view repmatch_gff3_util.py @ 3:f7608d0363bf draft

Uploaded
author iuc
date Fri, 13 Jan 2017 10:52:02 -0500
parents e5c7fffdc078
children 6acaa2c93f47
line wrap: on
line source

import bisect
import csv
import os
import shutil
import sys
import tempfile

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot  # noqa: E402

# Graph settings
Y_LABEL = 'Counts'
X_LABEL = 'Number of matched replicates'
TICK_WIDTH = 3
# Amount to shift the graph to make labels fit, [left, right, top, bottom]
ADJUST = [0.180, 0.9, 0.9, 0.1]
# Length of tick marks, use TICK_WIDTH for width
pyplot.rc('xtick.major', size=10.00)
pyplot.rc('ytick.major', size=10.00)
pyplot.rc('lines', linewidth=4.00)
pyplot.rc('axes', linewidth=3.00)
pyplot.rc('font', family='Bitstream Vera Sans', size=32.0)

COLORS = 'krb'


class Replicate(object):

    def __init__(self, id, dataset_path):
        self.id = id
        self.dataset_path = dataset_path
        self.parse(csv.reader(open(dataset_path, 'rt'), delimiter='\t'))

    def parse(self, reader):
        self.chromosomes = {}
        for line in reader:
            if line[0].startswith("#") or line[0].startswith('"'):
                continue
            cname, junk, junk, mid, midplus, value, strand, junk, attrs = line
            attrs = parse_gff_attrs(attrs)
            distance = attrs['cw_distance']
            mid = int(mid)
            midplus = int(midplus)
            value = float(value)
            distance = int(distance)
            if cname not in self.chromosomes:
                self.chromosomes[cname] = Chromosome(cname)
            chrom = self.chromosomes[cname]
            chrom.add_peak(Peak(cname, mid, value, distance, self))
        for chrom in self.chromosomes.values():
            chrom.sort_by_index()

    def filter(self, up_limit, low_limit):
        for chrom in self.chromosomes.values():
            chrom.filter(up_limit, low_limit)

    def size(self):
        return sum([len(c.peaks) for c in self.chromosomes.values()])


class Chromosome(object):

    def __init__(self, name):
        self.name = name
        self.peaks = []

    def add_peak(self, peak):
        self.peaks.append(peak)

    def sort_by_index(self):
        self.peaks.sort(key=lambda peak: peak.midpoint)
        self.keys = make_keys(self.peaks)

    def remove_peak(self, peak):
        i = bisect.bisect_left(self.keys, peak.midpoint)
        # If the peak was actually found
        if i < len(self.peaks) and self.peaks[i].midpoint == peak.midpoint:
            del self.keys[i]
            del self.peaks[i]

    def filter(self, up_limit, low_limit):
        self.peaks = [p for p in self.peaks if low_limit <= p.distance <= up_limit]
        self.keys = make_keys(self.peaks)


class Peak(object):

    def __init__(self, chrom, midpoint, value, distance, replicate):
        self.chrom = chrom
        self.value = value
        self.midpoint = midpoint
        self.distance = distance
        self.replicate = replicate

    def normalized_value(self, med):
        return self.value * med / self.replicate.median


class PeakGroup(object):

    def __init__(self):
        self.peaks = {}

    def add_peak(self, repid, peak):
        self.peaks[repid] = peak

    @property
    def chrom(self):
        return self.peaks.values()[0].chrom

    @property
    def midpoint(self):
        return median([peak.midpoint for peak in self.peaks.values()])

    @property
    def num_replicates(self):
        return len(self.peaks)

    @property
    def median_distance(self):
        return median([peak.distance for peak in self.peaks.values()])

    @property
    def value_sum(self):
        return sum([peak.value for peak in self.peaks.values()])

    def normalized_value(self, med):
        values = []
        for peak in self.peaks.values():
            values.append(peak.normalized_value(med))
        return median(values)

    @property
    def peakpeak_distance(self):
        keys = self.peaks.keys()
        return abs(self.peaks[keys[0]].midpoint - self.peaks[keys[1]].midpoint)


class FrequencyDistribution(object):

    def __init__(self, d=None):
        self.dist = d or {}

    def add(self, x):
        self.dist[x] = self.dist.get(x, 0) + 1

    def graph_series(self):
        x = []
        y = []
        for key, val in self.dist.items():
            x.append(key)
            y.append(val)
        return x, y

    def mode(self):
        return max(self.dist.items(), key=lambda data: data[1])[0]

    def size(self):
        return sum(self.dist.values())


def stop_err(msg):
    sys.stderr.write(msg)
    sys.exit(1)


def median(data):
    """
    Find the integer median of the data set.
    """
    if not data:
        return 0
    sdata = sorted(data)
    if len(data) % 2 == 0:
        return (sdata[len(data) // 2] + sdata[len(data) // 2 - 1]) / 2
    else:
        return sdata[len(data) // 2]


def make_keys(peaks):
    return [data.midpoint for data in peaks]


def get_window(chromosome, target_peaks, distance):
    """
    Returns a window of all peaks from a replicate within a certain distance of
    a peak from another replicate.
    """
    lower = target_peaks[0].midpoint
    upper = target_peaks[0].midpoint
    for peak in target_peaks:
        lower = min(lower, peak.midpoint - distance)
        upper = max(upper, peak.midpoint + distance)
    start_index = bisect.bisect_left(chromosome.keys, lower)
    end_index = bisect.bisect_right(chromosome.keys, upper)
    return (chromosome.peaks[start_index: end_index], chromosome.name)


def match_largest(window, peak, chrum):
    if not window:
        return None
    if peak.chrom != chrum:
        return None
    return max(window, key=lambda cpeak: cpeak.value)


def match_closest(window, peak, chrum):
    if not window:
        return None
    if peak.chrom != chrum:
        return None
    return min(window, key=lambda match: abs(match.midpoint - peak.midpoint))


def frequency_histogram(freqs, dataset_path, labels=[], title=''):
    pyplot.clf()
    pyplot.figure(figsize=(10, 10))
    for i, freq in enumerate(freqs):
        xvals, yvals = freq.graph_series()
        # Go from high to low
        xvals.reverse()
        pyplot.bar([x - 0.4 + 0.8 / len(freqs) * i for x in xvals], yvals, width=0.8 / len(freqs), color=COLORS[i])
    pyplot.xticks(range(min(xvals), max(xvals) + 1), map(str, reversed(range(min(xvals), max(xvals) + 1))))
    pyplot.xlabel(X_LABEL)
    pyplot.ylabel(Y_LABEL)
    pyplot.subplots_adjust(left=ADJUST[0], right=ADJUST[1], top=ADJUST[2], bottom=ADJUST[3])
    ax = pyplot.gca()
    for l in ax.get_xticklines() + ax.get_yticklines():
        l.set_markeredgewidth(TICK_WIDTH)
    pyplot.savefig(dataset_path)


METHODS = {'closest': match_closest, 'largest': match_largest}


def gff_attrs(d):
    if not d:
        return '.'
    return ';'.join('%s=%s' % item for item in d.items())


def parse_gff_attrs(s):
    d = {}
    if s == '.':
        return d
    for item in s.split(';'):
        key, val = item.split('=')
        d[key] = val
    return d


def gff_row(cname, start, end, score, source, type='.', strand='.', phase='.', attrs={}):
    return (cname, source, type, start, end, score, strand, phase, gff_attrs(attrs))


def get_temporary_plot_path():
    """
    Return the path to a temporary file with a valid image format
    file extension that can be used with bioformats.
    """
    tmp_dir = tempfile.mkdtemp(prefix='tmp-repmatch-')
    fd, name = tempfile.mkstemp(suffix='.pdf', dir=tmp_dir)
    os.close(fd)
    return name


def process_files(dataset_paths, galaxy_hids, method, distance, step, replicates, up_limit, low_limit, output_files,
                  output_matched_peaks, output_unmatched_peaks, output_detail, output_statistics_table, output_statistics_histogram):
    output_statistics_histogram_file = output_files in ["all"] and method in ["all"]
    if len(dataset_paths) < 2:
        return
    if method == 'all':
        match_methods = METHODS.keys()
    else:
        match_methods = [method]
    for match_method in match_methods:
        statistics = perform_process(dataset_paths,
                                     galaxy_hids,
                                     match_method,
                                     distance,
                                     step,
                                     replicates,
                                     up_limit,
                                     low_limit,
                                     output_files,
                                     output_matched_peaks,
                                     output_unmatched_peaks,
                                     output_detail,
                                     output_statistics_table,
                                     output_statistics_histogram)
    if output_statistics_histogram_file:
        tmp_statistics_histogram_path = get_temporary_plot_path()
        frequency_histogram([stat['distribution'] for stat in [statistics]],
                            tmp_statistics_histogram_path,
                            METHODS.keys())
        shutil.move(tmp_statistics_histogram_path, output_statistics_histogram)


def perform_process(dataset_paths, galaxy_hids, method, distance, step, num_required, up_limit, low_limit, output_files,
                    output_matched_peaks, output_unmatched_peaks, output_detail, output_statistics_table, output_statistics_histogram):
    output_detail_file = output_files in ["all"] and output_detail is not None
    output_statistics_table_file = output_files in ["all"] and output_statistics_table is not None
    output_unmatched_peaks_file = output_files in ["all", "matched_peaks_unmatched_peaks"] and output_unmatched_peaks is not None
    output_statistics_histogram_file = output_files in ["all"] and output_statistics_histogram is not None
    replicates = []
    for i, dataset_path in enumerate(dataset_paths):
        try:
            galaxy_hid = galaxy_hids[i]
            r = Replicate(galaxy_hid, dataset_path)
            replicates.append(r)
        except Exception as e:
            stop_err('Unable to parse file "%s", exception: %s' % (dataset_path, str(e)))
    attrs = 'd%sr%s' % (distance, num_required)
    if up_limit != 1000:
        attrs += 'u%d' % up_limit
    if low_limit != -1000:
        attrs += 'l%d' % low_limit
    if step != 0:
        attrs += 's%d' % step

    def td_writer(file_path):
        # Returns a tab-delimited writer for a certain output
        return csv.writer(open(file_path, 'wt'), delimiter='\t')

    labels = ('chrom',
              'median midpoint',
              'median midpoint+1',
              'median normalized reads',
              'replicates',
              'median c-w distance',
              'reads sum')
    for replicate in replicates:
        labels += ('chrom',
                   'median midpoint',
                   'median midpoint+1',
                   'c-w sum',
                   'c-w distance',
                   'replicate id')
    matched_peaks_output = td_writer(output_matched_peaks)
    if output_statistics_table_file:
        statistics_table_output = td_writer(output_statistics_table)
        statistics_table_output.writerow(('data', 'median read count'))
    if output_detail_file:
        detail_output = td_writer(output_detail)
        detail_output.writerow(labels)
    if output_unmatched_peaks_file:
        unmatched_peaks_output = td_writer(output_unmatched_peaks)
        unmatched_peaks_output.writerow(('chrom', 'midpoint', 'midpoint+1', 'c-w sum', 'c-w distance', 'replicate id'))
    # Perform filtering
    if up_limit < 1000 or low_limit > -1000:
        for replicate in replicates:
            replicate.filter(up_limit, low_limit)
    # Actually merge the peaks
    peak_groups = []
    unmatched_peaks = []
    freq = FrequencyDistribution()

    def do_match(reps, distance):
        # Copy list because we will mutate it, but keep replicate references.
        reps = reps[:]
        while len(reps) > 1:
            # Iterate over each replicate as "main"
            main = reps[0]
            reps.remove(main)
            for chromosome in main.chromosomes.values():
                peaks_by_value = chromosome.peaks[:]
                # Sort main replicate by value
                peaks_by_value.sort(key=lambda peak: -peak.value)

                def search_for_matches(group):
                    # Here we use multiple passes, expanding the window to be
                    #  +- distance from any previously matched peak.
                    while True:
                        new_match = False
                        for replicate in reps:
                            if replicate.id in group.peaks:
                                # Stop if match already found for this replicate
                                continue
                            try:
                                # Lines changed to remove a major bug by Rohit Reja.
                                window, chrum = get_window(replicate.chromosomes[chromosome.name],
                                                           group.peaks.values(),
                                                           distance)
                                match = METHODS[method](window, peak, chrum)
                            except KeyError:
                                continue
                            if match:
                                group.add_peak(replicate.id, match)
                                new_match = True
                        if not new_match:
                            break
                # Attempt to enlarge existing peak groups
                for group in peak_groups:
                    old_peaks = group.peaks.values()[:]
                    search_for_matches(group)
                    for peak in group.peaks.values():
                        if peak not in old_peaks:
                            peak.replicate.chromosomes[chromosome.name].remove_peak(peak)
                # Attempt to find new peaks groups.  For each peak in the
                # main replicate, search for matches in the other replicates
                for peak in peaks_by_value:
                    matches = PeakGroup()
                    matches.add_peak(main.id, peak)
                    search_for_matches(matches)
                    # Were enough replicates matched?
                    if matches.num_replicates >= num_required:
                        for peak in matches.peaks.values():
                            peak.replicate.chromosomes[chromosome.name].remove_peak(peak)
                        peak_groups.append(matches)
    # Zero or less = no stepping
    if step <= 0:
        do_match(replicates, distance)
    else:
        for d in range(0, distance, step):
            do_match(replicates, d)
    for group in peak_groups:
        freq.add(group.num_replicates)
    # Collect together the remaining unmatched_peaks
    for replicate in replicates:
        for chromosome in replicate.chromosomes.values():
            for peak in chromosome.peaks:
                freq.add(1)
                unmatched_peaks.append(peak)
    # Average the unmatched_peaks count in the graph by # replicates
    med = median([peak.value for group in peak_groups for peak in group.peaks.values()])
    for replicate in replicates:
        replicate.median = median([peak.value for group in peak_groups for peak in group.peaks.values() if peak.replicate == replicate])
        statistics_table_output.writerow((replicate.id, replicate.median))
    for group in peak_groups:
        # Output matched_peaks (matched pairs).
        matched_peaks_output.writerow(gff_row(cname=group.chrom,
                                              start=group.midpoint,
                                              end=group.midpoint + 1,
                                              source='repmatch',
                                              score=group.normalized_value(med),
                                              attrs={'median_distance': group.median_distance,
                                                     'replicates': group.num_replicates,
                                                     'value_sum': group.value_sum}))
        if output_detail_file:
            matched_peaks = (group.chrom,
                             group.midpoint,
                             group.midpoint + 1,
                             group.normalized_value(med),
                             group.num_replicates,
                             group.median_distance,
                             group.value_sum)
            for peak in group.peaks.values():
                matched_peaks += (peak.chrom, peak.midpoint, peak.midpoint + 1, peak.value, peak.distance, peak.replicate.id)
            detail_output.writerow(matched_peaks)
    if output_unmatched_peaks_file:
        for unmatched_peak in unmatched_peaks:
            unmatched_peaks_output.writerow((unmatched_peak.chrom,
                                             unmatched_peak.midpoint,
                                             unmatched_peak.midpoint + 1,
                                             unmatched_peak.value,
                                             unmatched_peak.distance,
                                             unmatched_peak.replicate.id))
    if output_statistics_histogram_file:
        tmp_statistics_histogram_path = get_temporary_plot_path()
        frequency_histogram([freq], tmp_statistics_histogram_path)
        shutil.move(tmp_statistics_histogram_path, output_statistics_histogram)
    return {'distribution': freq}