comparison ml_visualization_ex.py @ 0:59e8b4328c82 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:40:10 +0000
parents
children f93f0cdbaf18
comparison
equal deleted inserted replaced
-1:000000000000 0:59e8b4328c82
1 import argparse
2 import json
3 import os
4 import warnings
5
6 import matplotlib
7 import matplotlib.pyplot as plt
8 import numpy as np
9 import pandas as pd
10 import plotly
11 import plotly.graph_objs as go
12 from galaxy_ml.utils import load_model, read_columns, SafeEval
13 from keras.models import model_from_json
14 from keras.utils import plot_model
15 from sklearn.feature_selection.base import SelectorMixin
16 from sklearn.metrics import auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve
17 from sklearn.pipeline import Pipeline
18
19
20 safe_eval = SafeEval()
21
22 # plotly default colors
23 default_colors = [
24 "#1f77b4", # muted blue
25 "#ff7f0e", # safety orange
26 "#2ca02c", # cooked asparagus green
27 "#d62728", # brick red
28 "#9467bd", # muted purple
29 "#8c564b", # chestnut brown
30 "#e377c2", # raspberry yogurt pink
31 "#7f7f7f", # middle gray
32 "#bcbd22", # curry yellow-green
33 "#17becf", # blue-teal
34 ]
35
36
37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
38 """output pr-curve in html using plotly
39
40 df1 : pandas.DataFrame
41 Containing y_true
42 df2 : pandas.DataFrame
43 Containing y_score
44 pos_label : None
45 The label of positive class
46 title : str
47 Plot title
48 """
49 data = []
50 for idx in range(df1.shape[1]):
51 y_true = df1.iloc[:, idx].values
52 y_score = df2.iloc[:, idx].values
53
54 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
55 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
56
57 trace = go.Scatter(
58 x=recall,
59 y=precision,
60 mode="lines",
61 marker=dict(color=default_colors[idx % len(default_colors)]),
62 name="%s (area = %.3f)" % (idx, ap),
63 )
64 data.append(trace)
65
66 layout = go.Layout(
67 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
68 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
69 title=dict(
70 text=title or "Precision-Recall Curve",
71 x=0.5,
72 y=0.92,
73 xanchor="center",
74 yanchor="top",
75 ),
76 font=dict(family="sans-serif", size=11),
77 # control backgroud colors
78 plot_bgcolor="rgba(255,255,255,0)",
79 )
80 """
81 legend=dict(
82 x=0.95,
83 y=0,
84 traceorder="normal",
85 font=dict(
86 family="sans-serif",
87 size=9,
88 color="black"
89 ),
90 bgcolor="LightSteelBlue",
91 bordercolor="Black",
92 borderwidth=2
93 ),"""
94
95 fig = go.Figure(data=data, layout=layout)
96
97 plotly.offline.plot(fig, filename="output.html", auto_open=False)
98 # to be discovered by `from_work_dir`
99 os.rename("output.html", "output")
100
101
102 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
103 """visualize pr-curve using matplotlib and output svg image"""
104 backend = matplotlib.get_backend()
105 if "inline" not in backend:
106 matplotlib.use("SVG")
107 plt.style.use("seaborn-colorblind")
108 plt.figure()
109
110 for idx in range(df1.shape[1]):
111 y_true = df1.iloc[:, idx].values
112 y_score = df2.iloc[:, idx].values
113
114 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
115 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
116
117 plt.step(
118 recall,
119 precision,
120 "r-",
121 color="black",
122 alpha=0.3,
123 lw=1,
124 where="post",
125 label="%s (area = %.3f)" % (idx, ap),
126 )
127
128 plt.xlim([0.0, 1.0])
129 plt.ylim([0.0, 1.05])
130 plt.xlabel("Recall")
131 plt.ylabel("Precision")
132 title = title or "Precision-Recall Curve"
133 plt.title(title)
134 folder = os.getcwd()
135 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
136 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
137
138
139 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
140 """output roc-curve in html using plotly
141
142 df1 : pandas.DataFrame
143 Containing y_true
144 df2 : pandas.DataFrame
145 Containing y_score
146 pos_label : None
147 The label of positive class
148 drop_intermediate : bool
149 Whether to drop some suboptimal thresholds
150 title : str
151 Plot title
152 """
153 data = []
154 for idx in range(df1.shape[1]):
155 y_true = df1.iloc[:, idx].values
156 y_score = df2.iloc[:, idx].values
157
158 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
159 roc_auc = auc(fpr, tpr)
160
161 trace = go.Scatter(
162 x=fpr,
163 y=tpr,
164 mode="lines",
165 marker=dict(color=default_colors[idx % len(default_colors)]),
166 name="%s (area = %.3f)" % (idx, roc_auc),
167 )
168 data.append(trace)
169
170 layout = go.Layout(
171 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1),
172 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
173 title=dict(
174 text=title or "Receiver Operating Characteristic (ROC) Curve",
175 x=0.5,
176 y=0.92,
177 xanchor="center",
178 yanchor="top",
179 ),
180 font=dict(family="sans-serif", size=11),
181 # control backgroud colors
182 plot_bgcolor="rgba(255,255,255,0)",
183 )
184 """
185 # legend=dict(
186 # x=0.95,
187 # y=0,
188 # traceorder="normal",
189 # font=dict(
190 # family="sans-serif",
191 # size=9,
192 # color="black"
193 # ),
194 # bgcolor="LightSteelBlue",
195 # bordercolor="Black",
196 # borderwidth=2
197 # ),
198 """
199
200 fig = go.Figure(data=data, layout=layout)
201
202 plotly.offline.plot(fig, filename="output.html", auto_open=False)
203 # to be discovered by `from_work_dir`
204 os.rename("output.html", "output")
205
206
207 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None):
208 """visualize roc-curve using matplotlib and output svg image"""
209 backend = matplotlib.get_backend()
210 if "inline" not in backend:
211 matplotlib.use("SVG")
212 plt.style.use("seaborn-colorblind")
213 plt.figure()
214
215 for idx in range(df1.shape[1]):
216 y_true = df1.iloc[:, idx].values
217 y_score = df2.iloc[:, idx].values
218
219 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
220 roc_auc = auc(fpr, tpr)
221
222 plt.step(
223 fpr,
224 tpr,
225 "r-",
226 color="black",
227 alpha=0.3,
228 lw=1,
229 where="post",
230 label="%s (area = %.3f)" % (idx, roc_auc),
231 )
232
233 plt.xlim([0.0, 1.0])
234 plt.ylim([0.0, 1.05])
235 plt.xlabel("False Positive Rate")
236 plt.ylabel("True Positive Rate")
237 title = title or "Receiver Operating Characteristic (ROC) Curve"
238 plt.title(title)
239 folder = os.getcwd()
240 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
241 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
242
243
244 def get_dataframe(file_path, plot_selection, header_name, column_name):
245 header = "infer" if plot_selection[header_name] else None
246 column_option = plot_selection[column_name]["selected_column_selector_option"]
247 if column_option in [
248 "by_index_number",
249 "all_but_by_index_number",
250 "by_header_name",
251 "all_but_by_header_name",
252 ]:
253 col = plot_selection[column_name]["col1"]
254 else:
255 col = None
256 _, input_df = read_columns(file_path, c=col,
257 c_option=column_option,
258 return_df=True,
259 sep='\t', header=header,
260 parse_dates=True)
261 return input_df
262
263
264 def main(
265 inputs,
266 infile_estimator=None,
267 infile1=None,
268 infile2=None,
269 outfile_result=None,
270 outfile_object=None,
271 groups=None,
272 ref_seq=None,
273 intervals=None,
274 targets=None,
275 fasta_path=None,
276 model_config=None,
277 true_labels=None,
278 predicted_labels=None,
279 plot_color=None,
280 title=None,
281 ):
282 """
283 Parameter
284 ---------
285 inputs : str
286 File path to galaxy tool parameter
287
288 infile_estimator : str, default is None
289 File path to estimator
290
291 infile1 : str, default is None
292 File path to dataset containing features or true labels.
293
294 infile2 : str, default is None
295 File path to dataset containing target values or predicted
296 probabilities.
297
298 outfile_result : str, default is None
299 File path to save the results, either cv_results or test result
300
301 outfile_object : str, default is None
302 File path to save searchCV object
303
304 groups : str, default is None
305 File path to dataset containing groups labels
306
307 ref_seq : str, default is None
308 File path to dataset containing genome sequence file
309
310 intervals : str, default is None
311 File path to dataset containing interval file
312
313 targets : str, default is None
314 File path to dataset compressed target bed file
315
316 fasta_path : str, default is None
317 File path to dataset containing fasta file
318
319 model_config : str, default is None
320 File path to dataset containing JSON config for neural networks
321
322 true_labels : str, default is None
323 File path to dataset containing true labels
324
325 predicted_labels : str, default is None
326 File path to dataset containing true predicted labels
327
328 plot_color : str, default is None
329 Color of the confusion matrix heatmap
330
331 title : str, default is None
332 Title of the confusion matrix heatmap
333 """
334 warnings.simplefilter("ignore")
335
336 with open(inputs, "r") as param_handler:
337 params = json.load(param_handler)
338
339 title = params["plotting_selection"]["title"].strip()
340 plot_type = params["plotting_selection"]["plot_type"]
341 plot_format = params["plotting_selection"]["plot_format"]
342
343 if plot_type == "feature_importances":
344 with open(infile_estimator, "rb") as estimator_handler:
345 estimator = load_model(estimator_handler)
346
347 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"]
348 if column_option in [
349 "by_index_number",
350 "all_but_by_index_number",
351 "by_header_name",
352 "all_but_by_header_name",
353 ]:
354 c = params["plotting_selection"]["column_selector_options"]["col1"]
355 else:
356 c = None
357
358 _, input_df = read_columns(
359 infile1,
360 c=c,
361 c_option=column_option,
362 return_df=True,
363 sep="\t",
364 header="infer",
365 parse_dates=True,
366 )
367
368 feature_names = input_df.columns.values
369
370 if isinstance(estimator, Pipeline):
371 for st in estimator.steps[:-1]:
372 if isinstance(st[-1], SelectorMixin):
373 mask = st[-1].get_support()
374 feature_names = feature_names[mask]
375 estimator = estimator.steps[-1][-1]
376
377 if hasattr(estimator, "coef_"):
378 coefs = estimator.coef_
379 else:
380 coefs = getattr(estimator, "feature_importances_", None)
381 if coefs is None:
382 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes")
383
384 threshold = params["plotting_selection"]["threshold"]
385 if threshold is not None:
386 mask = (coefs > threshold) | (coefs < -threshold)
387 coefs = coefs[mask]
388 feature_names = feature_names[mask]
389
390 # sort
391 indices = np.argsort(coefs)[::-1]
392
393 trace = go.Bar(x=feature_names[indices], y=coefs[indices])
394 layout = go.Layout(title=title or "Feature Importances")
395 fig = go.Figure(data=[trace], layout=layout)
396
397 plotly.offline.plot(fig, filename="output.html", auto_open=False)
398 # to be discovered by `from_work_dir`
399 os.rename("output.html", "output")
400
401 return 0
402
403 elif plot_type in ("pr_curve", "roc_curve"):
404 df1 = pd.read_csv(infile1, sep="\t", header="infer")
405 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)
406
407 minimum = params["plotting_selection"]["report_minimum_n_positives"]
408 # filter out columns whose n_positives is beblow the threhold
409 if minimum:
410 mask = df1.sum(axis=0) >= minimum
411 df1 = df1.loc[:, mask]
412 df2 = df2.loc[:, mask]
413
414 pos_label = params["plotting_selection"]["pos_label"].strip() or None
415
416 if plot_type == "pr_curve":
417 if plot_format == "plotly_html":
418 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
419 else:
420 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
421 else: # 'roc_curve'
422 drop_intermediate = params["plotting_selection"]["drop_intermediate"]
423 if plot_format == "plotly_html":
424 visualize_roc_curve_plotly(
425 df1,
426 df2,
427 pos_label,
428 drop_intermediate=drop_intermediate,
429 title=title,
430 )
431 else:
432 visualize_roc_curve_matplotlib(
433 df1,
434 df2,
435 pos_label,
436 drop_intermediate=drop_intermediate,
437 title=title,
438 )
439
440 return 0
441
442 elif plot_type == "rfecv_gridscores":
443 input_df = pd.read_csv(infile1, sep="\t", header="infer")
444 scores = input_df.iloc[:, 0]
445 steps = params["plotting_selection"]["steps"].strip()
446 steps = safe_eval(steps)
447
448 data = go.Scatter(
449 x=list(range(len(scores))),
450 y=scores,
451 text=[str(_) for _ in steps] if steps else None,
452 mode="lines",
453 )
454 layout = go.Layout(
455 xaxis=dict(title="Number of features selected"),
456 yaxis=dict(title="Cross validation score"),
457 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"),
458 font=dict(family="sans-serif", size=11),
459 # control backgroud colors
460 plot_bgcolor="rgba(255,255,255,0)",
461 )
462 """
463 # legend=dict(
464 # x=0.95,
465 # y=0,
466 # traceorder="normal",
467 # font=dict(
468 # family="sans-serif",
469 # size=9,
470 # color="black"
471 # ),
472 # bgcolor="LightSteelBlue",
473 # bordercolor="Black",
474 # borderwidth=2
475 # ),
476 """
477
478 fig = go.Figure(data=[data], layout=layout)
479 plotly.offline.plot(fig, filename="output.html", auto_open=False)
480 # to be discovered by `from_work_dir`
481 os.rename("output.html", "output")
482
483 return 0
484
485 elif plot_type == "learning_curve":
486 input_df = pd.read_csv(infile1, sep="\t", header="infer")
487 plot_std_err = params["plotting_selection"]["plot_std_err"]
488 data1 = go.Scatter(
489 x=input_df["train_sizes_abs"],
490 y=input_df["mean_train_scores"],
491 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
492 mode="lines",
493 name="Train Scores",
494 )
495 data2 = go.Scatter(
496 x=input_df["train_sizes_abs"],
497 y=input_df["mean_test_scores"],
498 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
499 mode="lines",
500 name="Test Scores",
501 )
502 layout = dict(
503 xaxis=dict(title="No. of samples"),
504 yaxis=dict(title="Performance Score"),
505 # modify these configurations to customize image
506 title=dict(
507 text=title or "Learning Curve",
508 x=0.5,
509 y=0.92,
510 xanchor="center",
511 yanchor="top",
512 ),
513 font=dict(family="sans-serif", size=11),
514 # control backgroud colors
515 plot_bgcolor="rgba(255,255,255,0)",
516 )
517 """
518 # legend=dict(
519 # x=0.95,
520 # y=0,
521 # traceorder="normal",
522 # font=dict(
523 # family="sans-serif",
524 # size=9,
525 # color="black"
526 # ),
527 # bgcolor="LightSteelBlue",
528 # bordercolor="Black",
529 # borderwidth=2
530 # ),
531 """
532
533 fig = go.Figure(data=[data1, data2], layout=layout)
534 plotly.offline.plot(fig, filename="output.html", auto_open=False)
535 # to be discovered by `from_work_dir`
536 os.rename("output.html", "output")
537
538 return 0
539
540 elif plot_type == "keras_plot_model":
541 with open(model_config, "r") as f:
542 model_str = f.read()
543 model = model_from_json(model_str)
544 plot_model(model, to_file="output.png")
545 os.rename("output.png", "output")
546
547 return 0
548
549 elif plot_type == "classification_confusion_matrix":
550 plot_selection = params["plotting_selection"]
551 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
552 header_predicted = "infer" if plot_selection["header_predicted"] else None
553 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted)
554 true_classes = input_true.iloc[:, -1].copy()
555 predicted_classes = input_predicted.iloc[:, -1].copy()
556 axis_labels = list(set(true_classes))
557 c_matrix = confusion_matrix(true_classes, predicted_classes)
558 fig, ax = plt.subplots(figsize=(7, 7))
559 im = plt.imshow(c_matrix, cmap=plot_color)
560 for i in range(len(c_matrix)):
561 for j in range(len(c_matrix)):
562 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
563 ax.set_ylabel("True class labels")
564 ax.set_xlabel("Predicted class labels")
565 ax.set_title(title)
566 ax.set_xticks(axis_labels)
567 ax.set_yticks(axis_labels)
568 fig.colorbar(im, ax=ax)
569 fig.tight_layout()
570 plt.savefig("output.png", dpi=125)
571 os.rename("output.png", "output")
572
573 return 0
574
575 # save pdf file to disk
576 # fig.write_image("image.pdf", format='pdf')
577 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
578
579
580 if __name__ == "__main__":
581 aparser = argparse.ArgumentParser()
582 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
583 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
584 aparser.add_argument("-X", "--infile1", dest="infile1")
585 aparser.add_argument("-y", "--infile2", dest="infile2")
586 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
587 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
588 aparser.add_argument("-g", "--groups", dest="groups")
589 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
590 aparser.add_argument("-b", "--intervals", dest="intervals")
591 aparser.add_argument("-t", "--targets", dest="targets")
592 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
593 aparser.add_argument("-c", "--model_config", dest="model_config")
594 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
595 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
596 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
597 aparser.add_argument("-pt", "--title", dest="title")
598 args = aparser.parse_args()
599
600 main(
601 args.inputs,
602 args.infile_estimator,
603 args.infile1,
604 args.infile2,
605 args.outfile_result,
606 outfile_object=args.outfile_object,
607 groups=args.groups,
608 ref_seq=args.ref_seq,
609 intervals=args.intervals,
610 targets=args.targets,
611 fasta_path=args.fasta_path,
612 model_config=args.model_config,
613 true_labels=args.true_labels,
614 predicted_labels=args.predicted_labels,
615 plot_color=args.plot_color,
616 title=args.title,
617 )