comparison spring_mcc.py @ 39:172398348efd draft

"planemo upload commit 26b4018c88041ee0ca7c2976e0a012015173d7b6-dirty"
author guerler
date Fri, 22 Jan 2021 15:50:27 +0000
parents
children f316caf098a6
comparison
equal deleted inserted replaced
38:80a4b98121b6 39:172398348efd
1 #! /usr/bin/env python
2 import argparse
3 import math
4 from os.path import isfile
5 import re
6 from matplotlib import pyplot as plt
7
8
9 def getIds(rawIds):
10 return rawIds.split("|")
11
12
13 def getCenterId(rawId):
14 elements = rawId.split("|")
15 if len(elements) > 1:
16 return elements[1]
17 return rawId
18
19
20 def getOrganism(rawId):
21 elements = rawId.split("_")
22 return elements[-1]
23
24
25 def getKey(a, b):
26 if a > b:
27 name = "%s_%s" % (a, b)
28 else:
29 name = "%s_%s" % (b, a)
30 return name
31
32
33 def getPercentage(rate, denominator):
34 if denominator > 0:
35 return 100.0 * rate / denominator
36 return 0.0
37
38
39 def getFilter(filterName):
40 print("Loading target organism(s)...")
41 filterSets = dict()
42 with open(filterName) as filterFile:
43 for line in filterFile:
44 columns = line.split()
45 for colIndex in [0, 1]:
46 if colIndex >= len(columns):
47 break
48 colEntry = columns[colIndex]
49 id = getCenterId(colEntry)
50 organism = getOrganism(colEntry)
51 if organism not in filterSets:
52 filterSets[organism] = set()
53 filterSets[organism].add(id)
54 print("Organism(s) in set: %s." % filterSets.keys())
55 return filterSets
56
57
58 def getReference(fileName, filterA=None, filterB=None, minScore=None, aCol=0,
59 bCol=1, scoreCol=-1, separator=None,
60 skipFirstLine=False, filterValues=list()):
61
62 index = dict()
63 count = 0
64 with open(fileName) as fp:
65 line = fp.readline()
66 if skipFirstLine:
67 line = fp.readline()
68 while line:
69 ls = line.split(separator)
70 skipEntry = False
71 if separator is not None:
72 aList = getIds(ls[aCol])
73 bList = getIds(ls[bCol])
74 else:
75 aId = getCenterId(ls[aCol])
76 bId = getCenterId(ls[bCol])
77 aList = [aId]
78 bList = [bId]
79 if not skipEntry:
80 validEntry = False
81 for a in aList:
82 for b in bList:
83 skip = False
84 if a == "-" or b == "-":
85 skip = True
86 if filterA is not None and filterB is not None:
87 skip = not ((a in filterA and b in filterB) or (a in filterB and b in filterA))
88 for f in filterValues:
89 if len(ls) > f[0]:
90 columnEntry = ls[f[0]].lower()
91 searchEntry = f[1].lower()
92 if columnEntry.find(searchEntry) == -1:
93 skip = True
94 if not skip:
95 name = getKey(a, b)
96 if name not in index:
97 validEntry = True
98 if scoreCol >= 0 and len(ls) > scoreCol:
99 score = float(ls[scoreCol])
100 skip = False
101 if minScore is not None:
102 if minScore > score:
103 return index, count
104 if not skip:
105 index[name] = score
106 else:
107 index[name] = 1.0
108 if validEntry:
109 count = count + 1
110 line = fp.readline()
111 return index, count
112
113
114 def getMCC(prediction, positive, positiveCount, negative):
115 sortedPrediction = sorted(prediction.items(), key=lambda x: x[1],
116 reverse=True)
117 positiveTotal = positiveCount
118 negativeTotal = len(negative)
119 x = list([0])
120 y = list([0])
121 xMax = 0
122 topCount = 0
123 topMCC = 0.0
124 topFP = 0.0
125 topTP = 0.0
126 topScore = 0.0
127 tp = 0
128 fp = 0
129 count = 0
130 for (name, score) in sortedPrediction:
131 found = False
132 if name in positive:
133 found = True
134 tp = tp + 1
135 if name in negative:
136 found = True
137 fp = fp + 1
138 fn = positiveTotal - tp
139 tn = negativeTotal - fp
140 denom = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)
141 yValue = getPercentage(tp, tp + fn)
142 xValue = getPercentage(fp, fp + tn)
143 if denom > 0.0:
144 mcc = (tp*tn-fp*fn)/math.sqrt(denom)
145 if mcc >= topMCC:
146 topMCC = mcc
147 topScore = score
148 topCount = count
149 topFP = xValue
150 topTP = yValue
151 if found:
152 y.append(yValue)
153 x.append(xValue)
154 xMax = max(xValue, xMax)
155 count = count + 1
156 if len(sortedPrediction) > 0:
157 print("Top ranking prediction %s." % str(sortedPrediction[0]))
158 print("Total count of prediction set: %s (tp=%1.2f, fp=%1.2f)." % (topCount, topTP, topFP))
159 print("Total count of positive set: %s." % len(positive))
160 print("Total count of negative set: %s." % len(negative))
161 print("Matthews-Correlation-Coefficient: %s at Score >= %s." % (round(topMCC, 2), topScore))
162 return topMCC
163
164
165 def getNegativeSet(args, filterA, filterB, negativeRequired, jSize=5):
166 # determine negative set
167 print("Identifying non-interacting pairs...")
168 negative = set()
169 if args.negative and isfile(args.negative):
170 # load from explicit file
171 with open(args.negative) as file:
172 for line in file:
173 cols = line.split()
174 nameA = cols[0]
175 nameB = cols[1]
176 key = getKey(nameA, nameB)
177 if key not in negative:
178 negative.add(key)
179 else:
180 if args.region_a and args.region_b:
181 locations = dict()
182 regionA = args.region_a.lower()
183 regionB = args.region_b.lower()
184 locations[regionA] = list()
185 locations[regionB] = list()
186 regions = [regionA, regionB]
187 print("Filtering regions %s" % str(regions))
188 with open(args.locations) as locFile:
189 for line in locFile:
190 searchKey = "SUBCELLULAR LOCATION"
191 searchPos = line.find(searchKey)
192 if searchPos != -1:
193 uniId = line.split()[0]
194 if uniId not in filterA and uniId not in filterB:
195 continue
196 locStart = searchPos + len(searchKey) + 1
197 locId = line[locStart:]
198 locId = re.sub(r"\s*{.*}\s*", "", locId)
199 locId = locId.replace(".", ",")
200 locId = locId.strip().lower()
201 filter_pos = locId.find("note=")
202 if filter_pos >= 0:
203 locId = locId[:filter_pos]
204 filter_pos = locId.find(";")
205 if filter_pos >= 0:
206 locId = locId[:filter_pos]
207 if locId:
208 locId = list(map(lambda x: x.strip(), locId.split(",")))
209 if (regionA in locId and regionB not in locId):
210 locations[regionA].append(uniId)
211 elif (regionA not in locId and regionB in locId):
212 locations[regionB].append(uniId)
213 filterAList = sorted(locations[regionA])
214 filterBList = sorted(locations[regionB])
215 else:
216 filterAList = list(filterA)
217 filterBList = list(filterB)
218 for i, j in randomPairs(len(filterAList), len(filterBList), jSize):
219 nameA = filterAList[i]
220 nameB = filterBList[j]
221 key = getKey(nameA, nameB)
222 if key not in negative:
223 negative.add(key)
224 negativeRequired = negativeRequired - 1
225 if negativeRequired == 0:
226 break
227 return negative
228
229
230 def randomPairs(iLen, jLen, jSize):
231 i = 0
232 jStart = 0
233 while i < iLen:
234 jMax = min(jStart + jSize, jLen)
235 for j in range(jStart, jMax):
236 yield i, j
237 i = i + 1
238 if i == iLen and jMax < jLen:
239 i = 0
240 jStart = jStart + jSize + 1
241
242
243 def main(args):
244 # load source files
245 filterSets = getFilter(args.input)
246 filterKeys = list(filterSets.keys())
247 filterA = filterSets[filterKeys[0]]
248 if len(filterKeys) > 1:
249 filterB = filterSets[filterKeys[1]]
250 else:
251 filterB = filterA
252
253 # identify biogrid filter options
254 filterValues = list()
255 filterValues.append([11, args.method])
256
257 # process biogrid database
258 print("Loading positive set from BioGRID file...")
259 positive, positiveCount = getReference(args.biogrid, aCol=23, bCol=26,
260 separator="\t", filterA=filterA,
261 filterB=filterB, skipFirstLine=True,
262 filterValues=filterValues)
263
264 # estimate negative set
265 negative = getNegativeSet(args, filterA, filterB, positiveCount)
266
267 # get prediction results
268 print("Loading prediction file...")
269 prediction, _ = getReference(args.input, scoreCol=2, minScore=0.8)
270 mcc = getMCC(prediction, positive, positiveCount, negative)
271 yValues = [mcc]
272 yTicks = ["SPRING"]
273
274 # identify biogrid filter options
275 for method in ["Affinity Capture-MS",
276 "Biochemical Activity",
277 "Co-crystal Structure",
278 "Co-fractionation",
279 "Co-localization",
280 "Co-purification",
281 "Far Western",
282 "FRET",
283 "PCA",
284 "Reconstituted Complex",
285 "Two-hybrid"]:
286 if args.method != method:
287 print("Method: %s" % method)
288 filterValues = [[11, method]]
289 prediction, _ = getReference(args.biogrid, aCol=23, bCol=26,
290 separator="\t", filterA=filterA,
291 filterB=filterB, skipFirstLine=True,
292 filterValues=filterValues)
293 mcc = getMCC(prediction, positive, positiveCount, negative)
294 yValues.append(mcc)
295 yTicks.append(method)
296
297 # create plot
298 print("Producing plot data...")
299 print("Total count in prediction file: %d." % len(prediction))
300 print("Total count in positive file: %d." % len(positive))
301 plt.xlabel("Matthews-Correlation Coefficient (MCC)")
302 plt.title("Positive set: %s" % args.method)
303 plt.barh(yTicks, yValues)
304 plt.tight_layout()
305 plt.savefig(args.output, format="png")
306
307
308 if __name__ == "__main__":
309 parser = argparse.ArgumentParser(description='Create ROC plot.')
310 parser.add_argument('-i', '--input', help='Input prediction file (2-columns).', required=True)
311 parser.add_argument('-b', '--biogrid', help='BioGRID interaction database file', required=True)
312 parser.add_argument('-l', '--locations', help='UniProt export table with subcellular locations', required=False)
313 parser.add_argument('-ra', '--region_a', help='First subcellular location', required=False)
314 parser.add_argument('-rb', '--region_b', help='Second subcellular location', required=False)
315 parser.add_argument('-n', '--negative', help='Negative set (2-columns)', required=False)
316 parser.add_argument('-t', '--throughput', help='Throughput (low/high)', required=False)
317 parser.add_argument('-m', '--method', help='Method e.g. Two-hybrid', required=False)
318 parser.add_argument('-o', '--output', help='Output (png)', required=True)
319 args = parser.parse_args()
320 main(args)