comparison ml_visualization_ex.py @ 0:2d7016b3ae92 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2afb24f3c81d625312186750a714d702363012b5"
author bgruening
date Fri, 02 Oct 2020 08:45:21 +0000
parents
children 132805688fa3
comparison
equal deleted inserted replaced
-1:000000000000 0:2d7016b3ae92
1 import argparse
2 import json
3 import matplotlib
4 import matplotlib.pyplot as plt
5 import numpy as np
6 import os
7 import pandas as pd
8 import plotly
9 import plotly.graph_objs as go
10 import warnings
11
12 from keras.models import model_from_json
13 from keras.utils import plot_model
14 from sklearn.feature_selection.base import SelectorMixin
15 from sklearn.metrics import precision_recall_curve, average_precision_score
16 from sklearn.metrics import roc_curve, auc, confusion_matrix
17 from sklearn.pipeline import Pipeline
18 from galaxy_ml.utils import load_model, read_columns, SafeEval
19
20
21 safe_eval = SafeEval()
22
23 # plotly default colors
24 default_colors = [
25 '#1f77b4', # muted blue
26 '#ff7f0e', # safety orange
27 '#2ca02c', # cooked asparagus green
28 '#d62728', # brick red
29 '#9467bd', # muted purple
30 '#8c564b', # chestnut brown
31 '#e377c2', # raspberry yogurt pink
32 '#7f7f7f', # middle gray
33 '#bcbd22', # curry yellow-green
34 '#17becf' # blue-teal
35 ]
36
37
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly
40
41 df1 : pandas.DataFrame
42 Containing y_true
43 df2 : pandas.DataFrame
44 Containing y_score
45 pos_label : None
46 The label of positive class
47 title : str
48 Plot title
49 """
50 data = []
51 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values
54
55 precision, recall, _ = precision_recall_curve(
56 y_true, y_score, pos_label=pos_label)
57 ap = average_precision_score(
58 y_true, y_score, pos_label=pos_label or 1)
59
60 trace = go.Scatter(
61 x=recall,
62 y=precision,
63 mode='lines',
64 marker=dict(
65 color=default_colors[idx % len(default_colors)]
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 )
69 data.append(trace)
70
71 layout = go.Layout(
72 xaxis=dict(
73 title='Recall',
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict(
83 text=title or 'Precision-Recall Curve',
84 x=0.5,
85 y=0.92,
86 xanchor='center',
87 yanchor='top'
88 ),
89 font=dict(
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)'
95 )
96 """
97 legend=dict(
98 x=0.95,
99 y=0,
100 traceorder="normal",
101 font=dict(
102 family="sans-serif",
103 size=9,
104 color="black"
105 ),
106 bgcolor="LightSteelBlue",
107 bordercolor="Black",
108 borderwidth=2
109 ),"""
110
111 fig = go.Figure(data=data, layout=layout)
112
113 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output')
116
117
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image
120 """
121 backend = matplotlib.get_backend()
122 if "inline" not in backend:
123 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind')
125 plt.figure()
126
127 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values
130
131 precision, recall, _ = precision_recall_curve(
132 y_true, y_score, pos_label=pos_label)
133 ap = average_precision_score(
134 y_true, y_score, pos_label=pos_label or 1)
135
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3,
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap))
138
139 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall')
142 plt.ylabel('Precision')
143 title = title or 'Precision-Recall Curve'
144 plt.title(title)
145 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"),
148 os.path.join(folder, "output"))
149
150
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly
155
156 df1 : pandas.DataFrame
157 Containing y_true
158 df2 : pandas.DataFrame
159 Containing y_score
160 pos_label : None
161 The label of positive class
162 drop_intermediate : bool
163 Whether to drop some suboptimal thresholds
164 title : str
165 Plot title
166 """
167 data = []
168 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values
171
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
173 drop_intermediate=drop_intermediate)
174 roc_auc = auc(fpr, tpr)
175
176 trace = go.Scatter(
177 x=fpr,
178 y=tpr,
179 mode='lines',
180 marker=dict(
181 color=default_colors[idx % len(default_colors)]
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 )
185 data.append(trace)
186
187 layout = go.Layout(
188 xaxis=dict(
189 title='False Positive Rate',
190 linecolor='lightslategray',
191 linewidth=1
192 ),
193 yaxis=dict(
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve',
200 x=0.5,
201 y=0.92,
202 xanchor='center',
203 yanchor='top'
204 ),
205 font=dict(
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)'
211 )
212 """
213 # legend=dict(
214 # x=0.95,
215 # y=0,
216 # traceorder="normal",
217 # font=dict(
218 # family="sans-serif",
219 # size=9,
220 # color="black"
221 # ),
222 # bgcolor="LightSteelBlue",
223 # bordercolor="Black",
224 # borderwidth=2
225 # ),
226 """
227
228 fig = go.Figure(data=data, layout=layout)
229
230 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output')
233
234
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label,
236 drop_intermediate=True,
237 title=None):
238 """visualize roc-curve using matplotlib and output svg image
239 """
240 backend = matplotlib.get_backend()
241 if "inline" not in backend:
242 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind')
244 plt.figure()
245
246 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values
249
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
251 drop_intermediate=drop_intermediate)
252 roc_auc = auc(fpr, tpr)
253
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1,
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc))
256
257 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate')
260 plt.ylabel('True Positive Rate')
261 title = title or 'Receiver Operating Characteristic (ROC) Curve'
262 plt.title(title)
263 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"),
266 os.path.join(folder, "output"))
267
268
269 def get_dataframe(file_path, plot_selection, header_name, column_name):
270 header = 'infer' if plot_selection[header_name] else None
271 column_option = plot_selection[column_name]["selected_column_selector_option"]
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]:
273 col = plot_selection[column_name]["col1"]
274 else:
275 col = None
276 _, input_df = read_columns(file_path, c=col,
277 c_option=column_option,
278 return_df=True,
279 sep='\t', header=header,
280 parse_dates=True)
281 return input_df
282
283
284 def main(inputs, infile_estimator=None, infile1=None,
285 infile2=None, outfile_result=None,
286 outfile_object=None, groups=None,
287 ref_seq=None, intervals=None,
288 targets=None, fasta_path=None,
289 model_config=None, true_labels=None,
290 predicted_labels=None, plot_color=None,
291 title=None):
292 """
293 Parameter
294 ---------
295 inputs : str
296 File path to galaxy tool parameter
297
298 infile_estimator : str, default is None
299 File path to estimator
300
301 infile1 : str, default is None
302 File path to dataset containing features or true labels.
303
304 infile2 : str, default is None
305 File path to dataset containing target values or predicted
306 probabilities.
307
308 outfile_result : str, default is None
309 File path to save the results, either cv_results or test result
310
311 outfile_object : str, default is None
312 File path to save searchCV object
313
314 groups : str, default is None
315 File path to dataset containing groups labels
316
317 ref_seq : str, default is None
318 File path to dataset containing genome sequence file
319
320 intervals : str, default is None
321 File path to dataset containing interval file
322
323 targets : str, default is None
324 File path to dataset compressed target bed file
325
326 fasta_path : str, default is None
327 File path to dataset containing fasta file
328
329 model_config : str, default is None
330 File path to dataset containing JSON config for neural networks
331
332 true_labels : str, default is None
333 File path to dataset containing true labels
334
335 predicted_labels : str, default is None
336 File path to dataset containing true predicted labels
337
338 plot_color : str, default is None
339 Color of the confusion matrix heatmap
340
341 title : str, default is None
342 Title of the confusion matrix heatmap
343 """
344 warnings.simplefilter('ignore')
345
346 with open(inputs, 'r') as param_handler:
347 params = json.load(param_handler)
348
349 title = params['plotting_selection']['title'].strip()
350 plot_type = params['plotting_selection']['plot_type']
351 plot_format = params['plotting_selection']['plot_format']
352
353 if plot_type == 'feature_importances':
354 with open(infile_estimator, 'rb') as estimator_handler:
355 estimator = load_model(estimator_handler)
356
357 column_option = (params['plotting_selection']
358 ['column_selector_options']
359 ['selected_column_selector_option'])
360 if column_option in ['by_index_number', 'all_but_by_index_number',
361 'by_header_name', 'all_but_by_header_name']:
362 c = (params['plotting_selection']
363 ['column_selector_options']['col1'])
364 else:
365 c = None
366
367 _, input_df = read_columns(infile1, c=c,
368 c_option=column_option,
369 return_df=True,
370 sep='\t', header='infer',
371 parse_dates=True)
372
373 feature_names = input_df.columns.values
374
375 if isinstance(estimator, Pipeline):
376 for st in estimator.steps[:-1]:
377 if isinstance(st[-1], SelectorMixin):
378 mask = st[-1].get_support()
379 feature_names = feature_names[mask]
380 estimator = estimator.steps[-1][-1]
381
382 if hasattr(estimator, 'coef_'):
383 coefs = estimator.coef_
384 else:
385 coefs = getattr(estimator, 'feature_importances_', None)
386 if coefs is None:
387 raise RuntimeError('The classifier does not expose '
388 '"coef_" or "feature_importances_" '
389 'attributes')
390
391 threshold = params['plotting_selection']['threshold']
392 if threshold is not None:
393 mask = (coefs > threshold) | (coefs < -threshold)
394 coefs = coefs[mask]
395 feature_names = feature_names[mask]
396
397 # sort
398 indices = np.argsort(coefs)[::-1]
399
400 trace = go.Bar(x=feature_names[indices],
401 y=coefs[indices])
402 layout = go.Layout(title=title or "Feature Importances")
403 fig = go.Figure(data=[trace], layout=layout)
404
405 plotly.offline.plot(fig, filename="output.html",
406 auto_open=False)
407 # to be discovered by `from_work_dir`
408 os.rename('output.html', 'output')
409
410 return 0
411
412 elif plot_type in ('pr_curve', 'roc_curve'):
413 df1 = pd.read_csv(infile1, sep='\t', header='infer')
414 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32)
415
416 minimum = params['plotting_selection']['report_minimum_n_positives']
417 # filter out columns whose n_positives is beblow the threhold
418 if minimum:
419 mask = df1.sum(axis=0) >= minimum
420 df1 = df1.loc[:, mask]
421 df2 = df2.loc[:, mask]
422
423 pos_label = params['plotting_selection']['pos_label'].strip() \
424 or None
425
426 if plot_type == 'pr_curve':
427 if plot_format == 'plotly_html':
428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
429 else:
430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
431 else: # 'roc_curve'
432 drop_intermediate = (params['plotting_selection']
433 ['drop_intermediate'])
434 if plot_format == 'plotly_html':
435 visualize_roc_curve_plotly(df1, df2, pos_label,
436 drop_intermediate=drop_intermediate,
437 title=title)
438 else:
439 visualize_roc_curve_matplotlib(
440 df1, df2, pos_label,
441 drop_intermediate=drop_intermediate,
442 title=title)
443
444 return 0
445
446 elif plot_type == 'rfecv_gridscores':
447 input_df = pd.read_csv(infile1, sep='\t', header='infer')
448 scores = input_df.iloc[:, 0]
449 steps = params['plotting_selection']['steps'].strip()
450 steps = safe_eval(steps)
451
452 data = go.Scatter(
453 x=list(range(len(scores))),
454 y=scores,
455 text=[str(_) for _ in steps] if steps else None,
456 mode='lines'
457 )
458 layout = go.Layout(
459 xaxis=dict(title="Number of features selected"),
460 yaxis=dict(title="Cross validation score"),
461 title=dict(
462 text=title or None,
463 x=0.5,
464 y=0.92,
465 xanchor='center',
466 yanchor='top'
467 ),
468 font=dict(
469 family="sans-serif",
470 size=11
471 ),
472 # control backgroud colors
473 plot_bgcolor='rgba(255,255,255,0)'
474 )
475 """
476 # legend=dict(
477 # x=0.95,
478 # y=0,
479 # traceorder="normal",
480 # font=dict(
481 # family="sans-serif",
482 # size=9,
483 # color="black"
484 # ),
485 # bgcolor="LightSteelBlue",
486 # bordercolor="Black",
487 # borderwidth=2
488 # ),
489 """
490
491 fig = go.Figure(data=[data], layout=layout)
492 plotly.offline.plot(fig, filename="output.html",
493 auto_open=False)
494 # to be discovered by `from_work_dir`
495 os.rename('output.html', 'output')
496
497 return 0
498
499 elif plot_type == 'learning_curve':
500 input_df = pd.read_csv(infile1, sep='\t', header='infer')
501 plot_std_err = params['plotting_selection']['plot_std_err']
502 data1 = go.Scatter(
503 x=input_df['train_sizes_abs'],
504 y=input_df['mean_train_scores'],
505 error_y=dict(
506 array=input_df['std_train_scores']
507 ) if plot_std_err else None,
508 mode='lines',
509 name="Train Scores",
510 )
511 data2 = go.Scatter(
512 x=input_df['train_sizes_abs'],
513 y=input_df['mean_test_scores'],
514 error_y=dict(
515 array=input_df['std_test_scores']
516 ) if plot_std_err else None,
517 mode='lines',
518 name="Test Scores",
519 )
520 layout = dict(
521 xaxis=dict(
522 title='No. of samples'
523 ),
524 yaxis=dict(
525 title='Performance Score'
526 ),
527 # modify these configurations to customize image
528 title=dict(
529 text=title or 'Learning Curve',
530 x=0.5,
531 y=0.92,
532 xanchor='center',
533 yanchor='top'
534 ),
535 font=dict(
536 family="sans-serif",
537 size=11
538 ),
539 # control backgroud colors
540 plot_bgcolor='rgba(255,255,255,0)'
541 )
542 """
543 # legend=dict(
544 # x=0.95,
545 # y=0,
546 # traceorder="normal",
547 # font=dict(
548 # family="sans-serif",
549 # size=9,
550 # color="black"
551 # ),
552 # bgcolor="LightSteelBlue",
553 # bordercolor="Black",
554 # borderwidth=2
555 # ),
556 """
557
558 fig = go.Figure(data=[data1, data2], layout=layout)
559 plotly.offline.plot(fig, filename="output.html",
560 auto_open=False)
561 # to be discovered by `from_work_dir`
562 os.rename('output.html', 'output')
563
564 return 0
565
566 elif plot_type == 'keras_plot_model':
567 with open(model_config, 'r') as f:
568 model_str = f.read()
569 model = model_from_json(model_str)
570 plot_model(model, to_file="output.png")
571 os.rename('output.png', 'output')
572
573 return 0
574
575 elif plot_type == 'classification_confusion_matrix':
576 plot_selection = params["plotting_selection"]
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
578 header_predicted = 'infer' if plot_selection["header_predicted"] else None
579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted)
580 true_classes = input_true.iloc[:, -1].copy()
581 predicted_classes = input_predicted.iloc[:, -1].copy()
582 axis_labels = list(set(true_classes))
583 c_matrix = confusion_matrix(true_classes, predicted_classes)
584 fig, ax = plt.subplots(figsize=(7, 7))
585 im = plt.imshow(c_matrix, cmap=plot_color)
586 for i in range(len(c_matrix)):
587 for j in range(len(c_matrix)):
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
589 ax.set_ylabel('True class labels')
590 ax.set_xlabel('Predicted class labels')
591 ax.set_title(title)
592 ax.set_xticks(axis_labels)
593 ax.set_yticks(axis_labels)
594 fig.colorbar(im, ax=ax)
595 fig.tight_layout()
596 plt.savefig("output.png", dpi=125)
597 os.rename('output.png', 'output')
598
599 return 0
600
601 # save pdf file to disk
602 # fig.write_image("image.pdf", format='pdf')
603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
604
605
606 if __name__ == '__main__':
607 aparser = argparse.ArgumentParser()
608 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
609 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
610 aparser.add_argument("-X", "--infile1", dest="infile1")
611 aparser.add_argument("-y", "--infile2", dest="infile2")
612 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
613 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
614 aparser.add_argument("-g", "--groups", dest="groups")
615 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
616 aparser.add_argument("-b", "--intervals", dest="intervals")
617 aparser.add_argument("-t", "--targets", dest="targets")
618 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
619 aparser.add_argument("-c", "--model_config", dest="model_config")
620 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
622 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
623 aparser.add_argument("-pt", "--title", dest="title")
624 args = aparser.parse_args()
625
626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
627 args.outfile_result, outfile_object=args.outfile_object,
628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
629 targets=args.targets, fasta_path=args.fasta_path,
630 model_config=args.model_config, true_labels=args.true_labels,
631 predicted_labels=args.predicted_labels,
632 plot_color=args.plot_color,
633 title=args.title)