Mercurial > repos > guerler > springsuite
comparison spring_roc.py @ 29:41353488926c draft
"planemo upload commit 1c0a60f98e36bccb6d6c85ff82a8d737a811b4d5"
author | guerler |
---|---|
date | Sun, 22 Nov 2020 14:15:24 +0000 |
parents | |
children | b0e195a47df7 |
comparison
equal
deleted
inserted
replaced
28:75d1aedc9b3f | 29:41353488926c |
---|---|
1 #! /usr/bin/env python | |
2 import argparse | |
3 import math | |
4 import random | |
5 from datetime import datetime | |
6 | |
7 from matplotlib import pyplot as plt | |
8 | |
9 | |
10 def getIds(rawIds): | |
11 return rawIds.split("|") | |
12 | |
13 | |
14 def getCenterId(rawId): | |
15 elements = rawId.split("|") | |
16 if len(elements) > 1: | |
17 return elements[1] | |
18 return rawId | |
19 | |
20 | |
21 def getOrganism(rawId): | |
22 elements = rawId.split("_") | |
23 return elements[-1] | |
24 | |
25 | |
26 def getKey(a, b): | |
27 if a > b: | |
28 name = "%s_%s" % (a, b) | |
29 else: | |
30 name = "%s_%s" % (b, a) | |
31 return name | |
32 | |
33 | |
34 def getPercentage(rate, denominator): | |
35 if denominator > 0: | |
36 return 100.0 * rate / denominator | |
37 return 0.0 | |
38 | |
39 | |
40 def getFilter(filterName): | |
41 print("Loading target organism(s)...") | |
42 filterSets = dict() | |
43 with open(filterName) as filterFile: | |
44 for line in filterFile: | |
45 columns = line.split() | |
46 for colIndex in [0, 1]: | |
47 if colIndex >= len(columns): | |
48 break | |
49 colEntry = columns[colIndex] | |
50 id = getCenterId(colEntry) | |
51 organism = getOrganism(colEntry) | |
52 if organism not in filterSets: | |
53 filterSets[organism] = set() | |
54 filterSets[organism].add(id) | |
55 print("Organism(s) in set: %s." % filterSets.keys()) | |
56 return filterSets | |
57 | |
58 | |
59 def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0, | |
60 bCol=1, scoreCol=-1, separator=None, | |
61 skipFirstLine=False, filterValues=list()): | |
62 index = dict() | |
63 count = 0 | |
64 with open(fileName) as fp: | |
65 line = fp.readline() | |
66 if skipFirstLine: | |
67 line = fp.readline() | |
68 while line: | |
69 ls = line.split(separator) | |
70 if separator is not None: | |
71 aList = getIds(ls[aCol]) | |
72 bList = getIds(ls[bCol]) | |
73 else: | |
74 aList = [getCenterId(ls[aCol])] | |
75 bList = [getCenterId(ls[bCol])] | |
76 validEntry = False | |
77 for a in aList: | |
78 for b in bList: | |
79 skip = False | |
80 if a == "-" or b == "-": | |
81 skip = True | |
82 if filterA is not None: | |
83 if a not in filterA and b not in filterA: | |
84 skip = True | |
85 if filterB is not None: | |
86 if a not in filterB and b not in filterB: | |
87 skip = True | |
88 for f in filterValues: | |
89 if len(ls) > f[0]: | |
90 columnEntry = ls[f[0]].lower() | |
91 searchEntry = f[1].lower() | |
92 if columnEntry.find(searchEntry) == -1: | |
93 skip = True | |
94 if not skip: | |
95 name = getKey(a, b) | |
96 if name not in index: | |
97 validEntry = True | |
98 if scoreCol >= 0 and len(ls) > scoreCol: | |
99 score = float(ls[scoreCol]) | |
100 skip = False | |
101 if minScore is not None: | |
102 if minScore > score: | |
103 return index, count | |
104 if not skip: | |
105 index[name] = score | |
106 else: | |
107 index[name] = 1.0 | |
108 if validEntry: | |
109 count = count + 1 | |
110 line = fp.readline() | |
111 return index, count | |
112 | |
113 | |
114 def getXY(prediction, positive, positiveCount, negative): | |
115 sortedPrediction = sorted(prediction.items(), key=lambda x: x[1], | |
116 reverse=True) | |
117 positiveTotal = positiveCount | |
118 negativeTotal = len(negative) | |
119 x = list([0]) | |
120 y = list([0]) | |
121 xMax = 0 | |
122 topCount = 0 | |
123 topMCC = 0.0 | |
124 topPrecision = 0.0 | |
125 topScore = 0.0 | |
126 tp = 0 | |
127 fp = 0 | |
128 count = 0 | |
129 for (name, score) in sortedPrediction: | |
130 found = False | |
131 if name in positive: | |
132 found = True | |
133 tp = tp + 1 | |
134 if name in negative: | |
135 found = True | |
136 fp = fp + 1 | |
137 precision = 0.0 | |
138 if tp > 0 or fp > 0: | |
139 precision = tp / (tp + fp) | |
140 fn = positiveTotal - tp | |
141 tn = negativeTotal - fp | |
142 denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) | |
143 if denom > 0.0: | |
144 mcc = (tp*tn-fp*fn)/math.sqrt(denom) | |
145 if mcc >= topMCC: | |
146 topMCC = mcc | |
147 topScore = score | |
148 topCount = count | |
149 topPrecision = precision | |
150 if found: | |
151 yValue = getPercentage(tp, tp + fn) | |
152 xValue = getPercentage(fp, fp + tn) | |
153 y.append(yValue) | |
154 x.append(xValue) | |
155 xMax = max(xValue, xMax) | |
156 count = count + 1 | |
157 print("Top ranking prediction %s." % str(sortedPrediction[0])) | |
158 print("Total count of prediction set: %s (precision=%1.2f)." % | |
159 (topCount, topPrecision)) | |
160 print("Total count of positive set: %s." % len(positive)) | |
161 print("Total count of negative set: %s." % len(negative)) | |
162 print("Matthews-Correlation-Coefficient: %s at Score >= %s." % | |
163 (round(topMCC, 2), topScore)) | |
164 return x, y, xMax | |
165 | |
166 | |
167 def main(args): | |
168 # load source files | |
169 filterSets = getFilter(args.input) | |
170 filterKeys = list(filterSets.keys()) | |
171 filterA = filterSets[filterKeys[0]] | |
172 if len(filterKeys) > 1: | |
173 filterB = filterSets[filterKeys[1]] | |
174 else: | |
175 filterB = filterA | |
176 | |
177 # identify biogrid filter options | |
178 filterValues = [] | |
179 if args.method: | |
180 filterValues.append([11, args.method]) | |
181 if args.experiment: | |
182 filterValues.append([12, args.experiment]) | |
183 if args.throughput: | |
184 filterValues.append([17, args.throughput]) | |
185 | |
186 # process biogrid database | |
187 print("Loading positive set from BioGRID file...") | |
188 positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26, | |
189 separator="\t", filterA=filterA, | |
190 filterB=filterB, skipFirstLine=True, | |
191 filterValues=filterValues) | |
192 | |
193 # rescan biogrid database to identify set of putative interactions | |
194 if filterValues: | |
195 print("Filtered entries by (column, value): %s" % filterValues) | |
196 print("Loading putative set from BioGRID file...") | |
197 putative, putativeCount = getReference(args.biogrid, aCol=23, bCol=26, | |
198 separator="\t", filterA=filterA, | |
199 filterB=filterB, | |
200 skipFirstLine=True) | |
201 print("Found %s." % putativeCount) | |
202 else: | |
203 putative = positive | |
204 | |
205 # process prediction file | |
206 print("Loading prediction file...") | |
207 prediction, _ = getReference(args.input, scoreCol=2) | |
208 | |
209 # estimate background noise | |
210 print("Estimating background noise...") | |
211 negative = set() | |
212 filterAList = list(filterA) | |
213 filterBList = list(filterB) | |
214 negativeCount = positiveCount | |
215 negativeRequired = negativeCount | |
216 random.seed(datetime.now()) | |
217 while negativeRequired > 0: | |
218 nameA = random.choice(filterAList) | |
219 nameB = random.choice(filterBList) | |
220 key = getKey(nameA, nameB) | |
221 if key not in putative and key not in negative: | |
222 negative.add(key) | |
223 negativeRequired = negativeRequired - 1 | |
224 | |
225 # create plot | |
226 print("Producing plot data...") | |
227 print("Total count in prediction file: %d." % len(prediction)) | |
228 print("Total count in positive file: %d." % len(positive)) | |
229 plt.ylabel('True Positive Rate (%)') | |
230 plt.xlabel('False Positive Rate (%)') | |
231 title = " vs. ".join(filterSets) | |
232 plt.suptitle(title) | |
233 if filterValues: | |
234 filterAttributes = list(map(lambda x: x[1], filterValues)) | |
235 plt.title("BioGRID filters: %s" % filterAttributes, fontsize=10) | |
236 x, y, xMax = getXY(prediction, positive, positiveCount, negative) | |
237 plt.plot(x, y) | |
238 plt.plot([0, xMax], [0, xMax]) | |
239 plt.savefig(args.output, format="png") | |
240 | |
241 | |
242 if __name__ == "__main__": | |
243 parser = argparse.ArgumentParser(description='Create ROC plot.') | |
244 parser.add_argument('-i', '--input', help='Input prediction file.', | |
245 required=True) | |
246 parser.add_argument('-b', '--biogrid', help='BioGRID interaction ' + | |
247 'database file', required=True) | |
248 parser.add_argument('-e', '--experiment', help='Type (physical/genetic)', | |
249 default="", required=False) | |
250 parser.add_argument('-t', '--throughput', help='Throughput (low/high)', | |
251 default="", required=False) | |
252 parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', | |
253 default="", required=False) | |
254 parser.add_argument('-o', '--output', help='Output (png)', required=True) | |
255 args = parser.parse_args() | |
256 main(args) |