comparison ml_visualization_ex.py @ 4:6b94d76a1397 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author bgruening
date Mon, 16 Dec 2019 05:40:29 -0500
parents 09efff9a5765
children 222c02df5d55
comparison
equal deleted inserted replaced
3:72c0d2747dc9 4:6b94d76a1397
1 import argparse 1 import argparse
2 import json 2 import json
3 import matplotlib
4 import matplotlib.pyplot as plt
3 import numpy as np 5 import numpy as np
6 import os
4 import pandas as pd 7 import pandas as pd
5 import plotly 8 import plotly
6 import plotly.graph_objs as go 9 import plotly.graph_objs as go
7 import warnings 10 import warnings
8 11
15 from galaxy_ml.utils import load_model, read_columns, SafeEval 18 from galaxy_ml.utils import load_model, read_columns, SafeEval
16 19
17 20
18 safe_eval = SafeEval() 21 safe_eval = SafeEval()
19 22
23 # plotly default colors
24 default_colors = [
25 '#1f77b4', # muted blue
26 '#ff7f0e', # safety orange
27 '#2ca02c', # cooked asparagus green
28 '#d62728', # brick red
29 '#9467bd', # muted purple
30 '#8c564b', # chestnut brown
31 '#e377c2', # raspberry yogurt pink
32 '#7f7f7f', # middle gray
33 '#bcbd22', # curry yellow-green
34 '#17becf' # blue-teal
35 ]
36
37
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly
40
41 df1 : pandas.DataFrame
42 Containing y_true
43 df2 : pandas.DataFrame
44 Containing y_score
45 pos_label : None
46 The label of positive class
47 title : str
48 Plot title
49 """
50 data = []
51 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values
54
55 precision, recall, _ = precision_recall_curve(
56 y_true, y_score, pos_label=pos_label)
57 ap = average_precision_score(
58 y_true, y_score, pos_label=pos_label or 1)
59
60 trace = go.Scatter(
61 x=recall,
62 y=precision,
63 mode='lines',
64 marker=dict(
65 color=default_colors[idx % len(default_colors)]
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 )
69 data.append(trace)
70
71 layout = go.Layout(
72 xaxis=dict(
73 title='Recall',
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict(
83 text=title or 'Precision-Recall Curve',
84 x=0.5,
85 y=0.92,
86 xanchor='center',
87 yanchor='top'
88 ),
89 font=dict(
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)'
95 )
96 """
97 legend=dict(
98 x=0.95,
99 y=0,
100 traceorder="normal",
101 font=dict(
102 family="sans-serif",
103 size=9,
104 color="black"
105 ),
106 bgcolor="LightSteelBlue",
107 bordercolor="Black",
108 borderwidth=2
109 ),"""
110
111 fig = go.Figure(data=data, layout=layout)
112
113 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output')
116
117
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image
120 """
121 backend = matplotlib.get_backend()
122 if "inline" not in backend:
123 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind')
125 plt.figure()
126
127 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values
130
131 precision, recall, _ = precision_recall_curve(
132 y_true, y_score, pos_label=pos_label)
133 ap = average_precision_score(
134 y_true, y_score, pos_label=pos_label or 1)
135
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3,
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap))
138
139 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall')
142 plt.ylabel('Precision')
143 title = title or 'Precision-Recall Curve'
144 plt.title(title)
145 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"),
148 os.path.join(folder, "output"))
149
150
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly
155
156 df1 : pandas.DataFrame
157 Containing y_true
158 df2 : pandas.DataFrame
159 Containing y_score
160 pos_label : None
161 The label of positive class
162 drop_intermediate : bool
163 Whether to drop some suboptimal thresholds
164 title : str
165 Plot title
166 """
167 data = []
168 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values
171
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
173 drop_intermediate=drop_intermediate)
174 roc_auc = auc(fpr, tpr)
175
176 trace = go.Scatter(
177 x=fpr,
178 y=tpr,
179 mode='lines',
180 marker=dict(
181 color=default_colors[idx % len(default_colors)]
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 )
185 data.append(trace)
186
187 layout = go.Layout(
188 xaxis=dict(
189 title='False Positive Rate',
190 linecolor='lightslategray',
191 linewidth=1
192 ),
193 yaxis=dict(
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve',
200 x=0.5,
201 y=0.92,
202 xanchor='center',
203 yanchor='top'
204 ),
205 font=dict(
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)'
211 )
212 """
213 # legend=dict(
214 # x=0.95,
215 # y=0,
216 # traceorder="normal",
217 # font=dict(
218 # family="sans-serif",
219 # size=9,
220 # color="black"
221 # ),
222 # bgcolor="LightSteelBlue",
223 # bordercolor="Black",
224 # borderwidth=2
225 # ),
226 """
227
228 fig = go.Figure(data=data, layout=layout)
229
230 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output')
233
234
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label,
236 drop_intermediate=True,
237 title=None):
238 """visualize roc-curve using matplotlib and output svg image
239 """
240 backend = matplotlib.get_backend()
241 if "inline" not in backend:
242 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind')
244 plt.figure()
245
246 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values
249
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
251 drop_intermediate=drop_intermediate)
252 roc_auc = auc(fpr, tpr)
253
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1,
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc))
256
257 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate')
260 plt.ylabel('True Positive Rate')
261 title = title or 'Receiver Operating Characteristic (ROC) Curve'
262 plt.title(title)
263 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"),
266 os.path.join(folder, "output"))
267
20 268
21 def main(inputs, infile_estimator=None, infile1=None, 269 def main(inputs, infile_estimator=None, infile1=None,
22 infile2=None, outfile_result=None, 270 infile2=None, outfile_result=None,
23 outfile_object=None, groups=None, 271 outfile_object=None, groups=None,
24 ref_seq=None, intervals=None, 272 ref_seq=None, intervals=None,
69 with open(inputs, 'r') as param_handler: 317 with open(inputs, 'r') as param_handler:
70 params = json.load(param_handler) 318 params = json.load(param_handler)
71 319
72 title = params['plotting_selection']['title'].strip() 320 title = params['plotting_selection']['title'].strip()
73 plot_type = params['plotting_selection']['plot_type'] 321 plot_type = params['plotting_selection']['plot_type']
322 plot_format = params['plotting_selection']['plot_format']
323
74 if plot_type == 'feature_importances': 324 if plot_type == 'feature_importances':
75 with open(infile_estimator, 'rb') as estimator_handler: 325 with open(infile_estimator, 'rb') as estimator_handler:
76 estimator = load_model(estimator_handler) 326 estimator = load_model(estimator_handler)
77 327
78 column_option = (params['plotting_selection'] 328 column_option = (params['plotting_selection']
121 trace = go.Bar(x=feature_names[indices], 371 trace = go.Bar(x=feature_names[indices],
122 y=coefs[indices]) 372 y=coefs[indices])
123 layout = go.Layout(title=title or "Feature Importances") 373 layout = go.Layout(title=title or "Feature Importances")
124 fig = go.Figure(data=[trace], layout=layout) 374 fig = go.Figure(data=[trace], layout=layout)
125 375
126 elif plot_type == 'pr_curve': 376 plotly.offline.plot(fig, filename="output.html",
127 df1 = pd.read_csv(infile1, sep='\t', header=None) 377 auto_open=False)
128 df2 = pd.read_csv(infile2, sep='\t', header=None) 378 # to be discovered by `from_work_dir`
129 379 os.rename('output.html', 'output')
130 precision = {} 380
131 recall = {} 381 return 0
132 ap = {} 382
383 elif plot_type in ('pr_curve', 'roc_curve'):
384 df1 = pd.read_csv(infile1, sep='\t', header='infer')
385 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32)
386
387 minimum = params['plotting_selection']['report_minimum_n_positives']
388 # filter out columns whose n_positives is beblow the threhold
389 if minimum:
390 mask = df1.sum(axis=0) >= minimum
391 df1 = df1.loc[:, mask]
392 df2 = df2.loc[:, mask]
133 393
134 pos_label = params['plotting_selection']['pos_label'].strip() \ 394 pos_label = params['plotting_selection']['pos_label'].strip() \
135 or None 395 or None
136 for col in df1.columns: 396
137 y_true = df1[col].values 397 if plot_type == 'pr_curve':
138 y_score = df2[col].values 398 if plot_format == 'plotly_html':
139 399 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
140 precision[col], recall[col], _ = precision_recall_curve( 400 else:
141 y_true, y_score, pos_label=pos_label) 401 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
142 ap[col] = average_precision_score( 402 else: # 'roc_curve'
143 y_true, y_score, pos_label=pos_label or 1) 403 drop_intermediate = (params['plotting_selection']
144 404 ['drop_intermediate'])
145 if len(df1.columns) > 1: 405 if plot_format == 'plotly_html':
146 precision["micro"], recall["micro"], _ = precision_recall_curve( 406 visualize_roc_curve_plotly(df1, df2, pos_label,
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) 407 drop_intermediate=drop_intermediate,
148 ap['micro'] = average_precision_score( 408 title=title)
149 df1.values, df2.values, average='micro', 409 else:
150 pos_label=pos_label or 1) 410 visualize_roc_curve_matplotlib(
151 411 df1, df2, pos_label,
152 data = [] 412 drop_intermediate=drop_intermediate,
153 for key in precision.keys(): 413 title=title)
154 trace = go.Scatter( 414
155 x=recall[key], 415 return 0
156 y=precision[key],
157 mode='lines',
158 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro'
159 else 'column %s (area = %.2f)' % (key, ap[key])
160 )
161 data.append(trace)
162
163 layout = go.Layout(
164 title=title or "Precision-Recall curve",
165 xaxis=dict(title='Recall'),
166 yaxis=dict(title='Precision')
167 )
168
169 fig = go.Figure(data=data, layout=layout)
170
171 elif plot_type == 'roc_curve':
172 df1 = pd.read_csv(infile1, sep='\t', header=None)
173 df2 = pd.read_csv(infile2, sep='\t', header=None)
174
175 fpr = {}
176 tpr = {}
177 roc_auc = {}
178
179 pos_label = params['plotting_selection']['pos_label'].strip() \
180 or None
181 for col in df1.columns:
182 y_true = df1[col].values
183 y_score = df2[col].values
184
185 fpr[col], tpr[col], _ = roc_curve(
186 y_true, y_score, pos_label=pos_label)
187 roc_auc[col] = auc(fpr[col], tpr[col])
188
189 if len(df1.columns) > 1:
190 fpr["micro"], tpr["micro"], _ = roc_curve(
191 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
192 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"])
193
194 data = []
195 for key in fpr.keys():
196 trace = go.Scatter(
197 x=fpr[key],
198 y=tpr[key],
199 mode='lines',
200 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro'
201 else 'column %s (area = %.2f)' % (key, roc_auc[key])
202 )
203 data.append(trace)
204
205 trace = go.Scatter(x=[0, 1], y=[0, 1],
206 mode='lines',
207 line=dict(color='black', dash='dash'),
208 showlegend=False)
209 data.append(trace)
210
211 layout = go.Layout(
212 title=title or "Receiver operating characteristic curve",
213 xaxis=dict(title='False Positive Rate'),
214 yaxis=dict(title='True Positive Rate')
215 )
216
217 fig = go.Figure(data=data, layout=layout)
218 416
219 elif plot_type == 'rfecv_gridscores': 417 elif plot_type == 'rfecv_gridscores':
220 input_df = pd.read_csv(infile1, sep='\t', header='infer') 418 input_df = pd.read_csv(infile1, sep='\t', header='infer')
221 scores = input_df.iloc[:, 0] 419 scores = input_df.iloc[:, 0]
222 steps = params['plotting_selection']['steps'].strip() 420 steps = params['plotting_selection']['steps'].strip()
229 mode='lines' 427 mode='lines'
230 ) 428 )
231 layout = go.Layout( 429 layout = go.Layout(
232 xaxis=dict(title="Number of features selected"), 430 xaxis=dict(title="Number of features selected"),
233 yaxis=dict(title="Cross validation score"), 431 yaxis=dict(title="Cross validation score"),
234 title=title or None 432 title=dict(
235 ) 433 text=title or None,
434 x=0.5,
435 y=0.92,
436 xanchor='center',
437 yanchor='top'
438 ),
439 font=dict(
440 family="sans-serif",
441 size=11
442 ),
443 # control backgroud colors
444 plot_bgcolor='rgba(255,255,255,0)'
445 )
446 """
447 # legend=dict(
448 # x=0.95,
449 # y=0,
450 # traceorder="normal",
451 # font=dict(
452 # family="sans-serif",
453 # size=9,
454 # color="black"
455 # ),
456 # bgcolor="LightSteelBlue",
457 # bordercolor="Black",
458 # borderwidth=2
459 # ),
460 """
236 461
237 fig = go.Figure(data=[data], layout=layout) 462 fig = go.Figure(data=[data], layout=layout)
463 plotly.offline.plot(fig, filename="output.html",
464 auto_open=False)
465 # to be discovered by `from_work_dir`
466 os.rename('output.html', 'output')
467
468 return 0
238 469
239 elif plot_type == 'learning_curve': 470 elif plot_type == 'learning_curve':
240 input_df = pd.read_csv(infile1, sep='\t', header='infer') 471 input_df = pd.read_csv(infile1, sep='\t', header='infer')
241 plot_std_err = params['plotting_selection']['plot_std_err'] 472 plot_std_err = params['plotting_selection']['plot_std_err']
242 data1 = go.Scatter( 473 data1 = go.Scatter(
262 title='No. of samples' 493 title='No. of samples'
263 ), 494 ),
264 yaxis=dict( 495 yaxis=dict(
265 title='Performance Score' 496 title='Performance Score'
266 ), 497 ),
267 title=title or 'Learning Curve' 498 # modify these configurations to customize image
268 ) 499 title=dict(
500 text=title or 'Learning Curve',
501 x=0.5,
502 y=0.92,
503 xanchor='center',
504 yanchor='top'
505 ),
506 font=dict(
507 family="sans-serif",
508 size=11
509 ),
510 # control backgroud colors
511 plot_bgcolor='rgba(255,255,255,0)'
512 )
513 """
514 # legend=dict(
515 # x=0.95,
516 # y=0,
517 # traceorder="normal",
518 # font=dict(
519 # family="sans-serif",
520 # size=9,
521 # color="black"
522 # ),
523 # bgcolor="LightSteelBlue",
524 # bordercolor="Black",
525 # borderwidth=2
526 # ),
527 """
528
269 fig = go.Figure(data=[data1, data2], layout=layout) 529 fig = go.Figure(data=[data1, data2], layout=layout)
530 plotly.offline.plot(fig, filename="output.html",
531 auto_open=False)
532 # to be discovered by `from_work_dir`
533 os.rename('output.html', 'output')
534
535 return 0
270 536
271 elif plot_type == 'keras_plot_model': 537 elif plot_type == 'keras_plot_model':
272 with open(model_config, 'r') as f: 538 with open(model_config, 'r') as f:
273 model_str = f.read() 539 model_str = f.read()
274 model = model_from_json(model_str) 540 model = model_from_json(model_str)
275 plot_model(model, to_file="output.png") 541 plot_model(model, to_file="output.png")
276 __import__('os').rename('output.png', 'output') 542 os.rename('output.png', 'output')
277 543
278 return 0 544 return 0
279 545
280 plotly.offline.plot(fig, filename="output.html", 546 # save pdf file to disk
281 auto_open=False) 547 # fig.write_image("image.pdf", format='pdf')
282 # to be discovered by `from_work_dir` 548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
283 __import__('os').rename('output.html', 'output')
284 549
285 550
286 if __name__ == '__main__': 551 if __name__ == '__main__':
287 aparser = argparse.ArgumentParser() 552 aparser = argparse.ArgumentParser()
288 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 553 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)