diff spring_mcc.py @ 39:172398348efd draft

"planemo upload commit 26b4018c88041ee0ca7c2976e0a012015173d7b6-dirty"
author guerler
date Fri, 22 Jan 2021 15:50:27 +0000
parents
children f316caf098a6
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/spring_mcc.py	Fri Jan 22 15:50:27 2021 +0000
@@ -0,0 +1,320 @@
+#! /usr/bin/env python
+import argparse
+import math
+from os.path import isfile
+import re
+from matplotlib import pyplot as plt
+
+
+def getIds(rawIds):
+    return rawIds.split("|")
+
+
+def getCenterId(rawId):
+    elements = rawId.split("|")
+    if len(elements) > 1:
+        return elements[1]
+    return rawId
+
+
+def getOrganism(rawId):
+    elements = rawId.split("_")
+    return elements[-1]
+
+
+def getKey(a, b):
+    if a > b:
+        name = "%s_%s" % (a, b)
+    else:
+        name = "%s_%s" % (b, a)
+    return name
+
+
+def getPercentage(rate, denominator):
+    if denominator > 0:
+        return 100.0 * rate / denominator
+    return 0.0
+
+
+def getFilter(filterName):
+    print("Loading target organism(s)...")
+    filterSets = dict()
+    with open(filterName) as filterFile:
+        for line in filterFile:
+            columns = line.split()
+            for colIndex in [0, 1]:
+                if colIndex >= len(columns):
+                    break
+                colEntry = columns[colIndex]
+                id = getCenterId(colEntry)
+                organism = getOrganism(colEntry)
+                if organism not in filterSets:
+                    filterSets[organism] = set()
+                filterSets[organism].add(id)
+    print("Organism(s) in set: %s." % filterSets.keys())
+    return filterSets
+
+
+def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0,
+                 bCol=1, scoreCol=-1, separator=None,
+                 skipFirstLine=False, filterValues=list()):
+
+    index = dict()
+    count = 0
+    with open(fileName) as fp:
+        line = fp.readline()
+        if skipFirstLine:
+            line = fp.readline()
+        while line:
+            ls = line.split(separator)
+            skipEntry = False
+            if separator is not None:
+                aList = getIds(ls[aCol])
+                bList = getIds(ls[bCol])
+            else:
+                aId = getCenterId(ls[aCol])
+                bId = getCenterId(ls[bCol])
+                aList = [aId]
+                bList = [bId]
+            if not skipEntry:
+                validEntry = False
+                for a in aList:
+                    for b in bList:
+                        skip = False
+                        if a == "-" or b == "-":
+                            skip = True
+                        if filterA is not None and filterB is not None:
+                            skip = not ((a in filterA and b in filterB) or (a in filterB and b in filterA))
+                        for f in filterValues:
+                            if len(ls) > f[0]:
+                                columnEntry = ls[f[0]].lower()
+                                searchEntry = f[1].lower()
+                                if columnEntry.find(searchEntry) == -1:
+                                    skip = True
+                        if not skip:
+                            name = getKey(a, b)
+                            if name not in index:
+                                validEntry = True
+                                if scoreCol >= 0 and len(ls) > scoreCol:
+                                    score = float(ls[scoreCol])
+                                    skip = False
+                                    if minScore is not None:
+                                        if minScore > score:
+                                            return index, count
+                                    if not skip:
+                                        index[name] = score
+                                else:
+                                    index[name] = 1.0
+                if validEntry:
+                    count = count + 1
+            line = fp.readline()
+    return index, count
+
+
+def getMCC(prediction, positive, positiveCount, negative):
+    sortedPrediction = sorted(prediction.items(), key=lambda x: x[1],
+                              reverse=True)
+    positiveTotal = positiveCount
+    negativeTotal = len(negative)
+    x = list([0])
+    y = list([0])
+    xMax = 0
+    topCount = 0
+    topMCC = 0.0
+    topFP = 0.0
+    topTP = 0.0
+    topScore = 0.0
+    tp = 0
+    fp = 0
+    count = 0
+    for (name, score) in sortedPrediction:
+        found = False
+        if name in positive:
+            found = True
+            tp = tp + 1
+        if name in negative:
+            found = True
+            fp = fp + 1
+        fn = positiveTotal - tp
+        tn = negativeTotal - fp
+        denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)
+        yValue = getPercentage(tp, tp + fn)
+        xValue = getPercentage(fp, fp + tn)
+        if denom > 0.0:
+            mcc = (tp*tn-fp*fn)/math.sqrt(denom)
+            if mcc >= topMCC:
+                topMCC = mcc
+                topScore = score
+                topCount = count
+                topFP = xValue
+                topTP = yValue
+        if found:
+            y.append(yValue)
+            x.append(xValue)
+            xMax = max(xValue, xMax)
+        count = count + 1
+    if len(sortedPrediction) > 0:
+        print("Top ranking prediction %s." % str(sortedPrediction[0]))
+        print("Total count of prediction set: %s (tp=%1.2f, fp=%1.2f)." % (topCount, topTP, topFP))
+        print("Total count of positive set: %s." % len(positive))
+        print("Total count of negative set: %s." % len(negative))
+        print("Matthews-Correlation-Coefficient: %s at Score >= %s." % (round(topMCC, 2), topScore))
+    return topMCC
+
+
+def getNegativeSet(args, filterA, filterB, negativeRequired, jSize=5):
+    # determine negative set
+    print("Identifying non-interacting pairs...")
+    negative = set()
+    if args.negative and isfile(args.negative):
+        # load from explicit file
+        with open(args.negative) as file:
+            for line in file:
+                cols = line.split()
+                nameA = cols[0]
+                nameB = cols[1]
+                key = getKey(nameA, nameB)
+                if key not in negative:
+                    negative.add(key)
+    else:
+        if args.region_a and args.region_b:
+            locations = dict()
+            regionA = args.region_a.lower()
+            regionB = args.region_b.lower()
+            locations[regionA] = list()
+            locations[regionB] = list()
+            regions = [regionA, regionB]
+            print("Filtering regions %s" % str(regions))
+            with open(args.locations) as locFile:
+                for line in locFile:
+                    searchKey = "SUBCELLULAR LOCATION"
+                    searchPos = line.find(searchKey)
+                    if searchPos != -1:
+                        uniId = line.split()[0]
+                        if uniId not in filterA and uniId not in filterB:
+                            continue
+                        locStart = searchPos + len(searchKey) + 1
+                        locId = line[locStart:]
+                        locId = re.sub(r"\s*{.*}\s*", "", locId)
+                        locId = locId.replace(".", ",")
+                        locId = locId.strip().lower()
+                        filter_pos = locId.find("note=")
+                        if filter_pos >= 0:
+                            locId = locId[:filter_pos]
+                        filter_pos = locId.find(";")
+                        if filter_pos >= 0:
+                            locId = locId[:filter_pos]
+                        if locId:
+                            locId = list(map(lambda x: x.strip(), locId.split(",")))
+                            if (regionA in locId and regionB not in locId):
+                                locations[regionA].append(uniId)
+                            elif (regionA not in locId and regionB in locId):
+                                locations[regionB].append(uniId)
+            filterAList = sorted(locations[regionA])
+            filterBList = sorted(locations[regionB])
+        else:
+            filterAList = list(filterA)
+            filterBList = list(filterB)
+        for i, j in randomPairs(len(filterAList), len(filterBList), jSize):
+            nameA = filterAList[i]
+            nameB = filterBList[j]
+            key = getKey(nameA, nameB)
+            if key not in negative:
+                negative.add(key)
+                negativeRequired = negativeRequired - 1
+            if negativeRequired == 0:
+                break
+    return negative
+
+
+def randomPairs(iLen, jLen, jSize):
+    i = 0
+    jStart = 0
+    while i < iLen:
+        jMax = min(jStart + jSize, jLen)
+        for j in range(jStart, jMax):
+            yield i, j
+        i = i + 1
+        if i == iLen and jMax < jLen:
+            i = 0
+            jStart = jStart + jSize + 1
+
+
+def main(args):
+    # load source files
+    filterSets = getFilter(args.input)
+    filterKeys = list(filterSets.keys())
+    filterA = filterSets[filterKeys[0]]
+    if len(filterKeys) > 1:
+        filterB = filterSets[filterKeys[1]]
+    else:
+        filterB = filterA
+
+    # identify biogrid filter options
+    filterValues = list()
+    filterValues.append([11, args.method])
+
+    # process biogrid database
+    print("Loading positive set from BioGRID file...")
+    positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26,
+                                           separator="\t", filterA=filterA,
+                                           filterB=filterB, skipFirstLine=True,
+                                           filterValues=filterValues)
+
+    # estimate negative set
+    negative = getNegativeSet(args, filterA, filterB, positiveCount)
+
+    # get prediction results
+    print("Loading prediction file...")
+    prediction, _ = getReference(args.input, scoreCol=2, minScore=0.8)
+    mcc = getMCC(prediction, positive, positiveCount, negative)
+    yValues = [mcc]
+    yTicks = ["SPRING"]
+
+    # identify biogrid filter options
+    for method in ["Affinity Capture-MS",
+                   "Biochemical Activity",
+                   "Co-crystal Structure",
+                   "Co-fractionation",
+                   "Co-localization",
+                   "Co-purification",
+                   "Far Western",
+                   "FRET",
+                   "PCA",
+                   "Reconstituted Complex",
+                   "Two-hybrid"]:
+        if args.method != method:
+            print("Method: %s" % method)
+            filterValues = [[11, method]]
+            prediction, _ = getReference(args.biogrid, aCol=23, bCol=26,
+                                         separator="\t", filterA=filterA,
+                                         filterB=filterB, skipFirstLine=True,
+                                         filterValues=filterValues)
+            mcc = getMCC(prediction, positive, positiveCount, negative)
+            yValues.append(mcc)
+            yTicks.append(method)
+
+    # create plot
+    print("Producing plot data...")
+    print("Total count in prediction file: %d." % len(prediction))
+    print("Total count in positive file: %d." % len(positive))
+    plt.xlabel("Matthews-Correlation Coefficient (MCC)")
+    plt.title("Positive set: %s" % args.method)
+    plt.barh(yTicks, yValues)
+    plt.tight_layout()
+    plt.savefig(args.output, format="png")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Create ROC plot.')
+    parser.add_argument('-i', '--input', help='Input prediction file (2-columns).', required=True)
+    parser.add_argument('-b', '--biogrid', help='BioGRID interaction database file', required=True)
+    parser.add_argument('-l', '--locations', help='UniProt export table with subcellular locations', required=False)
+    parser.add_argument('-ra', '--region_a', help='First subcellular location', required=False)
+    parser.add_argument('-rb', '--region_b', help='Second subcellular location', required=False)
+    parser.add_argument('-n', '--negative', help='Negative set (2-columns)', required=False)
+    parser.add_argument('-t', '--throughput', help='Throughput (low/high)', required=False)
+    parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', required=False)
+    parser.add_argument('-o', '--output', help='Output (png)', required=True)
+    args = parser.parse_args()
+    main(args)