comparison ml_visualization_ex.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
comparison
equal deleted inserted replaced
-1:000000000000 0:af2624d5ab32
1 import argparse
2 import json
3 import os
4 import warnings
5
6 import matplotlib
7 import matplotlib.pyplot as plt
8 import numpy as np
9 import pandas as pd
10 import plotly
11 import plotly.graph_objs as go
12 from galaxy_ml.utils import load_model, read_columns, SafeEval
13 from keras.models import model_from_json
14 from keras.utils import plot_model
15 from sklearn.feature_selection.base import SelectorMixin
16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix,
17 precision_recall_curve, roc_curve)
18 from sklearn.pipeline import Pipeline
19
20 safe_eval = SafeEval()
21
22 # plotly default colors
23 default_colors = [
24 "#1f77b4", # muted blue
25 "#ff7f0e", # safety orange
26 "#2ca02c", # cooked asparagus green
27 "#d62728", # brick red
28 "#9467bd", # muted purple
29 "#8c564b", # chestnut brown
30 "#e377c2", # raspberry yogurt pink
31 "#7f7f7f", # middle gray
32 "#bcbd22", # curry yellow-green
33 "#17becf", # blue-teal
34 ]
35
36
37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
38 """output pr-curve in html using plotly
39
40 df1 : pandas.DataFrame
41 Containing y_true
42 df2 : pandas.DataFrame
43 Containing y_score
44 pos_label : None
45 The label of positive class
46 title : str
47 Plot title
48 """
49 data = []
50 for idx in range(df1.shape[1]):
51 y_true = df1.iloc[:, idx].values
52 y_score = df2.iloc[:, idx].values
53
54 precision, recall, _ = precision_recall_curve(
55 y_true, y_score, pos_label=pos_label
56 )
57 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
58
59 trace = go.Scatter(
60 x=recall,
61 y=precision,
62 mode="lines",
63 marker=dict(color=default_colors[idx % len(default_colors)]),
64 name="%s (area = %.3f)" % (idx, ap),
65 )
66 data.append(trace)
67
68 layout = go.Layout(
69 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
70 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
71 title=dict(
72 text=title or "Precision-Recall Curve",
73 x=0.5,
74 y=0.92,
75 xanchor="center",
76 yanchor="top",
77 ),
78 font=dict(family="sans-serif", size=11),
79 # control backgroud colors
80 plot_bgcolor="rgba(255,255,255,0)",
81 )
82 """
83 legend=dict(
84 x=0.95,
85 y=0,
86 traceorder="normal",
87 font=dict(
88 family="sans-serif",
89 size=9,
90 color="black"
91 ),
92 bgcolor="LightSteelBlue",
93 bordercolor="Black",
94 borderwidth=2
95 ),"""
96
97 fig = go.Figure(data=data, layout=layout)
98
99 plotly.offline.plot(fig, filename="output.html", auto_open=False)
100 # to be discovered by `from_work_dir`
101 os.rename("output.html", "output")
102
103
104 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
105 """visualize pr-curve using matplotlib and output svg image"""
106 backend = matplotlib.get_backend()
107 if "inline" not in backend:
108 matplotlib.use("SVG")
109 plt.style.use("seaborn-colorblind")
110 plt.figure()
111
112 for idx in range(df1.shape[1]):
113 y_true = df1.iloc[:, idx].values
114 y_score = df2.iloc[:, idx].values
115
116 precision, recall, _ = precision_recall_curve(
117 y_true, y_score, pos_label=pos_label
118 )
119 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
120
121 plt.step(
122 recall,
123 precision,
124 "r-",
125 color="black",
126 alpha=0.3,
127 lw=1,
128 where="post",
129 label="%s (area = %.3f)" % (idx, ap),
130 )
131
132 plt.xlim([0.0, 1.0])
133 plt.ylim([0.0, 1.05])
134 plt.xlabel("Recall")
135 plt.ylabel("Precision")
136 title = title or "Precision-Recall Curve"
137 plt.title(title)
138 folder = os.getcwd()
139 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
140 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
141
142
143 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
144 """output roc-curve in html using plotly
145
146 df1 : pandas.DataFrame
147 Containing y_true
148 df2 : pandas.DataFrame
149 Containing y_score
150 pos_label : None
151 The label of positive class
152 drop_intermediate : bool
153 Whether to drop some suboptimal thresholds
154 title : str
155 Plot title
156 """
157 data = []
158 for idx in range(df1.shape[1]):
159 y_true = df1.iloc[:, idx].values
160 y_score = df2.iloc[:, idx].values
161
162 fpr, tpr, _ = roc_curve(
163 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
164 )
165 roc_auc = auc(fpr, tpr)
166
167 trace = go.Scatter(
168 x=fpr,
169 y=tpr,
170 mode="lines",
171 marker=dict(color=default_colors[idx % len(default_colors)]),
172 name="%s (area = %.3f)" % (idx, roc_auc),
173 )
174 data.append(trace)
175
176 layout = go.Layout(
177 xaxis=dict(
178 title="False Positive Rate", linecolor="lightslategray", linewidth=1
179 ),
180 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
181 title=dict(
182 text=title or "Receiver Operating Characteristic (ROC) Curve",
183 x=0.5,
184 y=0.92,
185 xanchor="center",
186 yanchor="top",
187 ),
188 font=dict(family="sans-serif", size=11),
189 # control backgroud colors
190 plot_bgcolor="rgba(255,255,255,0)",
191 )
192 """
193 # legend=dict(
194 # x=0.95,
195 # y=0,
196 # traceorder="normal",
197 # font=dict(
198 # family="sans-serif",
199 # size=9,
200 # color="black"
201 # ),
202 # bgcolor="LightSteelBlue",
203 # bordercolor="Black",
204 # borderwidth=2
205 # ),
206 """
207
208 fig = go.Figure(data=data, layout=layout)
209
210 plotly.offline.plot(fig, filename="output.html", auto_open=False)
211 # to be discovered by `from_work_dir`
212 os.rename("output.html", "output")
213
214
215 def visualize_roc_curve_matplotlib(
216 df1, df2, pos_label, drop_intermediate=True, title=None
217 ):
218 """visualize roc-curve using matplotlib and output svg image"""
219 backend = matplotlib.get_backend()
220 if "inline" not in backend:
221 matplotlib.use("SVG")
222 plt.style.use("seaborn-colorblind")
223 plt.figure()
224
225 for idx in range(df1.shape[1]):
226 y_true = df1.iloc[:, idx].values
227 y_score = df2.iloc[:, idx].values
228
229 fpr, tpr, _ = roc_curve(
230 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
231 )
232 roc_auc = auc(fpr, tpr)
233
234 plt.step(
235 fpr,
236 tpr,
237 "r-",
238 color="black",
239 alpha=0.3,
240 lw=1,
241 where="post",
242 label="%s (area = %.3f)" % (idx, roc_auc),
243 )
244
245 plt.xlim([0.0, 1.0])
246 plt.ylim([0.0, 1.05])
247 plt.xlabel("False Positive Rate")
248 plt.ylabel("True Positive Rate")
249 title = title or "Receiver Operating Characteristic (ROC) Curve"
250 plt.title(title)
251 folder = os.getcwd()
252 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
253 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
254
255
256 def get_dataframe(file_path, plot_selection, header_name, column_name):
257 header = "infer" if plot_selection[header_name] else None
258 column_option = plot_selection[column_name]["selected_column_selector_option"]
259 if column_option in [
260 "by_index_number",
261 "all_but_by_index_number",
262 "by_header_name",
263 "all_but_by_header_name",
264 ]:
265 col = plot_selection[column_name]["col1"]
266 else:
267 col = None
268 _, input_df = read_columns(
269 file_path,
270 c=col,
271 c_option=column_option,
272 return_df=True,
273 sep="\t",
274 header=header,
275 parse_dates=True,
276 )
277 return input_df
278
279
280 def main(
281 inputs,
282 infile_estimator=None,
283 infile1=None,
284 infile2=None,
285 outfile_result=None,
286 outfile_object=None,
287 groups=None,
288 ref_seq=None,
289 intervals=None,
290 targets=None,
291 fasta_path=None,
292 model_config=None,
293 true_labels=None,
294 predicted_labels=None,
295 plot_color=None,
296 title=None,
297 ):
298 """
299 Parameter
300 ---------
301 inputs : str
302 File path to galaxy tool parameter
303
304 infile_estimator : str, default is None
305 File path to estimator
306
307 infile1 : str, default is None
308 File path to dataset containing features or true labels.
309
310 infile2 : str, default is None
311 File path to dataset containing target values or predicted
312 probabilities.
313
314 outfile_result : str, default is None
315 File path to save the results, either cv_results or test result
316
317 outfile_object : str, default is None
318 File path to save searchCV object
319
320 groups : str, default is None
321 File path to dataset containing groups labels
322
323 ref_seq : str, default is None
324 File path to dataset containing genome sequence file
325
326 intervals : str, default is None
327 File path to dataset containing interval file
328
329 targets : str, default is None
330 File path to dataset compressed target bed file
331
332 fasta_path : str, default is None
333 File path to dataset containing fasta file
334
335 model_config : str, default is None
336 File path to dataset containing JSON config for neural networks
337
338 true_labels : str, default is None
339 File path to dataset containing true labels
340
341 predicted_labels : str, default is None
342 File path to dataset containing true predicted labels
343
344 plot_color : str, default is None
345 Color of the confusion matrix heatmap
346
347 title : str, default is None
348 Title of the confusion matrix heatmap
349 """
350 warnings.simplefilter("ignore")
351
352 with open(inputs, "r") as param_handler:
353 params = json.load(param_handler)
354
355 title = params["plotting_selection"]["title"].strip()
356 plot_type = params["plotting_selection"]["plot_type"]
357 plot_format = params["plotting_selection"]["plot_format"]
358
359 if plot_type == "feature_importances":
360 with open(infile_estimator, "rb") as estimator_handler:
361 estimator = load_model(estimator_handler)
362
363 column_option = params["plotting_selection"]["column_selector_options"][
364 "selected_column_selector_option"
365 ]
366 if column_option in [
367 "by_index_number",
368 "all_but_by_index_number",
369 "by_header_name",
370 "all_but_by_header_name",
371 ]:
372 c = params["plotting_selection"]["column_selector_options"]["col1"]
373 else:
374 c = None
375
376 _, input_df = read_columns(
377 infile1,
378 c=c,
379 c_option=column_option,
380 return_df=True,
381 sep="\t",
382 header="infer",
383 parse_dates=True,
384 )
385
386 feature_names = input_df.columns.values
387
388 if isinstance(estimator, Pipeline):
389 for st in estimator.steps[:-1]:
390 if isinstance(st[-1], SelectorMixin):
391 mask = st[-1].get_support()
392 feature_names = feature_names[mask]
393 estimator = estimator.steps[-1][-1]
394
395 if hasattr(estimator, "coef_"):
396 coefs = estimator.coef_
397 else:
398 coefs = getattr(estimator, "feature_importances_", None)
399 if coefs is None:
400 raise RuntimeError(
401 "The classifier does not expose "
402 '"coef_" or "feature_importances_" '
403 "attributes"
404 )
405
406 threshold = params["plotting_selection"]["threshold"]
407 if threshold is not None:
408 mask = (coefs > threshold) | (coefs < -threshold)
409 coefs = coefs[mask]
410 feature_names = feature_names[mask]
411
412 # sort
413 indices = np.argsort(coefs)[::-1]
414
415 trace = go.Bar(x=feature_names[indices], y=coefs[indices])
416 layout = go.Layout(title=title or "Feature Importances")
417 fig = go.Figure(data=[trace], layout=layout)
418
419 plotly.offline.plot(fig, filename="output.html", auto_open=False)
420 # to be discovered by `from_work_dir`
421 os.rename("output.html", "output")
422
423 return 0
424
425 elif plot_type in ("pr_curve", "roc_curve"):
426 df1 = pd.read_csv(infile1, sep="\t", header="infer")
427 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)
428
429 minimum = params["plotting_selection"]["report_minimum_n_positives"]
430 # filter out columns whose n_positives is beblow the threhold
431 if minimum:
432 mask = df1.sum(axis=0) >= minimum
433 df1 = df1.loc[:, mask]
434 df2 = df2.loc[:, mask]
435
436 pos_label = params["plotting_selection"]["pos_label"].strip() or None
437
438 if plot_type == "pr_curve":
439 if plot_format == "plotly_html":
440 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
441 else:
442 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
443 else: # 'roc_curve'
444 drop_intermediate = params["plotting_selection"]["drop_intermediate"]
445 if plot_format == "plotly_html":
446 visualize_roc_curve_plotly(
447 df1,
448 df2,
449 pos_label,
450 drop_intermediate=drop_intermediate,
451 title=title,
452 )
453 else:
454 visualize_roc_curve_matplotlib(
455 df1,
456 df2,
457 pos_label,
458 drop_intermediate=drop_intermediate,
459 title=title,
460 )
461
462 return 0
463
464 elif plot_type == "rfecv_gridscores":
465 input_df = pd.read_csv(infile1, sep="\t", header="infer")
466 scores = input_df.iloc[:, 0]
467 steps = params["plotting_selection"]["steps"].strip()
468 steps = safe_eval(steps)
469
470 data = go.Scatter(
471 x=list(range(len(scores))),
472 y=scores,
473 text=[str(_) for _ in steps] if steps else None,
474 mode="lines",
475 )
476 layout = go.Layout(
477 xaxis=dict(title="Number of features selected"),
478 yaxis=dict(title="Cross validation score"),
479 title=dict(
480 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"
481 ),
482 font=dict(family="sans-serif", size=11),
483 # control backgroud colors
484 plot_bgcolor="rgba(255,255,255,0)",
485 )
486 """
487 # legend=dict(
488 # x=0.95,
489 # y=0,
490 # traceorder="normal",
491 # font=dict(
492 # family="sans-serif",
493 # size=9,
494 # color="black"
495 # ),
496 # bgcolor="LightSteelBlue",
497 # bordercolor="Black",
498 # borderwidth=2
499 # ),
500 """
501
502 fig = go.Figure(data=[data], layout=layout)
503 plotly.offline.plot(fig, filename="output.html", auto_open=False)
504 # to be discovered by `from_work_dir`
505 os.rename("output.html", "output")
506
507 return 0
508
509 elif plot_type == "learning_curve":
510 input_df = pd.read_csv(infile1, sep="\t", header="infer")
511 plot_std_err = params["plotting_selection"]["plot_std_err"]
512 data1 = go.Scatter(
513 x=input_df["train_sizes_abs"],
514 y=input_df["mean_train_scores"],
515 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
516 mode="lines",
517 name="Train Scores",
518 )
519 data2 = go.Scatter(
520 x=input_df["train_sizes_abs"],
521 y=input_df["mean_test_scores"],
522 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
523 mode="lines",
524 name="Test Scores",
525 )
526 layout = dict(
527 xaxis=dict(title="No. of samples"),
528 yaxis=dict(title="Performance Score"),
529 # modify these configurations to customize image
530 title=dict(
531 text=title or "Learning Curve",
532 x=0.5,
533 y=0.92,
534 xanchor="center",
535 yanchor="top",
536 ),
537 font=dict(family="sans-serif", size=11),
538 # control backgroud colors
539 plot_bgcolor="rgba(255,255,255,0)",
540 )
541 """
542 # legend=dict(
543 # x=0.95,
544 # y=0,
545 # traceorder="normal",
546 # font=dict(
547 # family="sans-serif",
548 # size=9,
549 # color="black"
550 # ),
551 # bgcolor="LightSteelBlue",
552 # bordercolor="Black",
553 # borderwidth=2
554 # ),
555 """
556
557 fig = go.Figure(data=[data1, data2], layout=layout)
558 plotly.offline.plot(fig, filename="output.html", auto_open=False)
559 # to be discovered by `from_work_dir`
560 os.rename("output.html", "output")
561
562 return 0
563
564 elif plot_type == "keras_plot_model":
565 with open(model_config, "r") as f:
566 model_str = f.read()
567 model = model_from_json(model_str)
568 plot_model(model, to_file="output.png")
569 os.rename("output.png", "output")
570
571 return 0
572
573 elif plot_type == "classification_confusion_matrix":
574 plot_selection = params["plotting_selection"]
575 input_true = get_dataframe(
576 true_labels, plot_selection, "header_true", "column_selector_options_true"
577 )
578 header_predicted = "infer" if plot_selection["header_predicted"] else None
579 input_predicted = pd.read_csv(
580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted
581 )
582 true_classes = input_true.iloc[:, -1].copy()
583 predicted_classes = input_predicted.iloc[:, -1].copy()
584 axis_labels = list(set(true_classes))
585 c_matrix = confusion_matrix(true_classes, predicted_classes)
586 fig, ax = plt.subplots(figsize=(7, 7))
587 im = plt.imshow(c_matrix, cmap=plot_color)
588 for i in range(len(c_matrix)):
589 for j in range(len(c_matrix)):
590 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
591 ax.set_ylabel("True class labels")
592 ax.set_xlabel("Predicted class labels")
593 ax.set_title(title)
594 ax.set_xticks(axis_labels)
595 ax.set_yticks(axis_labels)
596 fig.colorbar(im, ax=ax)
597 fig.tight_layout()
598 plt.savefig("output.png", dpi=125)
599 os.rename("output.png", "output")
600
601 return 0
602
603 # save pdf file to disk
604 # fig.write_image("image.pdf", format='pdf')
605 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
606
607
608 if __name__ == "__main__":
609 aparser = argparse.ArgumentParser()
610 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
611 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
612 aparser.add_argument("-X", "--infile1", dest="infile1")
613 aparser.add_argument("-y", "--infile2", dest="infile2")
614 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
615 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
616 aparser.add_argument("-g", "--groups", dest="groups")
617 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
618 aparser.add_argument("-b", "--intervals", dest="intervals")
619 aparser.add_argument("-t", "--targets", dest="targets")
620 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
621 aparser.add_argument("-c", "--model_config", dest="model_config")
622 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
623 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
624 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
625 aparser.add_argument("-pt", "--title", dest="title")
626 args = aparser.parse_args()
627
628 main(
629 args.inputs,
630 args.infile_estimator,
631 args.infile1,
632 args.infile2,
633 args.outfile_result,
634 outfile_object=args.outfile_object,
635 groups=args.groups,
636 ref_seq=args.ref_seq,
637 intervals=args.intervals,
638 targets=args.targets,
639 fasta_path=args.fasta_path,
640 model_config=args.model_config,
641 true_labels=args.true_labels,
642 predicted_labels=args.predicted_labels,
643 plot_color=args.plot_color,
644 title=args.title,
645 )