Mercurial > repos > guerler > springsuite
diff spring_roc.py @ 29:41353488926c draft
"planemo upload commit 1c0a60f98e36bccb6d6c85ff82a8d737a811b4d5"
author | guerler |
---|---|
date | Sun, 22 Nov 2020 14:15:24 +0000 |
parents | |
children | b0e195a47df7 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/spring_roc.py Sun Nov 22 14:15:24 2020 +0000 @@ -0,0 +1,256 @@ +#! /usr/bin/env python +import argparse +import math +import random +from datetime import datetime + +from matplotlib import pyplot as plt + + +def getIds(rawIds): + return rawIds.split("|") + + +def getCenterId(rawId): + elements = rawId.split("|") + if len(elements) > 1: + return elements[1] + return rawId + + +def getOrganism(rawId): + elements = rawId.split("_") + return elements[-1] + + +def getKey(a, b): + if a > b: + name = "%s_%s" % (a, b) + else: + name = "%s_%s" % (b, a) + return name + + +def getPercentage(rate, denominator): + if denominator > 0: + return 100.0 * rate / denominator + return 0.0 + + +def getFilter(filterName): + print("Loading target organism(s)...") + filterSets = dict() + with open(filterName) as filterFile: + for line in filterFile: + columns = line.split() + for colIndex in [0, 1]: + if colIndex >= len(columns): + break + colEntry = columns[colIndex] + id = getCenterId(colEntry) + organism = getOrganism(colEntry) + if organism not in filterSets: + filterSets[organism] = set() + filterSets[organism].add(id) + print("Organism(s) in set: %s." % filterSets.keys()) + return filterSets + + +def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0, + bCol=1, scoreCol=-1, separator=None, + skipFirstLine=False, filterValues=list()): + index = dict() + count = 0 + with open(fileName) as fp: + line = fp.readline() + if skipFirstLine: + line = fp.readline() + while line: + ls = line.split(separator) + if separator is not None: + aList = getIds(ls[aCol]) + bList = getIds(ls[bCol]) + else: + aList = [getCenterId(ls[aCol])] + bList = [getCenterId(ls[bCol])] + validEntry = False + for a in aList: + for b in bList: + skip = False + if a == "-" or b == "-": + skip = True + if filterA is not None: + if a not in filterA and b not in filterA: + skip = True + if filterB is not None: + if a not in filterB and b not in filterB: + skip = True + for f in filterValues: + if len(ls) > f[0]: + columnEntry = ls[f[0]].lower() + searchEntry = f[1].lower() + if columnEntry.find(searchEntry) == -1: + skip = True + if not skip: + name = getKey(a, b) + if name not in index: + validEntry = True + if scoreCol >= 0 and len(ls) > scoreCol: + score = float(ls[scoreCol]) + skip = False + if minScore is not None: + if minScore > score: + return index, count + if not skip: + index[name] = score + else: + index[name] = 1.0 + if validEntry: + count = count + 1 + line = fp.readline() + return index, count + + +def getXY(prediction, positive, positiveCount, negative): + sortedPrediction = sorted(prediction.items(), key=lambda x: x[1], + reverse=True) + positiveTotal = positiveCount + negativeTotal = len(negative) + x = list([0]) + y = list([0]) + xMax = 0 + topCount = 0 + topMCC = 0.0 + topPrecision = 0.0 + topScore = 0.0 + tp = 0 + fp = 0 + count = 0 + for (name, score) in sortedPrediction: + found = False + if name in positive: + found = True + tp = tp + 1 + if name in negative: + found = True + fp = fp + 1 + precision = 0.0 + if tp > 0 or fp > 0: + precision = tp / (tp + fp) + fn = positiveTotal - tp + tn = negativeTotal - fp + denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) + if denom > 0.0: + mcc = (tp*tn-fp*fn)/math.sqrt(denom) + if mcc >= topMCC: + topMCC = mcc + topScore = score + topCount = count + topPrecision = precision + if found: + yValue = getPercentage(tp, tp + fn) + xValue = getPercentage(fp, fp + tn) + y.append(yValue) + x.append(xValue) + xMax = max(xValue, xMax) + count = count + 1 + print("Top ranking prediction %s." % str(sortedPrediction[0])) + print("Total count of prediction set: %s (precision=%1.2f)." % + (topCount, topPrecision)) + print("Total count of positive set: %s." % len(positive)) + print("Total count of negative set: %s." % len(negative)) + print("Matthews-Correlation-Coefficient: %s at Score >= %s." % + (round(topMCC, 2), topScore)) + return x, y, xMax + + +def main(args): + # load source files + filterSets = getFilter(args.input) + filterKeys = list(filterSets.keys()) + filterA = filterSets[filterKeys[0]] + if len(filterKeys) > 1: + filterB = filterSets[filterKeys[1]] + else: + filterB = filterA + + # identify biogrid filter options + filterValues = [] + if args.method: + filterValues.append([11, args.method]) + if args.experiment: + filterValues.append([12, args.experiment]) + if args.throughput: + filterValues.append([17, args.throughput]) + + # process biogrid database + print("Loading positive set from BioGRID file...") + positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26, + separator="\t", filterA=filterA, + filterB=filterB, skipFirstLine=True, + filterValues=filterValues) + + # rescan biogrid database to identify set of putative interactions + if filterValues: + print("Filtered entries by (column, value): %s" % filterValues) + print("Loading putative set from BioGRID file...") + putative, putativeCount = getReference(args.biogrid, aCol=23, bCol=26, + separator="\t", filterA=filterA, + filterB=filterB, + skipFirstLine=True) + print("Found %s." % putativeCount) + else: + putative = positive + + # process prediction file + print("Loading prediction file...") + prediction, _ = getReference(args.input, scoreCol=2) + + # estimate background noise + print("Estimating background noise...") + negative = set() + filterAList = list(filterA) + filterBList = list(filterB) + negativeCount = positiveCount + negativeRequired = negativeCount + random.seed(datetime.now()) + while negativeRequired > 0: + nameA = random.choice(filterAList) + nameB = random.choice(filterBList) + key = getKey(nameA, nameB) + if key not in putative and key not in negative: + negative.add(key) + negativeRequired = negativeRequired - 1 + + # create plot + print("Producing plot data...") + print("Total count in prediction file: %d." % len(prediction)) + print("Total count in positive file: %d." % len(positive)) + plt.ylabel('True Positive Rate (%)') + plt.xlabel('False Positive Rate (%)') + title = " vs. ".join(filterSets) + plt.suptitle(title) + if filterValues: + filterAttributes = list(map(lambda x: x[1], filterValues)) + plt.title("BioGRID filters: %s" % filterAttributes, fontsize=10) + x, y, xMax = getXY(prediction, positive, positiveCount, negative) + plt.plot(x, y) + plt.plot([0, xMax], [0, xMax]) + plt.savefig(args.output, format="png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Create ROC plot.') + parser.add_argument('-i', '--input', help='Input prediction file.', + required=True) + parser.add_argument('-b', '--biogrid', help='BioGRID interaction ' + + 'database file', required=True) + parser.add_argument('-e', '--experiment', help='Type (physical/genetic)', + default="", required=False) + parser.add_argument('-t', '--throughput', help='Throughput (low/high)', + default="", required=False) + parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', + default="", required=False) + parser.add_argument('-o', '--output', help='Output (png)', required=True) + args = parser.parse_args() + main(args)