comparison spring_roc.py @ 29:41353488926c draft

"planemo upload commit 1c0a60f98e36bccb6d6c85ff82a8d737a811b4d5"
author guerler
date Sun, 22 Nov 2020 14:15:24 +0000
parents
children b0e195a47df7
comparison
equal deleted inserted replaced
28:75d1aedc9b3f 29:41353488926c
1 #! /usr/bin/env python
2 import argparse
3 import math
4 import random
5 from datetime import datetime
6
7 from matplotlib import pyplot as plt
8
9
10 def getIds(rawIds):
11 return rawIds.split("|")
12
13
14 def getCenterId(rawId):
15 elements = rawId.split("|")
16 if len(elements) > 1:
17 return elements[1]
18 return rawId
19
20
21 def getOrganism(rawId):
22 elements = rawId.split("_")
23 return elements[-1]
24
25
26 def getKey(a, b):
27 if a > b:
28 name = "%s_%s" % (a, b)
29 else:
30 name = "%s_%s" % (b, a)
31 return name
32
33
34 def getPercentage(rate, denominator):
35 if denominator > 0:
36 return 100.0 * rate / denominator
37 return 0.0
38
39
40 def getFilter(filterName):
41 print("Loading target organism(s)...")
42 filterSets = dict()
43 with open(filterName) as filterFile:
44 for line in filterFile:
45 columns = line.split()
46 for colIndex in [0, 1]:
47 if colIndex >= len(columns):
48 break
49 colEntry = columns[colIndex]
50 id = getCenterId(colEntry)
51 organism = getOrganism(colEntry)
52 if organism not in filterSets:
53 filterSets[organism] = set()
54 filterSets[organism].add(id)
55 print("Organism(s) in set: %s." % filterSets.keys())
56 return filterSets
57
58
59 def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0,
60 bCol=1, scoreCol=-1, separator=None,
61 skipFirstLine=False, filterValues=list()):
62 index = dict()
63 count = 0
64 with open(fileName) as fp:
65 line = fp.readline()
66 if skipFirstLine:
67 line = fp.readline()
68 while line:
69 ls = line.split(separator)
70 if separator is not None:
71 aList = getIds(ls[aCol])
72 bList = getIds(ls[bCol])
73 else:
74 aList = [getCenterId(ls[aCol])]
75 bList = [getCenterId(ls[bCol])]
76 validEntry = False
77 for a in aList:
78 for b in bList:
79 skip = False
80 if a == "-" or b == "-":
81 skip = True
82 if filterA is not None:
83 if a not in filterA and b not in filterA:
84 skip = True
85 if filterB is not None:
86 if a not in filterB and b not in filterB:
87 skip = True
88 for f in filterValues:
89 if len(ls) > f[0]:
90 columnEntry = ls[f[0]].lower()
91 searchEntry = f[1].lower()
92 if columnEntry.find(searchEntry) == -1:
93 skip = True
94 if not skip:
95 name = getKey(a, b)
96 if name not in index:
97 validEntry = True
98 if scoreCol >= 0 and len(ls) > scoreCol:
99 score = float(ls[scoreCol])
100 skip = False
101 if minScore is not None:
102 if minScore > score:
103 return index, count
104 if not skip:
105 index[name] = score
106 else:
107 index[name] = 1.0
108 if validEntry:
109 count = count + 1
110 line = fp.readline()
111 return index, count
112
113
114 def getXY(prediction, positive, positiveCount, negative):
115 sortedPrediction = sorted(prediction.items(), key=lambda x: x[1],
116 reverse=True)
117 positiveTotal = positiveCount
118 negativeTotal = len(negative)
119 x = list([0])
120 y = list([0])
121 xMax = 0
122 topCount = 0
123 topMCC = 0.0
124 topPrecision = 0.0
125 topScore = 0.0
126 tp = 0
127 fp = 0
128 count = 0
129 for (name, score) in sortedPrediction:
130 found = False
131 if name in positive:
132 found = True
133 tp = tp + 1
134 if name in negative:
135 found = True
136 fp = fp + 1
137 precision = 0.0
138 if tp > 0 or fp > 0:
139 precision = tp / (tp + fp)
140 fn = positiveTotal - tp
141 tn = negativeTotal - fp
142 denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)
143 if denom > 0.0:
144 mcc = (tp*tn-fp*fn)/math.sqrt(denom)
145 if mcc >= topMCC:
146 topMCC = mcc
147 topScore = score
148 topCount = count
149 topPrecision = precision
150 if found:
151 yValue = getPercentage(tp, tp + fn)
152 xValue = getPercentage(fp, fp + tn)
153 y.append(yValue)
154 x.append(xValue)
155 xMax = max(xValue, xMax)
156 count = count + 1
157 print("Top ranking prediction %s." % str(sortedPrediction[0]))
158 print("Total count of prediction set: %s (precision=%1.2f)." %
159 (topCount, topPrecision))
160 print("Total count of positive set: %s." % len(positive))
161 print("Total count of negative set: %s." % len(negative))
162 print("Matthews-Correlation-Coefficient: %s at Score >= %s." %
163 (round(topMCC, 2), topScore))
164 return x, y, xMax
165
166
167 def main(args):
168 # load source files
169 filterSets = getFilter(args.input)
170 filterKeys = list(filterSets.keys())
171 filterA = filterSets[filterKeys[0]]
172 if len(filterKeys) > 1:
173 filterB = filterSets[filterKeys[1]]
174 else:
175 filterB = filterA
176
177 # identify biogrid filter options
178 filterValues = []
179 if args.method:
180 filterValues.append([11, args.method])
181 if args.experiment:
182 filterValues.append([12, args.experiment])
183 if args.throughput:
184 filterValues.append([17, args.throughput])
185
186 # process biogrid database
187 print("Loading positive set from BioGRID file...")
188 positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26,
189 separator="\t", filterA=filterA,
190 filterB=filterB, skipFirstLine=True,
191 filterValues=filterValues)
192
193 # rescan biogrid database to identify set of putative interactions
194 if filterValues:
195 print("Filtered entries by (column, value): %s" % filterValues)
196 print("Loading putative set from BioGRID file...")
197 putative, putativeCount = getReference(args.biogrid, aCol=23, bCol=26,
198 separator="\t", filterA=filterA,
199 filterB=filterB,
200 skipFirstLine=True)
201 print("Found %s." % putativeCount)
202 else:
203 putative = positive
204
205 # process prediction file
206 print("Loading prediction file...")
207 prediction, _ = getReference(args.input, scoreCol=2)
208
209 # estimate background noise
210 print("Estimating background noise...")
211 negative = set()
212 filterAList = list(filterA)
213 filterBList = list(filterB)
214 negativeCount = positiveCount
215 negativeRequired = negativeCount
216 random.seed(datetime.now())
217 while negativeRequired > 0:
218 nameA = random.choice(filterAList)
219 nameB = random.choice(filterBList)
220 key = getKey(nameA, nameB)
221 if key not in putative and key not in negative:
222 negative.add(key)
223 negativeRequired = negativeRequired - 1
224
225 # create plot
226 print("Producing plot data...")
227 print("Total count in prediction file: %d." % len(prediction))
228 print("Total count in positive file: %d." % len(positive))
229 plt.ylabel('True Positive Rate (%)')
230 plt.xlabel('False Positive Rate (%)')
231 title = " vs. ".join(filterSets)
232 plt.suptitle(title)
233 if filterValues:
234 filterAttributes = list(map(lambda x: x[1], filterValues))
235 plt.title("BioGRID filters: %s" % filterAttributes, fontsize=10)
236 x, y, xMax = getXY(prediction, positive, positiveCount, negative)
237 plt.plot(x, y)
238 plt.plot([0, xMax], [0, xMax])
239 plt.savefig(args.output, format="png")
240
241
242 if __name__ == "__main__":
243 parser = argparse.ArgumentParser(description='Create ROC plot.')
244 parser.add_argument('-i', '--input', help='Input prediction file.',
245 required=True)
246 parser.add_argument('-b', '--biogrid', help='BioGRID interaction ' +
247 'database file', required=True)
248 parser.add_argument('-e', '--experiment', help='Type (physical/genetic)',
249 default="", required=False)
250 parser.add_argument('-t', '--throughput', help='Throughput (low/high)',
251 default="", required=False)
252 parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid',
253 default="", required=False)
254 parser.add_argument('-o', '--output', help='Output (png)', required=True)
255 args = parser.parse_args()
256 main(args)