Mercurial > repos > guerler > springsuite
comparison spring_mcc.py @ 39:172398348efd draft
"planemo upload commit 26b4018c88041ee0ca7c2976e0a012015173d7b6-dirty"
author | guerler |
---|---|
date | Fri, 22 Jan 2021 15:50:27 +0000 |
parents | |
children | f316caf098a6 |
comparison
equal
deleted
inserted
replaced
38:80a4b98121b6 | 39:172398348efd |
---|---|
1 #! /usr/bin/env python | |
2 import argparse | |
3 import math | |
4 from os.path import isfile | |
5 import re | |
6 from matplotlib import pyplot as plt | |
7 | |
8 | |
9 def getIds(rawIds): | |
10 return rawIds.split("|") | |
11 | |
12 | |
13 def getCenterId(rawId): | |
14 elements = rawId.split("|") | |
15 if len(elements) > 1: | |
16 return elements[1] | |
17 return rawId | |
18 | |
19 | |
20 def getOrganism(rawId): | |
21 elements = rawId.split("_") | |
22 return elements[-1] | |
23 | |
24 | |
25 def getKey(a, b): | |
26 if a > b: | |
27 name = "%s_%s" % (a, b) | |
28 else: | |
29 name = "%s_%s" % (b, a) | |
30 return name | |
31 | |
32 | |
33 def getPercentage(rate, denominator): | |
34 if denominator > 0: | |
35 return 100.0 * rate / denominator | |
36 return 0.0 | |
37 | |
38 | |
39 def getFilter(filterName): | |
40 print("Loading target organism(s)...") | |
41 filterSets = dict() | |
42 with open(filterName) as filterFile: | |
43 for line in filterFile: | |
44 columns = line.split() | |
45 for colIndex in [0, 1]: | |
46 if colIndex >= len(columns): | |
47 break | |
48 colEntry = columns[colIndex] | |
49 id = getCenterId(colEntry) | |
50 organism = getOrganism(colEntry) | |
51 if organism not in filterSets: | |
52 filterSets[organism] = set() | |
53 filterSets[organism].add(id) | |
54 print("Organism(s) in set: %s." % filterSets.keys()) | |
55 return filterSets | |
56 | |
57 | |
58 def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0, | |
59 bCol=1, scoreCol=-1, separator=None, | |
60 skipFirstLine=False, filterValues=list()): | |
61 | |
62 index = dict() | |
63 count = 0 | |
64 with open(fileName) as fp: | |
65 line = fp.readline() | |
66 if skipFirstLine: | |
67 line = fp.readline() | |
68 while line: | |
69 ls = line.split(separator) | |
70 skipEntry = False | |
71 if separator is not None: | |
72 aList = getIds(ls[aCol]) | |
73 bList = getIds(ls[bCol]) | |
74 else: | |
75 aId = getCenterId(ls[aCol]) | |
76 bId = getCenterId(ls[bCol]) | |
77 aList = [aId] | |
78 bList = [bId] | |
79 if not skipEntry: | |
80 validEntry = False | |
81 for a in aList: | |
82 for b in bList: | |
83 skip = False | |
84 if a == "-" or b == "-": | |
85 skip = True | |
86 if filterA is not None and filterB is not None: | |
87 skip = not ((a in filterA and b in filterB) or (a in filterB and b in filterA)) | |
88 for f in filterValues: | |
89 if len(ls) > f[0]: | |
90 columnEntry = ls[f[0]].lower() | |
91 searchEntry = f[1].lower() | |
92 if columnEntry.find(searchEntry) == -1: | |
93 skip = True | |
94 if not skip: | |
95 name = getKey(a, b) | |
96 if name not in index: | |
97 validEntry = True | |
98 if scoreCol >= 0 and len(ls) > scoreCol: | |
99 score = float(ls[scoreCol]) | |
100 skip = False | |
101 if minScore is not None: | |
102 if minScore > score: | |
103 return index, count | |
104 if not skip: | |
105 index[name] = score | |
106 else: | |
107 index[name] = 1.0 | |
108 if validEntry: | |
109 count = count + 1 | |
110 line = fp.readline() | |
111 return index, count | |
112 | |
113 | |
114 def getMCC(prediction, positive, positiveCount, negative): | |
115 sortedPrediction = sorted(prediction.items(), key=lambda x: x[1], | |
116 reverse=True) | |
117 positiveTotal = positiveCount | |
118 negativeTotal = len(negative) | |
119 x = list([0]) | |
120 y = list([0]) | |
121 xMax = 0 | |
122 topCount = 0 | |
123 topMCC = 0.0 | |
124 topFP = 0.0 | |
125 topTP = 0.0 | |
126 topScore = 0.0 | |
127 tp = 0 | |
128 fp = 0 | |
129 count = 0 | |
130 for (name, score) in sortedPrediction: | |
131 found = False | |
132 if name in positive: | |
133 found = True | |
134 tp = tp + 1 | |
135 if name in negative: | |
136 found = True | |
137 fp = fp + 1 | |
138 fn = positiveTotal - tp | |
139 tn = negativeTotal - fp | |
140 denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn) | |
141 yValue = getPercentage(tp, tp + fn) | |
142 xValue = getPercentage(fp, fp + tn) | |
143 if denom > 0.0: | |
144 mcc = (tp*tn-fp*fn)/math.sqrt(denom) | |
145 if mcc >= topMCC: | |
146 topMCC = mcc | |
147 topScore = score | |
148 topCount = count | |
149 topFP = xValue | |
150 topTP = yValue | |
151 if found: | |
152 y.append(yValue) | |
153 x.append(xValue) | |
154 xMax = max(xValue, xMax) | |
155 count = count + 1 | |
156 if len(sortedPrediction) > 0: | |
157 print("Top ranking prediction %s." % str(sortedPrediction[0])) | |
158 print("Total count of prediction set: %s (tp=%1.2f, fp=%1.2f)." % (topCount, topTP, topFP)) | |
159 print("Total count of positive set: %s." % len(positive)) | |
160 print("Total count of negative set: %s." % len(negative)) | |
161 print("Matthews-Correlation-Coefficient: %s at Score >= %s." % (round(topMCC, 2), topScore)) | |
162 return topMCC | |
163 | |
164 | |
165 def getNegativeSet(args, filterA, filterB, negativeRequired, jSize=5): | |
166 # determine negative set | |
167 print("Identifying non-interacting pairs...") | |
168 negative = set() | |
169 if args.negative and isfile(args.negative): | |
170 # load from explicit file | |
171 with open(args.negative) as file: | |
172 for line in file: | |
173 cols = line.split() | |
174 nameA = cols[0] | |
175 nameB = cols[1] | |
176 key = getKey(nameA, nameB) | |
177 if key not in negative: | |
178 negative.add(key) | |
179 else: | |
180 if args.region_a and args.region_b: | |
181 locations = dict() | |
182 regionA = args.region_a.lower() | |
183 regionB = args.region_b.lower() | |
184 locations[regionA] = list() | |
185 locations[regionB] = list() | |
186 regions = [regionA, regionB] | |
187 print("Filtering regions %s" % str(regions)) | |
188 with open(args.locations) as locFile: | |
189 for line in locFile: | |
190 searchKey = "SUBCELLULAR LOCATION" | |
191 searchPos = line.find(searchKey) | |
192 if searchPos != -1: | |
193 uniId = line.split()[0] | |
194 if uniId not in filterA and uniId not in filterB: | |
195 continue | |
196 locStart = searchPos + len(searchKey) + 1 | |
197 locId = line[locStart:] | |
198 locId = re.sub(r"\s*{.*}\s*", "", locId) | |
199 locId = locId.replace(".", ",") | |
200 locId = locId.strip().lower() | |
201 filter_pos = locId.find("note=") | |
202 if filter_pos >= 0: | |
203 locId = locId[:filter_pos] | |
204 filter_pos = locId.find(";") | |
205 if filter_pos >= 0: | |
206 locId = locId[:filter_pos] | |
207 if locId: | |
208 locId = list(map(lambda x: x.strip(), locId.split(","))) | |
209 if (regionA in locId and regionB not in locId): | |
210 locations[regionA].append(uniId) | |
211 elif (regionA not in locId and regionB in locId): | |
212 locations[regionB].append(uniId) | |
213 filterAList = sorted(locations[regionA]) | |
214 filterBList = sorted(locations[regionB]) | |
215 else: | |
216 filterAList = list(filterA) | |
217 filterBList = list(filterB) | |
218 for i, j in randomPairs(len(filterAList), len(filterBList), jSize): | |
219 nameA = filterAList[i] | |
220 nameB = filterBList[j] | |
221 key = getKey(nameA, nameB) | |
222 if key not in negative: | |
223 negative.add(key) | |
224 negativeRequired = negativeRequired - 1 | |
225 if negativeRequired == 0: | |
226 break | |
227 return negative | |
228 | |
229 | |
230 def randomPairs(iLen, jLen, jSize): | |
231 i = 0 | |
232 jStart = 0 | |
233 while i < iLen: | |
234 jMax = min(jStart + jSize, jLen) | |
235 for j in range(jStart, jMax): | |
236 yield i, j | |
237 i = i + 1 | |
238 if i == iLen and jMax < jLen: | |
239 i = 0 | |
240 jStart = jStart + jSize + 1 | |
241 | |
242 | |
243 def main(args): | |
244 # load source files | |
245 filterSets = getFilter(args.input) | |
246 filterKeys = list(filterSets.keys()) | |
247 filterA = filterSets[filterKeys[0]] | |
248 if len(filterKeys) > 1: | |
249 filterB = filterSets[filterKeys[1]] | |
250 else: | |
251 filterB = filterA | |
252 | |
253 # identify biogrid filter options | |
254 filterValues = list() | |
255 filterValues.append([11, args.method]) | |
256 | |
257 # process biogrid database | |
258 print("Loading positive set from BioGRID file...") | |
259 positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26, | |
260 separator="\t", filterA=filterA, | |
261 filterB=filterB, skipFirstLine=True, | |
262 filterValues=filterValues) | |
263 | |
264 # estimate negative set | |
265 negative = getNegativeSet(args, filterA, filterB, positiveCount) | |
266 | |
267 # get prediction results | |
268 print("Loading prediction file...") | |
269 prediction, _ = getReference(args.input, scoreCol=2, minScore=0.8) | |
270 mcc = getMCC(prediction, positive, positiveCount, negative) | |
271 yValues = [mcc] | |
272 yTicks = ["SPRING"] | |
273 | |
274 # identify biogrid filter options | |
275 for method in ["Affinity Capture-MS", | |
276 "Biochemical Activity", | |
277 "Co-crystal Structure", | |
278 "Co-fractionation", | |
279 "Co-localization", | |
280 "Co-purification", | |
281 "Far Western", | |
282 "FRET", | |
283 "PCA", | |
284 "Reconstituted Complex", | |
285 "Two-hybrid"]: | |
286 if args.method != method: | |
287 print("Method: %s" % method) | |
288 filterValues = [[11, method]] | |
289 prediction, _ = getReference(args.biogrid, aCol=23, bCol=26, | |
290 separator="\t", filterA=filterA, | |
291 filterB=filterB, skipFirstLine=True, | |
292 filterValues=filterValues) | |
293 mcc = getMCC(prediction, positive, positiveCount, negative) | |
294 yValues.append(mcc) | |
295 yTicks.append(method) | |
296 | |
297 # create plot | |
298 print("Producing plot data...") | |
299 print("Total count in prediction file: %d." % len(prediction)) | |
300 print("Total count in positive file: %d." % len(positive)) | |
301 plt.xlabel("Matthews-Correlation Coefficient (MCC)") | |
302 plt.title("Positive set: %s" % args.method) | |
303 plt.barh(yTicks, yValues) | |
304 plt.tight_layout() | |
305 plt.savefig(args.output, format="png") | |
306 | |
307 | |
308 if __name__ == "__main__": | |
309 parser = argparse.ArgumentParser(description='Create ROC plot.') | |
310 parser.add_argument('-i', '--input', help='Input prediction file (2-columns).', required=True) | |
311 parser.add_argument('-b', '--biogrid', help='BioGRID interaction database file', required=True) | |
312 parser.add_argument('-l', '--locations', help='UniProt export table with subcellular locations', required=False) | |
313 parser.add_argument('-ra', '--region_a', help='First subcellular location', required=False) | |
314 parser.add_argument('-rb', '--region_b', help='Second subcellular location', required=False) | |
315 parser.add_argument('-n', '--negative', help='Negative set (2-columns)', required=False) | |
316 parser.add_argument('-t', '--throughput', help='Throughput (low/high)', required=False) | |
317 parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', required=False) | |
318 parser.add_argument('-o', '--output', help='Output (png)', required=True) | |
319 args = parser.parse_args() | |
320 main(args) |