comparison ml_visualization_ex.py @ 0:13226b2ddfb4 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
author bgruening
date Wed, 22 Jan 2020 07:51:20 -0500
parents
children 9b70bf3306e5
comparison
equal deleted inserted replaced
-1:000000000000 0:13226b2ddfb4
1 import argparse
2 import json
3 import matplotlib
4 import matplotlib.pyplot as plt
5 import numpy as np
6 import os
7 import pandas as pd
8 import plotly
9 import plotly.graph_objs as go
10 import warnings
11
12 from keras.models import model_from_json
13 from keras.utils import plot_model
14 from sklearn.feature_selection.base import SelectorMixin
15 from sklearn.metrics import precision_recall_curve, average_precision_score
16 from sklearn.metrics import roc_curve, auc
17 from sklearn.pipeline import Pipeline
18 from galaxy_ml.utils import load_model, read_columns, SafeEval
19
20
21 safe_eval = SafeEval()
22
23 # plotly default colors
24 default_colors = [
25 '#1f77b4', # muted blue
26 '#ff7f0e', # safety orange
27 '#2ca02c', # cooked asparagus green
28 '#d62728', # brick red
29 '#9467bd', # muted purple
30 '#8c564b', # chestnut brown
31 '#e377c2', # raspberry yogurt pink
32 '#7f7f7f', # middle gray
33 '#bcbd22', # curry yellow-green
34 '#17becf' # blue-teal
35 ]
36
37
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly
40
41 df1 : pandas.DataFrame
42 Containing y_true
43 df2 : pandas.DataFrame
44 Containing y_score
45 pos_label : None
46 The label of positive class
47 title : str
48 Plot title
49 """
50 data = []
51 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values
54
55 precision, recall, _ = precision_recall_curve(
56 y_true, y_score, pos_label=pos_label)
57 ap = average_precision_score(
58 y_true, y_score, pos_label=pos_label or 1)
59
60 trace = go.Scatter(
61 x=recall,
62 y=precision,
63 mode='lines',
64 marker=dict(
65 color=default_colors[idx % len(default_colors)]
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 )
69 data.append(trace)
70
71 layout = go.Layout(
72 xaxis=dict(
73 title='Recall',
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict(
83 text=title or 'Precision-Recall Curve',
84 x=0.5,
85 y=0.92,
86 xanchor='center',
87 yanchor='top'
88 ),
89 font=dict(
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)'
95 )
96 """
97 legend=dict(
98 x=0.95,
99 y=0,
100 traceorder="normal",
101 font=dict(
102 family="sans-serif",
103 size=9,
104 color="black"
105 ),
106 bgcolor="LightSteelBlue",
107 bordercolor="Black",
108 borderwidth=2
109 ),"""
110
111 fig = go.Figure(data=data, layout=layout)
112
113 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output')
116
117
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image
120 """
121 backend = matplotlib.get_backend()
122 if "inline" not in backend:
123 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind')
125 plt.figure()
126
127 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values
130
131 precision, recall, _ = precision_recall_curve(
132 y_true, y_score, pos_label=pos_label)
133 ap = average_precision_score(
134 y_true, y_score, pos_label=pos_label or 1)
135
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3,
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap))
138
139 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall')
142 plt.ylabel('Precision')
143 title = title or 'Precision-Recall Curve'
144 plt.title(title)
145 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"),
148 os.path.join(folder, "output"))
149
150
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly
155
156 df1 : pandas.DataFrame
157 Containing y_true
158 df2 : pandas.DataFrame
159 Containing y_score
160 pos_label : None
161 The label of positive class
162 drop_intermediate : bool
163 Whether to drop some suboptimal thresholds
164 title : str
165 Plot title
166 """
167 data = []
168 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values
171
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
173 drop_intermediate=drop_intermediate)
174 roc_auc = auc(fpr, tpr)
175
176 trace = go.Scatter(
177 x=fpr,
178 y=tpr,
179 mode='lines',
180 marker=dict(
181 color=default_colors[idx % len(default_colors)]
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 )
185 data.append(trace)
186
187 layout = go.Layout(
188 xaxis=dict(
189 title='False Positive Rate',
190 linecolor='lightslategray',
191 linewidth=1
192 ),
193 yaxis=dict(
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve',
200 x=0.5,
201 y=0.92,
202 xanchor='center',
203 yanchor='top'
204 ),
205 font=dict(
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)'
211 )
212 """
213 # legend=dict(
214 # x=0.95,
215 # y=0,
216 # traceorder="normal",
217 # font=dict(
218 # family="sans-serif",
219 # size=9,
220 # color="black"
221 # ),
222 # bgcolor="LightSteelBlue",
223 # bordercolor="Black",
224 # borderwidth=2
225 # ),
226 """
227
228 fig = go.Figure(data=data, layout=layout)
229
230 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output')
233
234
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label,
236 drop_intermediate=True,
237 title=None):
238 """visualize roc-curve using matplotlib and output svg image
239 """
240 backend = matplotlib.get_backend()
241 if "inline" not in backend:
242 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind')
244 plt.figure()
245
246 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values
249
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
251 drop_intermediate=drop_intermediate)
252 roc_auc = auc(fpr, tpr)
253
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1,
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc))
256
257 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate')
260 plt.ylabel('True Positive Rate')
261 title = title or 'Receiver Operating Characteristic (ROC) Curve'
262 plt.title(title)
263 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"),
266 os.path.join(folder, "output"))
267
268
269 def main(inputs, infile_estimator=None, infile1=None,
270 infile2=None, outfile_result=None,
271 outfile_object=None, groups=None,
272 ref_seq=None, intervals=None,
273 targets=None, fasta_path=None,
274 model_config=None):
275 """
276 Parameter
277 ---------
278 inputs : str
279 File path to galaxy tool parameter
280
281 infile_estimator : str, default is None
282 File path to estimator
283
284 infile1 : str, default is None
285 File path to dataset containing features or true labels.
286
287 infile2 : str, default is None
288 File path to dataset containing target values or predicted
289 probabilities.
290
291 outfile_result : str, default is None
292 File path to save the results, either cv_results or test result
293
294 outfile_object : str, default is None
295 File path to save searchCV object
296
297 groups : str, default is None
298 File path to dataset containing groups labels
299
300 ref_seq : str, default is None
301 File path to dataset containing genome sequence file
302
303 intervals : str, default is None
304 File path to dataset containing interval file
305
306 targets : str, default is None
307 File path to dataset compressed target bed file
308
309 fasta_path : str, default is None
310 File path to dataset containing fasta file
311
312 model_config : str, default is None
313 File path to dataset containing JSON config for neural networks
314 """
315 warnings.simplefilter('ignore')
316
317 with open(inputs, 'r') as param_handler:
318 params = json.load(param_handler)
319
320 title = params['plotting_selection']['title'].strip()
321 plot_type = params['plotting_selection']['plot_type']
322 plot_format = params['plotting_selection']['plot_format']
323
324 if plot_type == 'feature_importances':
325 with open(infile_estimator, 'rb') as estimator_handler:
326 estimator = load_model(estimator_handler)
327
328 column_option = (params['plotting_selection']
329 ['column_selector_options']
330 ['selected_column_selector_option'])
331 if column_option in ['by_index_number', 'all_but_by_index_number',
332 'by_header_name', 'all_but_by_header_name']:
333 c = (params['plotting_selection']
334 ['column_selector_options']['col1'])
335 else:
336 c = None
337
338 _, input_df = read_columns(infile1, c=c,
339 c_option=column_option,
340 return_df=True,
341 sep='\t', header='infer',
342 parse_dates=True)
343
344 feature_names = input_df.columns.values
345
346 if isinstance(estimator, Pipeline):
347 for st in estimator.steps[:-1]:
348 if isinstance(st[-1], SelectorMixin):
349 mask = st[-1].get_support()
350 feature_names = feature_names[mask]
351 estimator = estimator.steps[-1][-1]
352
353 if hasattr(estimator, 'coef_'):
354 coefs = estimator.coef_
355 else:
356 coefs = getattr(estimator, 'feature_importances_', None)
357 if coefs is None:
358 raise RuntimeError('The classifier does not expose '
359 '"coef_" or "feature_importances_" '
360 'attributes')
361
362 threshold = params['plotting_selection']['threshold']
363 if threshold is not None:
364 mask = (coefs > threshold) | (coefs < -threshold)
365 coefs = coefs[mask]
366 feature_names = feature_names[mask]
367
368 # sort
369 indices = np.argsort(coefs)[::-1]
370
371 trace = go.Bar(x=feature_names[indices],
372 y=coefs[indices])
373 layout = go.Layout(title=title or "Feature Importances")
374 fig = go.Figure(data=[trace], layout=layout)
375
376 plotly.offline.plot(fig, filename="output.html",
377 auto_open=False)
378 # to be discovered by `from_work_dir`
379 os.rename('output.html', 'output')
380
381 return 0
382
383 elif plot_type in ('pr_curve', 'roc_curve'):
384 df1 = pd.read_csv(infile1, sep='\t', header='infer')
385 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32)
386
387 minimum = params['plotting_selection']['report_minimum_n_positives']
388 # filter out columns whose n_positives is beblow the threhold
389 if minimum:
390 mask = df1.sum(axis=0) >= minimum
391 df1 = df1.loc[:, mask]
392 df2 = df2.loc[:, mask]
393
394 pos_label = params['plotting_selection']['pos_label'].strip() \
395 or None
396
397 if plot_type == 'pr_curve':
398 if plot_format == 'plotly_html':
399 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
400 else:
401 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
402 else: # 'roc_curve'
403 drop_intermediate = (params['plotting_selection']
404 ['drop_intermediate'])
405 if plot_format == 'plotly_html':
406 visualize_roc_curve_plotly(df1, df2, pos_label,
407 drop_intermediate=drop_intermediate,
408 title=title)
409 else:
410 visualize_roc_curve_matplotlib(
411 df1, df2, pos_label,
412 drop_intermediate=drop_intermediate,
413 title=title)
414
415 return 0
416
417 elif plot_type == 'rfecv_gridscores':
418 input_df = pd.read_csv(infile1, sep='\t', header='infer')
419 scores = input_df.iloc[:, 0]
420 steps = params['plotting_selection']['steps'].strip()
421 steps = safe_eval(steps)
422
423 data = go.Scatter(
424 x=list(range(len(scores))),
425 y=scores,
426 text=[str(_) for _ in steps] if steps else None,
427 mode='lines'
428 )
429 layout = go.Layout(
430 xaxis=dict(title="Number of features selected"),
431 yaxis=dict(title="Cross validation score"),
432 title=dict(
433 text=title or None,
434 x=0.5,
435 y=0.92,
436 xanchor='center',
437 yanchor='top'
438 ),
439 font=dict(
440 family="sans-serif",
441 size=11
442 ),
443 # control backgroud colors
444 plot_bgcolor='rgba(255,255,255,0)'
445 )
446 """
447 # legend=dict(
448 # x=0.95,
449 # y=0,
450 # traceorder="normal",
451 # font=dict(
452 # family="sans-serif",
453 # size=9,
454 # color="black"
455 # ),
456 # bgcolor="LightSteelBlue",
457 # bordercolor="Black",
458 # borderwidth=2
459 # ),
460 """
461
462 fig = go.Figure(data=[data], layout=layout)
463 plotly.offline.plot(fig, filename="output.html",
464 auto_open=False)
465 # to be discovered by `from_work_dir`
466 os.rename('output.html', 'output')
467
468 return 0
469
470 elif plot_type == 'learning_curve':
471 input_df = pd.read_csv(infile1, sep='\t', header='infer')
472 plot_std_err = params['plotting_selection']['plot_std_err']
473 data1 = go.Scatter(
474 x=input_df['train_sizes_abs'],
475 y=input_df['mean_train_scores'],
476 error_y=dict(
477 array=input_df['std_train_scores']
478 ) if plot_std_err else None,
479 mode='lines',
480 name="Train Scores",
481 )
482 data2 = go.Scatter(
483 x=input_df['train_sizes_abs'],
484 y=input_df['mean_test_scores'],
485 error_y=dict(
486 array=input_df['std_test_scores']
487 ) if plot_std_err else None,
488 mode='lines',
489 name="Test Scores",
490 )
491 layout = dict(
492 xaxis=dict(
493 title='No. of samples'
494 ),
495 yaxis=dict(
496 title='Performance Score'
497 ),
498 # modify these configurations to customize image
499 title=dict(
500 text=title or 'Learning Curve',
501 x=0.5,
502 y=0.92,
503 xanchor='center',
504 yanchor='top'
505 ),
506 font=dict(
507 family="sans-serif",
508 size=11
509 ),
510 # control backgroud colors
511 plot_bgcolor='rgba(255,255,255,0)'
512 )
513 """
514 # legend=dict(
515 # x=0.95,
516 # y=0,
517 # traceorder="normal",
518 # font=dict(
519 # family="sans-serif",
520 # size=9,
521 # color="black"
522 # ),
523 # bgcolor="LightSteelBlue",
524 # bordercolor="Black",
525 # borderwidth=2
526 # ),
527 """
528
529 fig = go.Figure(data=[data1, data2], layout=layout)
530 plotly.offline.plot(fig, filename="output.html",
531 auto_open=False)
532 # to be discovered by `from_work_dir`
533 os.rename('output.html', 'output')
534
535 return 0
536
537 elif plot_type == 'keras_plot_model':
538 with open(model_config, 'r') as f:
539 model_str = f.read()
540 model = model_from_json(model_str)
541 plot_model(model, to_file="output.png")
542 os.rename('output.png', 'output')
543
544 return 0
545
546 # save pdf file to disk
547 # fig.write_image("image.pdf", format='pdf')
548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
549
550
551 if __name__ == '__main__':
552 aparser = argparse.ArgumentParser()
553 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
554 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
555 aparser.add_argument("-X", "--infile1", dest="infile1")
556 aparser.add_argument("-y", "--infile2", dest="infile2")
557 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
558 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
559 aparser.add_argument("-g", "--groups", dest="groups")
560 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
561 aparser.add_argument("-b", "--intervals", dest="intervals")
562 aparser.add_argument("-t", "--targets", dest="targets")
563 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
564 aparser.add_argument("-c", "--model_config", dest="model_config")
565 args = aparser.parse_args()
566
567 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
568 args.outfile_result, outfile_object=args.outfile_object,
569 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
570 targets=args.targets, fasta_path=args.fasta_path,
571 model_config=args.model_config)