diff sashimi-plot.py @ 0:9304dd9a16a2 draft

"planemo upload for repository https://github.com/ARTbio/tools-artbio/tree/master/tools/sashimi_plot commit 746c03a1187e1d708af8628920a0c615cddcdacc"
author artbio
date Fri, 23 Aug 2019 11:38:29 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sashimi-plot.py	Fri Aug 23 11:38:29 2019 -0400
@@ -0,0 +1,1024 @@
+#!/usr/bin/env python
+
+# Import modules
+import copy
+import os
+import re
+import subprocess as sp
+import sys
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+
+def define_options():
+    # Argument parsing
+    parser = ArgumentParser(description="""Create sashimi plot for a given
+                            genomic region""")
+    parser.add_argument("-b", "--bam", type=str,
+                        help="""
+                        Individual bam file or file with a list of bam files.
+                        In the case of a list of files the format is tsv:
+                        1col: id for bam file,
+                        2col: path of bam file,
+                        3+col: additional columns
+                        """)
+    parser.add_argument("-c", "--coordinates", type=str,
+                        help="Genomic region. Format: chr:start-end (1-based)")
+    parser.add_argument("-o", "--out-prefix", type=str, dest="out_prefix",
+                        default="sashimi",
+                        help="Prefix for plot file name [default=%(default)s]")
+    parser.add_argument("-S", "--out-strand", type=str, dest="out_strand",
+                        default="both", help="""Only for --strand other than
+                        'NONE'. Choose which signal strand to plot:
+                        <both> <plus> <minus> [default=%(default)s]""")
+    parser.add_argument("-M", "--min-coverage", type=int, default=1,
+                        dest="min_coverage", help="""Minimum number of reads
+                        supporting a junction to be drawn [default=1]""")
+    parser.add_argument("-j", "--junctions-bed", type=str, default="",
+                        dest="junctions_bed", help="""Junction BED file name
+                        [default=no junction file]""")
+    parser.add_argument("-g", "--gtf",
+                        help="Gtf file with annotation (only exons is enough)")
+    parser.add_argument("-s", "--strand", default="NONE", type=str,
+                        help="""Strand specificity: <NONE> <SENSE> <ANTISENSE>
+                        <MATE1_SENSE> <MATE2_SENSE> [default=%(default)s]""")
+    parser.add_argument("--shrink", action="store_true",
+                        help="""Shrink the junctions by a factor for nicer
+                        display [default=%(default)s]""")
+    parser.add_argument("-O", "--overlay", type=int,
+                        help="Index of column with overlay levels (1-based)")
+    parser.add_argument("-A", "--aggr", type=str, default="",
+                        help="""Aggregate function for overlay:
+                        <mean> <median> <mean_j> <median_j>.
+                        Use mean_j | median_j to keep density overlay but
+                        aggregate junction counts [default=no aggregation]""")
+    parser.add_argument("-C", "--color-factor", type=int, dest="color_factor",
+                        help="Index of column with color levels (1-based)")
+    parser.add_argument("--alpha", type=float, default=0.5,
+                        help="""Transparency level for density histogram
+                        [default=%(default)s]""")
+    parser.add_argument("-P", "--palette", type=str,
+                        help="""Color palette file. tsv file with >=1 columns,
+                        where the color is the first column""")
+    parser.add_argument("-L", "--labels", type=int, dest="labels", default=1,
+                        help="""Index of column with labels (1-based)
+                        [default=%(default)s]""")
+    parser.add_argument("--height", type=float, default=2,
+                        help="""Height of the individual signal plot in inches
+                        [default=%(default)s]""")
+    parser.add_argument("--ann-height", type=float, default=1.5,
+                        dest="ann_height", help="""Height of annotation plot in
+                        inches [default=%(default)s]""")
+    parser.add_argument("--width", type=float, default=10,
+                        help="""Width of the plot in inches
+                        [default=%(default)s]""")
+    parser.add_argument("--base-size", type=float, default=14,
+                        dest="base_size", help="""Base font size of the plot in
+                        pch [default=%(default)s]""")
+    parser.add_argument("-F", "--out-format", type=str, default="pdf",
+                        dest="out_format", help="""Output file format:
+                        <pdf> <svg> <png> <jpeg> <tiff>
+                        [default=%(default)s]""")
+    parser.add_argument("-R", "--out-resolution", type=int, default=300,
+                        dest="out_resolution", help="""Output file resolution in
+                        PPI (pixels per inch). Applies only to raster output
+                        formats [default=%(default)s]""")
+    return parser
+
+
+def parse_coordinates(c):
+    c = c.replace(",", "")
+    chr = c.split(":")[0]
+    start, end = c.split(":")[1].split("-")
+    # Convert to 0-based
+    start, end = int(start) - 1, int(end)
+    return chr, start, end
+
+
+def count_operator(CIGAR_op, CIGAR_len, pos, start, end, a, junctions):
+
+    # Match
+    if CIGAR_op == "M":
+        for i in range(pos, pos + CIGAR_len):
+            if i < start or i >= end:
+                continue
+            ind = i - start
+            a[ind] += 1
+
+    # Insertion or Soft-clip
+    if CIGAR_op == "I" or CIGAR_op == "S":
+        return pos
+
+    # Deletion
+    if CIGAR_op == "D":
+        pass
+
+    # Junction
+    if CIGAR_op == "N":
+        don = pos
+        acc = pos + CIGAR_len
+        if don > start and acc < end:
+            junctions[(don, acc)] = junctions.setdefault((don, acc), 0) + 1
+
+    pos = pos + CIGAR_len
+
+    return pos
+
+
+def flip_read(s, samflag):
+    if s == "NONE" or s == "SENSE":
+        return 0
+    if s == "ANTISENSE":
+        return 1
+    if s == "MATE1_SENSE":
+        if int(samflag) & 64:
+            return 0
+        if int(samflag) & 128:
+            return 1
+    if s == "MATE2_SENSE":
+        if int(samflag) & 64:
+            return 1
+        if int(samflag) & 128:
+            return 0
+
+
+def read_bam(f, c, s):
+
+    _, start, end = parse_coordinates(c)
+
+    # Initialize coverage array and junction dict
+    a = {"+": [0] * (end - start)}
+    junctions = {"+": OrderedDict()}
+    if s != "NONE":
+        a["-"] = [0] * (end - start)
+        junctions["-"] = OrderedDict()
+
+    p = sp.Popen("samtools view %s %s " % (f, c), shell=True, stdout=sp.PIPE)
+    for line in p.communicate()[0].decode('utf8').strip().split("\n"):
+
+        if line == "":
+            continue
+
+        line_sp = line.strip().split("\t")
+        samflag, read_start, CIGAR = line_sp[1], int(line_sp[3]), line_sp[5]
+
+        # Ignore reads with more exotic CIGAR operators
+        if any(map(lambda x: x in CIGAR, ["H", "P", "X", "="])):
+            continue
+
+        read_strand = ["+", "-"][flip_read(s, samflag) ^ bool(int(samflag) &
+                                                              16)]
+        if s == "NONE":
+            read_strand = "+"
+
+        CIGAR_lens = re.split("[MIDNS]", CIGAR)[:-1]
+        CIGAR_ops = re.split("[0-9]+", CIGAR)[1:]
+
+        pos = read_start
+
+        for n, CIGAR_op in enumerate(CIGAR_ops):
+            CIGAR_len = int(CIGAR_lens[n])
+            pos = count_operator(CIGAR_op, CIGAR_len, pos, start, end,
+                                 a[read_strand], junctions[read_strand])
+
+    p.stdout.close()
+    return a, junctions
+
+
+def get_bam_path(index, path):
+    if os.path.isabs(path):
+        return path
+    base_dir = os.path.dirname(index)
+    return os.path.join(base_dir, path)
+
+
+def read_bam_input(f, overlay, color, label):
+    if f.endswith(".bam"):
+        bn = f.strip().split("/")[-1].strip(".bam")
+        yield bn, f, None, None, bn
+        return
+    with open(f) as openf:
+        for line in openf:
+            line_sp = line.strip().split("\t")
+            bam = get_bam_path(f, line_sp[1])
+            overlay_level = line_sp[overlay-1] if overlay else None
+            color_level = line_sp[color-1] if color else None
+            label_text = line_sp[label-1] if label else None
+            yield line_sp[0], bam, overlay_level, color_level, label_text
+
+
+def prepare_for_R(a, junctions, c, m):
+
+    _, start, _ = parse_coordinates(args.coordinates)
+
+    # Convert the array index to genomic coordinates
+    x = list(i+start for i in range(len(a)))
+    y = a
+
+    # Arrays for R
+    dons, accs, yd, ya, counts = [], [], [], [], []
+
+    # Prepare arrays for junctions (which will be the arcs)
+    for (don, acc), n in junctions.items():
+
+        # Do not add junctions with less than defined coverage
+        if n < m:
+            continue
+
+        dons.append(don)
+        accs.append(acc)
+        counts.append(n)
+
+        yd.append(a[don - start - 1])
+        ya.append(a[acc - start + 1])
+
+    return x, y, dons, accs, yd, ya, counts
+
+
+def intersect_introns(data):
+    data = sorted(data)
+    it = iter(data)
+    a, b = next(it)
+    for c, d in it:
+        if b > c:
+            # Use `if b > c` if you want (1,2), (2,3) not to be
+            # treated as intersection.
+            b = min(b, d)
+            a = max(a, c)
+        else:
+            yield a, b
+            a, b = c, d
+    yield a, b
+
+
+def shrink_density(x, y, introns):
+    new_x, new_y = [], []
+    shift = 0
+    start = 0
+    # introns are already sorted by coordinates
+    for a, b in introns:
+        end = x.index(a)+1
+        new_x += [int(i-shift) for i in x[start:end]]
+        new_y += y[start:end]
+        start = x.index(b)
+        L = (b-a)
+        shift += L-L**0.7
+    new_x += [int(i-shift) for i in x[start:]]
+    new_y += y[start:]
+    return new_x, new_y
+
+
+def shrink_junctions(dons, accs, introns):
+    new_dons, new_accs = [0]*len(dons), [0]*len(accs)
+    shift_acc = 0
+    shift_don = 0
+    s = set()
+    junctions = list(zip(dons, accs))
+    for a, b in introns:
+        L = b - a
+        shift_acc += L-int(L**0.7)
+        for i, (don, acc) in enumerate(junctions):
+            if a >= don and b <= acc:
+                if (don, acc) not in s:
+                    new_dons[i] = don - shift_don
+                    new_accs[i] = acc - shift_acc
+                else:
+                    new_accs[i] = acc - shift_acc
+                s.add((don, acc))
+        shift_don = shift_acc
+    return new_dons, new_accs
+
+
+def read_palette(f):
+    palette = "#ff0000", "#00ff00", "#0000ff", "#000000"
+    if f:
+        with open(f) as openf:
+            palette = list(line.split("\t")[0].strip() for line in openf)
+    return palette
+
+
+def read_gtf(f, c):
+    exons = OrderedDict()
+    transcripts = OrderedDict()
+    chr, start, end = parse_coordinates(c)
+    end = end - 1
+    with open(f) as openf:
+        for line in openf:
+            if line.startswith("#"):
+                continue
+            (el_chr, _, el, el_start, el_end, _,
+             strand, _, tags) = line.strip().split("\t")
+            if el_chr != chr:
+                continue
+            d = dict(kv.strip().split(" ") for kv in
+                     tags.strip(";").split("; "))
+            transcript_id = d["transcript_id"]
+            el_start, el_end = int(el_start) - 1, int(el_end)
+            strand = '"' + strand + '"'
+            if el == "transcript":
+                if (el_end > start and el_start < end):
+                    transcripts[transcript_id] = (max(start, el_start),
+                                                  min(end, el_end),
+                                                  strand)
+                    continue
+            if el == "exon":
+                if (start < el_start < end or start < el_end < end):
+                    exons.setdefault(transcript_id,
+                                     []).append((max(el_start, start),
+                                                 min(end, el_end), strand))
+
+    return transcripts, exons
+
+
+def make_introns(transcripts, exons, intersected_introns=None):
+    new_transcripts = copy.deepcopy(transcripts)
+    new_exons = copy.deepcopy(exons)
+    introns = OrderedDict()
+    if intersected_introns:
+        for tx, (tx_start, tx_end, strand) in new_transcripts.items():
+            total_shift = 0
+            for a, b in intersected_introns:
+                L = b - a
+                shift = L - int(L**0.7)
+                total_shift += shift
+                for i, (exon_start, exon_end, strand) in \
+                        enumerate(exons.get(tx, [])):
+                    new_exon_start, new_exon_end = new_exons[tx][i][:2]
+                    if a < exon_start:
+                        if b > exon_end:
+                            if i == len(exons[tx])-1:
+                                total_shift = total_shift - shift + \
+                                             (exon_start - a)*(1-int(L**-0.3))
+                            shift = (exon_start - a)*(1-int(L**-0.3))
+                            new_exon_end = new_exons[tx][i][1] - shift
+                        new_exon_start = new_exons[tx][i][0] - shift
+                    if b <= exon_end:
+                        new_exon_end = new_exons[tx][i][1] - shift
+                    new_exons[tx][i] = (new_exon_start, new_exon_end, strand)
+            tx_start = min(tx_start,
+                           sorted(new_exons.get(tx, [[sys.maxsize]]))[0][0])
+            new_transcripts[tx] = (tx_start, tx_end - total_shift, strand)
+
+    for tx, (tx_start, tx_end, strand) in new_transcripts.items():
+        intron_start = tx_start
+        ex_end = 0
+        for ex_start, ex_end, strand in sorted(new_exons.get(tx, [])):
+            intron_end = ex_start
+            if tx_start < ex_start:
+                introns.setdefault(tx, []).append((intron_start, intron_end,
+                                                   strand))
+            intron_start = ex_end
+        if tx_end > ex_end:
+            introns.setdefault(tx, []).append((intron_start, tx_end, strand))
+    d = {'transcripts': new_transcripts,
+         'exons': new_exons,
+         'introns': introns}
+    return d
+
+
+def gtf_for_ggplot(annotation, start, end, arrow_bins):
+    arrow_space = int((end - start)/arrow_bins)
+    s = """
+
+    # data table with exons
+    ann_list = list(
+            "exons" = data.table(),
+            "introns" = data.table()
+    )
+    """
+
+    if annotation["exons"]:
+
+        s += """
+        ann_list[['exons']] = data.table(
+                tx = rep(c(%(tx_exons)s), c(%(n_exons)s)),
+                start = c(%(exon_start)s),
+                end = c(%(exon_end)s),
+                strand = c(%(strand)s)
+        )
+        """ % ({
+            "tx_exons": ",".join(annotation["exons"].keys()),
+            "n_exons": ",".join(map(str, map(len,
+                                annotation["exons"].values()))),
+            "exon_start": ",".join(map(str, (v[0] for vs in
+                                   annotation["exons"].values() for v in vs))),
+            "exon_end": ",".join(map(str, (v[1] for vs in
+                                 annotation["exons"].values() for v in vs))),
+            "strand": ",".join(map(str, (v[2] for vs in
+                                   annotation["exons"].values() for v in vs))),
+            })
+
+    if annotation["introns"]:
+
+        s += """
+        ann_list[['introns']] = data.table(
+                tx = rep(c(%(tx_introns)s), c(%(n_introns)s)),
+                start = c(%(intron_start)s),
+                end = c(%(intron_end)s),
+                strand = c(%(strand)s)
+        )
+        # Create data table for strand arrows
+        txarrows = data.table()
+        introns = ann_list[['introns']]
+        # Add right-pointing arrows for plus strand
+        if ("+" %%in%% introns$strand) {
+            txarrows = rbind(
+                    txarrows,
+                    introns[strand=="+" & end-start>5, list(
+                            seq(start+4,end,by=%(arrow_space)s)-1,
+                            seq(start+4,end,by=%(arrow_space)s)
+                            ), by=.(tx,start,end)
+                    ]
+            )
+        }
+        # Add left-pointing arrows for minus strand
+        if ("-" %%in%% introns$strand) {
+          txarrows = rbind(txarrows,
+                           introns[strand=="-" & end-start>5,
+                                   list(seq(start,max(start+1, end-4),
+                                            by=%(arrow_space)s),
+                                        seq(start,max(start+1, end-4),
+                                            by=%(arrow_space)s)-1
+                                        ),
+                                   by=.(tx,start,end)
+                                   ]
+                           )
+        }
+        """ % ({
+            "tx_introns": ",".join(annotation["introns"].keys()),
+            "n_introns": ",".join(map(str, map(len,
+                                  annotation["introns"].values()))),
+            "intron_start": ",".join(map(str, (v[0] for vs in
+                                     annotation["introns"].values() for v in
+                                     vs))),
+            "intron_end": ",".join(map(str, (v[1] for vs in
+                                   annotation["introns"].values() for v in
+                                   vs))),
+            "strand": ",".join(map(str, (v[2] for vs in
+                               annotation["introns"].values() for v in vs))),
+            "arrow_space": arrow_space,
+            })
+
+    s += """
+
+    gtfp = ggplot()
+    if (length(ann_list[['introns']]) > 0) {
+      gtfp = gtfp + geom_segment(data = ann_list[['introns']],
+                                 aes(x = start,
+                                     xend = end,
+                                     y = tx,
+                                     yend = tx),
+                                 size = 0.3)
+     gtfp = gtfp + geom_segment(data = txarrows,
+                                aes(x = V1,
+                                     xend = V2,
+                                     y = tx,
+                                     yend = tx),
+                                arrow = arrow(length = unit(0.02, "npc")))
+    }
+    if (length(ann_list[['exons']]) > 0) {
+      gtfp = gtfp + geom_segment(data = ann_list[['exons']],
+                                 aes(x = start,
+                                     xend = end,
+                                     y = tx,
+                                     yend = tx),
+                                 size = 5,
+                                 alpha = 1)
+    }
+    gtfp = gtfp + scale_y_discrete(expand = c(0, 0.5))
+    gtfp = gtfp + scale_x_continuous(expand = c(0, 0.25),
+                                     limits = c( %s,% s))
+    gtfp = gtfp + labs(y = NULL)
+    gtfp = gtfp + theme(axis.line = element_blank(),
+                        axis.text.x = element_blank(),
+                        axis.ticks = element_blank())
+    """ % (start, end)
+
+    return s
+
+
+def setup_R_script(h, w, b, label_dict):
+    s = """
+    library(ggplot2)
+    library(grid)
+    library(gridExtra)
+    library(data.table)
+    library(gtable)
+
+    scale_lwd = function(r) {
+        lmin = 0.1
+        lmax = 4
+        return( r*(lmax-lmin)+lmin )
+    }
+
+    base_size = %(b)s
+    height = ( %(h)s + base_size*0.352777778/67 ) * 1.02
+    width = %(w)s
+    theme_set(theme_bw(base_size=base_size))
+    theme_update(
+        plot.margin = unit(c(15,15,15,15), "pt"),
+        panel.grid = element_blank(),
+        panel.border = element_blank(),
+        axis.line = element_line(size=0.5),
+        axis.title.x = element_blank(),
+        axis.title.y = element_text(angle=0, vjust=0.5)
+    )
+
+    labels = list(%(labels)s)
+
+    density_list = list()
+    junction_list = list()
+
+    """ % ({
+        'h': h,
+        'w': w,
+        'b': b,
+        'labels': ",".join(('"%s"="%s"' % (id, lab) for id, lab in
+                            label_dict.items())),
+        })
+    return s
+
+
+def median(lst):
+    quotient, remainder = divmod(len(lst), 2)
+    if remainder:
+        return sorted(lst)[quotient]
+    return sum(sorted(lst)[quotient - 1:quotient + 1]) / 2.
+
+
+def mean(lst):
+    return sum(lst)/len(lst)
+
+
+def make_R_lists(id_list, d, overlay_dict, aggr, intersected_introns):
+    s = ""
+    aggr_f = {
+        "mean": mean,
+        "median": median,
+    }
+    id_list = id_list if not overlay_dict else overlay_dict.keys()
+    # Iterate over ids to get bam signal and junctions
+    for k in id_list:
+        x, y, dons, accs, yd, ya, counts = [], [], [], [], [], [], []
+        if not overlay_dict:
+            x, y, dons, accs, yd, ya, counts = d[k]
+            if intersected_introns:
+                x, y = shrink_density(x, y, intersected_introns)
+                dons, accs = shrink_junctions(dons, accs, intersected_introns)
+        else:
+            for id in overlay_dict[k]:
+                xid, yid, donsid, accsid, ydid, yaid, countsid = d[id]
+                if intersected_introns:
+                    xid, yid = shrink_density(xid, yid, intersected_introns)
+                    donsid, accsid = shrink_junctions(donsid, accsid,
+                                                      intersected_introns)
+                x += xid
+                y += yid
+                dons += donsid
+                accs += accsid
+                yd += ydid
+                ya += yaid
+                counts += countsid
+            if aggr and "_j" not in aggr:
+                x = d[overlay_dict[k][0]][0]
+                y = list(map(aggr_f[aggr], zip(*(d[id][1] for id in
+                                               overlay_dict[k]))))
+                if intersected_introns:
+                    x, y = shrink_density(x, y, intersected_introns)
+                # dons, accs, yd, ya, counts = [], [], [], [], []
+        s += """
+        density_list[["%(id)s"]] = data.frame(x = c(%(x)s), y = c(%(y)s))
+        junction_list[["%(id)s"]] = data.frame(x = c(%(dons)s),
+                                               xend=c(%(accs)s),
+                                               y=c(%(yd)s),
+                                               yend=c(%(ya)s),
+                                               count=c(%(counts)s))
+        """ % ({
+            "id": k,
+            'x': ",".join(map(str, x)),
+            'y': ",".join(map(str, y)),
+            'dons': ",".join(map(str, dons)),
+            'accs': ",".join(map(str, accs)),
+            'yd': ",".join(map(str, yd)),
+            'ya': ",".join(map(str, ya)),
+            'counts': ",".join(map(str, counts)),
+            })
+    return s
+
+
+def plot(R_script):
+    p = sp.Popen("R --vanilla --slave", shell=True, stdin=sp.PIPE)
+    p.communicate(input=R_script.encode())
+    p.stdin.close()
+    p.wait()
+    return
+
+
+def colorize(d, p, color_factor):
+    levels = sorted(set(d.values()))
+    n = len(levels)
+    if n > len(p):
+        p = (p*n)[:n]
+    if color_factor:
+        s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k,
+                                         p[levels.index(v)]) for k, v in
+                                         d.items()))
+    else:
+        s = "color_list = list(%s)\n" % (",".join('%s="%s"' % (k, "grey") for
+                                         k, v in d.items()))
+    return s
+
+
+if __name__ == "__main__":
+
+    strand_dict = {"plus": "+", "minus": "-"}
+
+    parser = define_options()
+    if len(sys.argv) == 1:
+        parser.print_help()
+        sys.exit(1)
+    args = parser.parse_args()
+
+    if args.aggr and not args.overlay:
+        print("""ERROR: Cannot apply aggregate function
+              if overlay is not selected.""")
+        exit(1)
+
+    palette = read_palette(args.palette)
+
+    (bam_dict, overlay_dict, color_dict,
+     id_list, label_dict) = ({"+": OrderedDict()}, OrderedDict(),
+                             OrderedDict(), [], OrderedDict())
+    if args.strand != "NONE":
+        bam_dict["-"] = OrderedDict()
+    if args.junctions_bed != "":
+        junctions_list = []
+
+    for (id, bam, overlay_level,
+         color_level, label_text) in read_bam_input(args.bam,
+                                                    args.overlay,
+                                                    args.color_factor,
+                                                    args.labels):
+        if not os.path.isfile(bam):
+            continue
+        id_list.append(id)
+        label_dict[id] = label_text
+        a, junctions = read_bam(bam, args.coordinates, args.strand)
+        if a.keys() == ["+"] and all(map(lambda x: x == 0,
+                                         list(a.values()[0]))):
+            print("ERROR: No reads in the specified area.")
+            exit(1)
+        for strand in a:
+            # Store junction information
+            if args.junctions_bed:
+                for k, v in zip(junctions[strand].keys(),
+                                junctions[strand].values()):
+                    if v > args.min_coverage:
+                        junctions_list.append('\t'.join([args.coordinates.split
+                                              (':')[0], str(k[0]), str(k[1]),
+                                              id, str(v), strand]))
+            bam_dict[strand][id] = prepare_for_R(a[strand],
+                                                 junctions[strand],
+                                                 args.coordinates,
+                                                 args.min_coverage)
+        if color_level is None:
+            color_dict.setdefault(id, id)
+        if overlay_level is not None:
+            overlay_dict.setdefault(overlay_level, []).append(id)
+            label_dict[overlay_level] = overlay_level
+            color_dict.setdefault(overlay_level, overlay_level)
+        if overlay_level is None:
+            color_dict.setdefault(id, color_level)
+
+    # No bam files
+    if not bam_dict["+"]:
+        print("ERROR: No available bam files.")
+        exit(1)
+
+    # Write junctions to BED
+    if args.junctions_bed:
+        if not args.junctions_bed.endswith('.bed'):
+            args.junctions_bed = args.junctions_bed + '.bed'
+        jbed = open(args.junctions_bed, 'w')
+        jbed.write('\n'.join(sorted(junctions_list)))
+        jbed.close()
+
+    if args.gtf:
+        transcripts, exons = read_gtf(args.gtf, args.coordinates)
+
+    if args.out_format not in ('pdf', 'png', 'svg', 'tiff', 'jpeg'):
+        print("""ERROR: Provided output format '%s' is not available.
+        Please select among 'pdf', 'png', 'svg',
+        'tiff' or 'jpeg'""" % args.out_format)
+        exit(1)
+
+    # Iterate for plus and minus strand
+    for strand in bam_dict:
+
+        # Output file name (allow tiff/tif and jpeg/jpg extensions)
+        if args.out_prefix.endswith(('.pdf', '.png', '.svg', '.tiff',
+                                     '.tif', '.jpeg', '.jpg')):
+            out_split = os.path.splitext(args.out_prefix)
+            if (args.out_format == out_split[1][1:] or
+                    args.out_format == 'tiff'
+                    and out_split[1] in ('.tiff', '.tif') or
+                    args.out_format == 'jpeg'
+                    and out_split[1] in ('.jpeg', '.jpg')):
+                args.out_prefix = out_split[0]
+                out_suffix = out_split[1][1:]
+            else:
+                out_suffix = args.out_format
+        else:
+            out_suffix = args.out_format
+        out_prefix = args.out_prefix + "_" + strand
+        if args.strand == "NONE":
+            out_prefix = args.out_prefix
+        else:
+            if args.out_strand != "both" \
+                    and strand != strand_dict[args.out_strand]:
+                continue
+
+        # Find set of junctions to perform shrink
+        intersected_introns = None
+        if args.shrink:
+            introns = (v for vs in bam_dict[strand].values() for v in
+                       zip(vs[2], vs[3]))
+            intersected_introns = list(intersect_introns(introns))
+
+        # *** PLOT *** Define plot height
+        bam_height = args.height * len(id_list)
+        if args.overlay:
+            bam_height = args.height * len(overlay_dict)
+        if args.gtf:
+            bam_height += args.ann_height
+
+        # *** PLOT *** Start R script by loading libraries,
+        # initializing variables, etc...
+        R_script = setup_R_script(bam_height, args.width,
+                                  args.base_size, label_dict)
+
+        R_script += colorize(color_dict, palette, args.color_factor)
+
+        # *** PLOT *** Prepare annotation plot only for the first bam file
+        arrow_bins = 50
+        if args.gtf:
+            # Make introns from annotation (they are shrunk if required)
+            annotation = make_introns(transcripts, exons, intersected_introns)
+            x = list(bam_dict[strand].values())[0][0]
+            if args.shrink:
+                x, _ = shrink_density(x, x, intersected_introns)
+            R_script += gtf_for_ggplot(annotation, x[0], x[-1], arrow_bins)
+
+        R_script += make_R_lists(id_list, bam_dict[strand], overlay_dict,
+                                 args.aggr, intersected_introns)
+
+        R_script += """
+
+        pdf(NULL) # just to remove the blank pdf produced by ggplotGrob
+        # fix problems with ggplot2 vs >3.0.0
+        if(packageVersion('ggplot2') >= '3.0.0'){
+            vs = 1
+        } else {
+            vs = 0
+        }
+
+        density_grobs = list();
+
+            for (bam_index in 1:length(density_list)) {
+
+                id = names(density_list)[bam_index]
+                d = data.table(density_list[[id]])
+                junctions = data.table(junction_list[[id]])
+
+                maxheight = max(d[['y']])
+
+                # Density plot
+                gp = ggplot(d) + geom_bar(aes(x, y), width=1,
+                                          position='identity',
+                                          stat='identity',
+                                          fill=color_list[[id]],
+                                          alpha=%(alpha)s)
+                gp = gp + labs(y=labels[[id]])
+                gp = gp + scale_x_continuous(expand=c(0,0.2))
+
+                # fix problems with ggplot2 vs >3.0.0
+                if(packageVersion('ggplot2') >= '3.0.0') {
+                    gp = gp +
+                        scale_y_continuous(breaks =
+                        ggplot_build(gp
+                                     )$layout$panel_params[[1]]$y.major_source)
+                } else {
+                    gp = gp +
+                        scale_y_continuous(breaks =
+                        ggplot_build(gp
+                                     )$layout$panel_ranges[[1]]$y.major_source)
+                }
+
+                # Aggregate junction counts
+                row_i = c()
+                if (nrow(junctions) >0 ) {
+
+                    junctions$jlabel = as.character(junctions$count)
+                    junctions = setNames(junctions[,.(max(y),
+                                                      max(yend),
+                                                      round(mean(count)),
+                                                      paste(jlabel,
+                                                            collapse=",")),
+                                                    keyby=.(x,xend)],
+                                         names(junctions))
+                    if ("%(args.aggr)s" != "") {
+                        junctions = setNames(
+                          junctions[,.(max(y),
+                                       max(yend),
+                                       round(%(args.aggr)s(count)),
+                                       round(%(args.aggr)s(count))),
+                                    keyby=.(x,xend)],
+                          names(junctions))
+                    }
+                    # The number of rows (unique junctions per bam) has to be
+                    # calculated after aggregation
+                    row_i = 1:nrow(junctions)
+                }
+
+
+                for (i in row_i) {
+
+                    j_tot_counts = sum(junctions[['count']])
+
+                    j = as.numeric(junctions[i,1:5])
+
+                    if ("%(args.aggr)s" != "") {
+                        j[3] = as.numeric(d[x==j[1]-1,y])
+                        j[4] = as.numeric(d[x==j[2]+1,y])
+                    }
+
+                    # Find intron midpoint
+                    xmid = round(mean(j[1:2]), 1)
+                    ymid = max(j[3:4]) * 1.2
+
+                    # Thickness of the arch
+                    lwd = scale_lwd(j[5]/j_tot_counts)
+
+                    curve_par = gpar(lwd=lwd, col=color_list[[id]])
+
+                    # Arc grobs
+
+                    # Choose position of the arch (top or bottom)
+                    nss = i
+                    if (nss%%%%2 == 0) {  #bottom
+                        ymid = -0.3 * maxheight
+                        # Draw the arcs
+                        # Left
+                        curve = xsplineGrob(x = c(0, 0, 1, 1),
+                                            y = c(1, 0, 0, 0),
+                                            shape = 1,
+                                            gp = curve_par)
+                        gp = gp + annotation_custom(grob = curve,
+                                                    j[1], xmid, 0, ymid)
+                        # Right
+                        curve = xsplineGrob(x = c(1, 1, 0, 0),
+                                            y = c(1, 0, 0, 0),
+                                            shape = 1,
+                                            gp = curve_par)
+                        gp = gp + annotation_custom(grob = curve,
+                                                    xmid,
+                                                    j[2],
+                                                    0, ymid)
+                    }
+
+                    if (nss%%%%2 != 0) {  #top
+                        # Draw the arcs
+                        # Left
+                        curve = xsplineGrob(x = c(0, 0, 1, 1),
+                                            y = c(0, 1, 1, 1),
+                                            shape = 1,
+                                            gp = curve_par)
+                        gp = gp + annotation_custom(grob = curve,
+                                                    j[1], xmid, j[3], ymid)
+                        # Right
+                        curve = xsplineGrob(x = c(1, 1, 0, 0),
+                                            y = c(0, 1, 1, 1),
+                                            shape = 1,
+                                            gp = curve_par)
+                        gp = gp + annotation_custom(grob = curve,
+                                                    xmid, j[2], j[4], ymid)
+                    }
+
+                    # Add junction labels
+                    gp = gp + annotate("label", x = xmid, y = ymid,
+                                       label = as.character(junctions[i,6]),
+                                       vjust = 0.5, hjust = 0.5,
+                                       label.padding = unit(0.01, "lines"),
+                                       label.size = NA,
+                                       size = (base_size*0.352777778)*0.6)
+
+                }
+
+                gpGrob = ggplotGrob(gp);
+                gpGrob$layout$clip[gpGrob$layout$name=="panel"] <- "off"
+                if (bam_index == 1) {
+                    # fix problems ggplot2 vs
+                    maxWidth = gpGrob$widths[2+vs] + gpGrob$widths[3+vs];
+                    maxYtextWidth = gpGrob$widths[3+vs];
+                    # Extract x axis grob (trim=F --> keep empty cells)
+                    xaxisGrob <- gtable_filter(gpGrob, "axis-b", trim=F)
+                    # fix problems ggplot2 vs
+                    xaxisGrob$heights[8+vs] = gpGrob$heights[1]
+                    x.axis.height = gpGrob$heights[7+vs] + gpGrob$heights[1]
+                }
+
+
+                # Remove x axis from all density plots
+                kept_names = gpGrob$layout$name[gpGrob$layout$name != "axis-b"]
+                gpGrob <- gtable_filter(gpGrob,
+                                        paste(kept_names, sep = "",
+                                              collapse = "|"),
+                                        trim=F)
+
+                # Find max width of y text and y label and max width of y text
+                # fix problems ggplot2 vs
+                maxWidth = grid::unit.pmax(maxWidth,
+                                           gpGrob$widths[2+vs] +
+                                                gpGrob$widths[3+vs]);
+                maxYtextWidth = grid::unit.pmax(maxYtextWidth,
+                                                gpGrob$widths[3+vs]);
+                density_grobs[[id]] = gpGrob;
+                }
+
+                # Add x axis grob after density grobs BEFORE annotation grob
+                density_grobs[["xaxis"]] = xaxisGrob
+
+                # Annotation grob
+                if (%(args.gtf)s == 1) {
+                        gtfGrob = ggplotGrob(gtfp);
+                        maxWidth = grid::unit.pmax(maxWidth,
+                                                   gtfGrob$widths[2+vs] +
+                                                       gtfGrob$widths[3+vs]);
+                        density_grobs[['gtf']] = gtfGrob;
+                }
+
+                # Reassign grob widths to align the plots
+                for (id in names(density_grobs)) {
+                    density_grobs[[id]]$widths[1] <-
+                        density_grobs[[id]]$widths[1] +
+                        maxWidth - (density_grobs[[id]]$widths[2 + vs] +
+                                    maxYtextWidth)
+# fix problems ggplot2 vs
+                    density_grobs[[id]]$widths[3 + vs] <-
+                        maxYtextWidth # fix problems ggplot2 vs
+                }
+
+                # Heights for density, x axis and annotation
+                heights = unit.c(
+                        unit(rep(%(signal_height)s,
+                                   length(density_list)), "in"),
+                        x.axis.height,
+                        unit(%(ann_height)s*%(args.gtf)s, "in")
+                        )
+
+                # Arrange grobs
+                argrobs = arrangeGrob(
+                        grobs=density_grobs,
+                        ncol=1,
+                        heights = heights,
+                );
+
+                # Save plot to file in the requested format
+                if ("%(out_format)s" == "tiff"){
+                        # TIFF images will be lzw-compressed
+                        ggsave("%(out)s",
+                               plot = argrobs,
+                               device = "tiff",
+                               width = width,
+                               height = height,
+                               units = "in",
+                               dpi = %(out_resolution)s,
+                               compression = "lzw")
+                } else {
+                        ggsave("%(out)s",
+                               plot = argrobs,
+                               device = "%(out_format)s",
+                               width = width,
+                               height = height,
+                               units = "in",
+                               dpi = %(out_resolution)s)
+                }
+
+                dev.log = dev.off()
+
+                """ % ({
+                        "out": "%s.%s" % (out_prefix, out_suffix),
+                        "out_format": args.out_format,
+                        "out_resolution": args.out_resolution,
+                        "args.gtf": float(bool(args.gtf)),
+                        "args.aggr": args.aggr.rstrip("_j"),
+                        "signal_height": args.height,
+                        "ann_height": args.ann_height,
+                        "alpha": args.alpha,
+                        })
+        if os.getenv('GGSASHIMI_DEBUG') is not None:
+            with open("R_script", 'w') as r:
+                r.write(R_script)
+        else:
+            plot(R_script)
+    exit()