### comparison ml_visualization_ex.py @ 35:eeaf989f1024draft

author bgruening Tue, 13 Apr 2021 18:09:01 +0000 237ae72d58a3 420a4bf99244
comparison
equal inserted replaced
34:2d032cff49eb 35:eeaf989f1024
20 20
21 safe_eval = SafeEval() 21 safe_eval = SafeEval()
22 22
23 # plotly default colors 23 # plotly default colors
24 default_colors = [ 24 default_colors = [
25 '#1f77b4', # muted blue 25 "#1f77b4", # muted blue
26 '#ff7f0e', # safety orange 26 "#ff7f0e", # safety orange
27 '#2ca02c', # cooked asparagus green 27 "#2ca02c", # cooked asparagus green
28 '#d62728', # brick red 28 "#d62728", # brick red
29 '#9467bd', # muted purple 29 "#9467bd", # muted purple
30 '#8c564b', # chestnut brown 30 "#8c564b", # chestnut brown
31 '#e377c2', # raspberry yogurt pink 31 "#e377c2", # raspberry yogurt pink
32 '#7f7f7f', # middle gray 32 "#7f7f7f", # middle gray
33 '#bcbd22', # curry yellow-green 33 "#bcbd22", # curry yellow-green
34 '#17becf' # blue-teal 34 "#17becf", # blue-teal
35 ] 35 ]
36 36
37 37
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): 38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly 39 """output pr-curve in html using plotly
50 data = [] 50 data = []
51 for idx in range(df1.shape[1]): 51 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values 52 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values 53 y_score = df2.iloc[:, idx].values
54 54
55 precision, recall, _ = precision_recall_curve( 55 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
56 y_true, y_score, pos_label=pos_label) 56 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
57 ap = average_precision_score(
58 y_true, y_score, pos_label=pos_label or 1)
59 57
60 trace = go.Scatter( 58 trace = go.Scatter(
61 x=recall, 59 x=recall,
62 y=precision, 60 y=precision,
63 mode='lines', 61 mode="lines",
64 marker=dict( 62 marker=dict(color=default_colors[idx % len(default_colors)]),
65 color=default_colors[idx % len(default_colors)] 63 name="%s (area = %.3f)" % (idx, ap),
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 ) 64 )
69 data.append(trace) 65 data.append(trace)
70 66
71 layout = go.Layout( 67 layout = go.Layout(
72 xaxis=dict( 68 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
73 title='Recall', 69 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict( 70 title=dict(
83 text=title or 'Precision-Recall Curve', 71 text=title or "Precision-Recall Curve",
84 x=0.5, 72 x=0.5,
85 y=0.92, 73 y=0.92,
86 xanchor='center', 74 xanchor="center",
87 yanchor='top' 75 yanchor="top",
88 ), 76 ),
89 font=dict( 77 font=dict(family="sans-serif", size=11),
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors 78 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)' 79 plot_bgcolor="rgba(255,255,255,0)",
95 ) 80 )
96 """ 81 """
97 legend=dict( 82 legend=dict(
98 x=0.95, 83 x=0.95,
99 y=0, 84 y=0,
110 95
111 fig = go.Figure(data=data, layout=layout) 96 fig = go.Figure(data=data, layout=layout)
112 97
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) 98 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir` 99 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output') 100 os.rename("output.html", "output")
116 101
117 102
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): 103 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image 104 """visualize pr-curve using matplotlib and output svg image"""
120 """
121 backend = matplotlib.get_backend() 105 backend = matplotlib.get_backend()
122 if "inline" not in backend: 106 if "inline" not in backend:
123 matplotlib.use("SVG") 107 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind') 108 plt.style.use("seaborn-colorblind")
125 plt.figure() 109 plt.figure()
126 110
127 for idx in range(df1.shape[1]): 111 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values 112 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values 113 y_score = df2.iloc[:, idx].values
130 114
131 precision, recall, _ = precision_recall_curve( 115 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
132 y_true, y_score, pos_label=pos_label) 116 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
133 ap = average_precision_score( 117
134 y_true, y_score, pos_label=pos_label or 1) 118 plt.step(
135 119 recall,
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, 120 precision,
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) 121 "r-",
122 color="black",
123 alpha=0.3,
124 lw=1,
125 where="post",
126 label="%s (area = %.3f)" % (idx, ap),
127 )
138 128
139 plt.xlim([0.0, 1.0]) 129 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05]) 130 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall') 131 plt.xlabel("Recall")
142 plt.ylabel('Precision') 132 plt.ylabel("Precision")
143 title = title or 'Precision-Recall Curve' 133 title = title or "Precision-Recall Curve"
144 plt.title(title) 134 plt.title(title)
145 folder = os.getcwd() 135 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 136 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"), 137 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
148 os.path.join(folder, "output")) 138
149 139
150 140 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly 141 """output roc-curve in html using plotly
155 142
156 df1 : pandas.DataFrame 143 df1 : pandas.DataFrame
157 Containing y_true 144 Containing y_true
158 df2 : pandas.DataFrame 145 df2 : pandas.DataFrame
167 data = [] 154 data = []
168 for idx in range(df1.shape[1]): 155 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values 156 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values 157 y_score = df2.iloc[:, idx].values
171 158
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 159 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
173 drop_intermediate=drop_intermediate)
174 roc_auc = auc(fpr, tpr) 160 roc_auc = auc(fpr, tpr)
175 161
176 trace = go.Scatter( 162 trace = go.Scatter(
177 x=fpr, 163 x=fpr,
178 y=tpr, 164 y=tpr,
179 mode='lines', 165 mode="lines",
180 marker=dict( 166 marker=dict(color=default_colors[idx % len(default_colors)]),
181 color=default_colors[idx % len(default_colors)] 167 name="%s (area = %.3f)" % (idx, roc_auc),
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 ) 168 )
185 data.append(trace) 169 data.append(trace)
186 170
187 layout = go.Layout( 171 layout = go.Layout(
188 xaxis=dict( 172 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1),
189 title='False Positive Rate', 173 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
190 linecolor='lightslategray',
191 linewidth=1
192 ),
193 yaxis=dict(
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict( 174 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', 175 text=title or "Receiver Operating Characteristic (ROC) Curve",
200 x=0.5, 176 x=0.5,
201 y=0.92, 177 y=0.92,
202 xanchor='center', 178 xanchor="center",
203 yanchor='top' 179 yanchor="top",
204 ), 180 ),
205 font=dict( 181 font=dict(family="sans-serif", size=11),
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors 182 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)' 183 plot_bgcolor="rgba(255,255,255,0)",
211 ) 184 )
212 """ 185 """
213 # legend=dict( 186 # legend=dict(
214 # x=0.95, 187 # x=0.95,
215 # y=0, 188 # y=0,
227 200
228 fig = go.Figure(data=data, layout=layout) 201 fig = go.Figure(data=data, layout=layout)
229 202
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) 203 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir` 204 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output') 205 os.rename("output.html", "output")
233 206
234 207
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, 208 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None):
236 drop_intermediate=True, 209 """visualize roc-curve using matplotlib and output svg image"""
237 title=None):
238 """visualize roc-curve using matplotlib and output svg image
239 """
240 backend = matplotlib.get_backend() 210 backend = matplotlib.get_backend()
241 if "inline" not in backend: 211 if "inline" not in backend:
242 matplotlib.use("SVG") 212 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind') 213 plt.style.use("seaborn-colorblind")
244 plt.figure() 214 plt.figure()
245 215
246 for idx in range(df1.shape[1]): 216 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values 217 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values 218 y_score = df2.iloc[:, idx].values
249 219
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 220 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
251 drop_intermediate=drop_intermediate)
252 roc_auc = auc(fpr, tpr) 221 roc_auc = auc(fpr, tpr)
253 222
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, 223 plt.step(
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) 224 fpr,
225 tpr,
226 "r-",
227 color="black",
228 alpha=0.3,
229 lw=1,
230 where="post",
231 label="%s (area = %.3f)" % (idx, roc_auc),
232 )
256 233
257 plt.xlim([0.0, 1.0]) 234 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05]) 235 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate') 236 plt.xlabel("False Positive Rate")
260 plt.ylabel('True Positive Rate') 237 plt.ylabel("True Positive Rate")
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' 238 title = title or "Receiver Operating Characteristic (ROC) Curve"
262 plt.title(title) 239 plt.title(title)
263 folder = os.getcwd() 240 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 241 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"), 242 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
266 os.path.join(folder, "output"))
267 243
268 244
269 def get_dataframe(file_path, plot_selection, header_name, column_name): 245 def get_dataframe(file_path, plot_selection, header_name, column_name):
271 column_option = plot_selection[column_name]["selected_column_selector_option"] 247 column_option = plot_selection[column_name]["selected_column_selector_option"]
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: 248 if column_option in [
249 "by_index_number",
250 "all_but_by_index_number",
253 ]:
273 col = plot_selection[column_name]["col1"] 254 col = plot_selection[column_name]["col1"]
274 else: 255 else:
275 col = None 256 col = None
276 _, input_df = read_columns(file_path, c=col, 257 _, input_df = read_columns(file_path, c=col,
277 c_option=column_option, 258 c_option=column_option,
278 return_df=True, 259 return_df=True,
280 parse_dates=True) 261 parse_dates=True)
281 return input_df 262 return input_df
282 263
283 264
284 def main(inputs, infile_estimator=None, infile1=None, 265 def main(
285 infile2=None, outfile_result=None, 266 inputs,
286 outfile_object=None, groups=None, 267 infile_estimator=None,
287 ref_seq=None, intervals=None, 268 infile1=None,
288 targets=None, fasta_path=None, 269 infile2=None,
289 model_config=None, true_labels=None, 270 outfile_result=None,
290 predicted_labels=None, plot_color=None, 271 outfile_object=None,
291 title=None): 272 groups=None,
273 ref_seq=None,
274 intervals=None,
275 targets=None,
276 fasta_path=None,
277 model_config=None,
278 true_labels=None,
279 predicted_labels=None,
280 plot_color=None,
281 title=None,
282 ):
292 """ 283 """
293 Parameter 284 Parameter
294 --------- 285 ---------
295 inputs : str 286 inputs : str
296 File path to galaxy tool parameter 287 File path to galaxy tool parameter
339 Color of the confusion matrix heatmap 330 Color of the confusion matrix heatmap
340 331
341 title : str, default is None 332 title : str, default is None
342 Title of the confusion matrix heatmap 333 Title of the confusion matrix heatmap
343 """ 334 """
344 warnings.simplefilter('ignore') 335 warnings.simplefilter("ignore")
345 336
346 with open(inputs, 'r') as param_handler: 337 with open(inputs, "r") as param_handler:
348 339
349 title = params['plotting_selection']['title'].strip() 340 title = params["plotting_selection"]["title"].strip()
350 plot_type = params['plotting_selection']['plot_type'] 341 plot_type = params["plotting_selection"]["plot_type"]
351 plot_format = params['plotting_selection']['plot_format'] 342 plot_format = params["plotting_selection"]["plot_format"]
352 343
353 if plot_type == 'feature_importances': 344 if plot_type == "feature_importances":
354 with open(infile_estimator, 'rb') as estimator_handler: 345 with open(infile_estimator, "rb") as estimator_handler:
356 347
357 column_option = (params['plotting_selection'] 348 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"]
358 ['column_selector_options'] 349 if column_option in [
359 ['selected_column_selector_option']) 350 "by_index_number",
360 if column_option in ['by_index_number', 'all_but_by_index_number', 351 "all_but_by_index_number",
362 c = (params['plotting_selection'] 353 "all_but_by_header_name",
363 ['column_selector_options']['col1']) 354 ]:
355 c = params["plotting_selection"]["column_selector_options"]["col1"]
364 else: 356 else:
365 c = None 357 c = None
366 358
368 c_option=column_option, 360 infile1,
369 return_df=True, 361 c=c,
371 parse_dates=True) 363 return_df=True,
364 sep="\t",
366 parse_dates=True,
367 )
372 368
373 feature_names = input_df.columns.values 369 feature_names = input_df.columns.values
374 370
375 if isinstance(estimator, Pipeline): 371 if isinstance(estimator, Pipeline):
376 for st in estimator.steps[:-1]: 372 for st in estimator.steps[:-1]:
377 if isinstance(st[-1], SelectorMixin): 373 if isinstance(st[-1], SelectorMixin):
380 estimator = estimator.steps[-1][-1] 376 estimator = estimator.steps[-1][-1]
381 377
382 if hasattr(estimator, 'coef_'): 378 if hasattr(estimator, "coef_"):
383 coefs = estimator.coef_ 379 coefs = estimator.coef_
384 else: 380 else:
385 coefs = getattr(estimator, 'feature_importances_', None) 381 coefs = getattr(estimator, "feature_importances_", None)
386 if coefs is None: 382 if coefs is None:
387 raise RuntimeError('The classifier does not expose ' 383 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes")
388 '"coef_" or "feature_importances_" ' 384
389 'attributes') 385 threshold = params["plotting_selection"]["threshold"]
390
391 threshold = params['plotting_selection']['threshold']
392 if threshold is not None: 386 if threshold is not None:
393 mask = (coefs > threshold) | (coefs < -threshold) 387 mask = (coefs > threshold) | (coefs < -threshold)
396 390
397 # sort 391 # sort
398 indices = np.argsort(coefs)[::-1] 392 indices = np.argsort(coefs)[::-1]
399 393
400 trace = go.Bar(x=feature_names[indices], 394 trace = go.Bar(x=feature_names[indices], y=coefs[indices])
401 y=coefs[indices])
402 layout = go.Layout(title=title or "Feature Importances") 395 layout = go.Layout(title=title or "Feature Importances")
403 fig = go.Figure(data=[trace], layout=layout) 396 fig = go.Figure(data=[trace], layout=layout)
404 397
405 plotly.offline.plot(fig, filename="output.html", 398 plotly.offline.plot(fig, filename="output.html", auto_open=False)
406 auto_open=False)
407 # to be discovered by `from_work_dir` 399 # to be discovered by `from_work_dir`
408 os.rename('output.html', 'output') 400 os.rename("output.html", "output")
409 401
410 return 0 402 return 0
411 403
412 elif plot_type in ('pr_curve', 'roc_curve'): 404 elif plot_type in ("pr_curve", "roc_curve"):
415 407
416 minimum = params['plotting_selection']['report_minimum_n_positives'] 408 minimum = params["plotting_selection"]["report_minimum_n_positives"]
417 # filter out columns whose n_positives is beblow the threhold 409 # filter out columns whose n_positives is beblow the threhold
418 if minimum: 410 if minimum:
419 mask = df1.sum(axis=0) >= minimum 411 mask = df1.sum(axis=0) >= minimum
422 414
423 pos_label = params['plotting_selection']['pos_label'].strip() \ 415 pos_label = params["plotting_selection"]["pos_label"].strip() or None
424 or None 416
425 417 if plot_type == "pr_curve":
426 if plot_type == 'pr_curve': 418 if plot_format == "plotly_html":
427 if plot_format == 'plotly_html':
428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) 419 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
429 else: 420 else:
430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) 421 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
431 else: # 'roc_curve' 422 else: # 'roc_curve'
432 drop_intermediate = (params['plotting_selection'] 423 drop_intermediate = params["plotting_selection"]["drop_intermediate"]
433 ['drop_intermediate']) 424 if plot_format == "plotly_html":
434 if plot_format == 'plotly_html': 425 visualize_roc_curve_plotly(
435 visualize_roc_curve_plotly(df1, df2, pos_label, 426 df1,
436 drop_intermediate=drop_intermediate, 427 df2,
437 title=title) 428 pos_label,
429 drop_intermediate=drop_intermediate,
430 title=title,
431 )
438 else: 432 else:
439 visualize_roc_curve_matplotlib( 433 visualize_roc_curve_matplotlib(
440 df1, df2, pos_label, 434 df1,
435 df2,
436 pos_label,
441 drop_intermediate=drop_intermediate, 437 drop_intermediate=drop_intermediate,
442 title=title) 438 title=title,
439 )
443 440
444 return 0 441 return 0
445 442
446 elif plot_type == 'rfecv_gridscores': 443 elif plot_type == "rfecv_gridscores":
448 scores = input_df.iloc[:, 0] 445 scores = input_df.iloc[:, 0]
449 steps = params['plotting_selection']['steps'].strip() 446 steps = params["plotting_selection"]["steps"].strip()
450 steps = safe_eval(steps) 447 steps = safe_eval(steps)
451 448
452 data = go.Scatter( 449 data = go.Scatter(
453 x=list(range(len(scores))), 450 x=list(range(len(scores))),
454 y=scores, 451 y=scores,
455 text=[str(_) for _ in steps] if steps else None, 452 text=[str(_) for _ in steps] if steps else None,
456 mode='lines' 453 mode="lines",
457 ) 454 )
458 layout = go.Layout( 455 layout = go.Layout(
459 xaxis=dict(title="Number of features selected"), 456 xaxis=dict(title="Number of features selected"),
460 yaxis=dict(title="Cross validation score"), 457 yaxis=dict(title="Cross validation score"),
461 title=dict( 458 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"),
462 text=title or None, 459 font=dict(family="sans-serif", size=11),
463 x=0.5,
464 y=0.92,
465 xanchor='center',
466 yanchor='top'
467 ),
468 font=dict(
469 family="sans-serif",
470 size=11
471 ),
472 # control backgroud colors 460 # control backgroud colors
473 plot_bgcolor='rgba(255,255,255,0)' 461 plot_bgcolor="rgba(255,255,255,0)",
474 ) 462 )
475 """ 463 """
476 # legend=dict( 464 # legend=dict(
477 # x=0.95, 465 # x=0.95,
478 # y=0, 466 # y=0,
487 # borderwidth=2 475 # borderwidth=2
488 # ), 476 # ),
489 """ 477 """
490 478
491 fig = go.Figure(data=[data], layout=layout) 479 fig = go.Figure(data=[data], layout=layout)
492 plotly.offline.plot(fig, filename="output.html", 480 plotly.offline.plot(fig, filename="output.html", auto_open=False)
493 auto_open=False)
494 # to be discovered by `from_work_dir` 481 # to be discovered by `from_work_dir`
495 os.rename('output.html', 'output') 482 os.rename("output.html", "output")
496 483
497 return 0 484 return 0
498 485
499 elif plot_type == 'learning_curve': 486 elif plot_type == "learning_curve":
501 plot_std_err = params['plotting_selection']['plot_std_err'] 488 plot_std_err = params["plotting_selection"]["plot_std_err"]
502 data1 = go.Scatter( 489 data1 = go.Scatter(
503 x=input_df['train_sizes_abs'], 490 x=input_df["train_sizes_abs"],
504 y=input_df['mean_train_scores'], 491 y=input_df["mean_train_scores"],
505 error_y=dict( 492 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
506 array=input_df['std_train_scores'] 493 mode="lines",
507 ) if plot_std_err else None,
508 mode='lines',
509 name="Train Scores", 494 name="Train Scores",
510 ) 495 )
511 data2 = go.Scatter( 496 data2 = go.Scatter(
512 x=input_df['train_sizes_abs'], 497 x=input_df["train_sizes_abs"],
513 y=input_df['mean_test_scores'], 498 y=input_df["mean_test_scores"],
514 error_y=dict( 499 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
515 array=input_df['std_test_scores'] 500 mode="lines",
516 ) if plot_std_err else None,
517 mode='lines',
518 name="Test Scores", 501 name="Test Scores",
519 ) 502 )
520 layout = dict( 503 layout = dict(
521 xaxis=dict( 504 xaxis=dict(title="No. of samples"),
522 title='No. of samples' 505 yaxis=dict(title="Performance Score"),
523 ),
524 yaxis=dict(
525 title='Performance Score'
526 ),
527 # modify these configurations to customize image 506 # modify these configurations to customize image
528 title=dict( 507 title=dict(
529 text=title or 'Learning Curve', 508 text=title or "Learning Curve",
530 x=0.5, 509 x=0.5,
531 y=0.92, 510 y=0.92,
532 xanchor='center', 511 xanchor="center",
533 yanchor='top' 512 yanchor="top",
534 ), 513 ),
535 font=dict( 514 font=dict(family="sans-serif", size=11),
536 family="sans-serif",
537 size=11
538 ),
539 # control backgroud colors 515 # control backgroud colors
540 plot_bgcolor='rgba(255,255,255,0)' 516 plot_bgcolor="rgba(255,255,255,0)",
541 ) 517 )
542 """ 518 """
543 # legend=dict( 519 # legend=dict(
544 # x=0.95, 520 # x=0.95,
545 # y=0, 521 # y=0,
554 # borderwidth=2 530 # borderwidth=2
555 # ), 531 # ),
556 """ 532 """
557 533
558 fig = go.Figure(data=[data1, data2], layout=layout) 534 fig = go.Figure(data=[data1, data2], layout=layout)
559 plotly.offline.plot(fig, filename="output.html", 535 plotly.offline.plot(fig, filename="output.html", auto_open=False)
560 auto_open=False)
561 # to be discovered by `from_work_dir` 536 # to be discovered by `from_work_dir`
562 os.rename('output.html', 'output') 537 os.rename("output.html", "output")
563 538
564 return 0 539 return 0
565 540
566 elif plot_type == 'keras_plot_model': 541 elif plot_type == "keras_plot_model":
567 with open(model_config, 'r') as f: 542 with open(model_config, "r") as f:
569 model = model_from_json(model_str) 544 model = model_from_json(model_str)
570 plot_model(model, to_file="output.png") 545 plot_model(model, to_file="output.png")
571 os.rename('output.png', 'output') 546 os.rename("output.png", "output")
572 547
573 return 0 548 return 0
574 549
575 elif plot_type == 'classification_confusion_matrix': 550 elif plot_type == "classification_confusion_matrix":
576 plot_selection = params["plotting_selection"] 551 plot_selection = params["plotting_selection"]
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") 552 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
580 true_classes = input_true.iloc[:, -1].copy() 555 true_classes = input_true.iloc[:, -1].copy()
581 predicted_classes = input_predicted.iloc[:, -1].copy() 556 predicted_classes = input_predicted.iloc[:, -1].copy()
582 axis_labels = list(set(true_classes)) 557 axis_labels = list(set(true_classes))
583 c_matrix = confusion_matrix(true_classes, predicted_classes) 558 c_matrix = confusion_matrix(true_classes, predicted_classes)
584 fig, ax = plt.subplots(figsize=(7, 7)) 559 fig, ax = plt.subplots(figsize=(7, 7))
585 im = plt.imshow(c_matrix, cmap=plot_color) 560 im = plt.imshow(c_matrix, cmap=plot_color)
586 for i in range(len(c_matrix)): 561 for i in range(len(c_matrix)):
587 for j in range(len(c_matrix)): 562 for j in range(len(c_matrix)):
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") 563 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
589 ax.set_ylabel('True class labels') 564 ax.set_ylabel("True class labels")
590 ax.set_xlabel('Predicted class labels') 565 ax.set_xlabel("Predicted class labels")
591 ax.set_title(title) 566 ax.set_title(title)
592 ax.set_xticks(axis_labels) 567 ax.set_xticks(axis_labels)
593 ax.set_yticks(axis_labels) 568 ax.set_yticks(axis_labels)
594 fig.colorbar(im, ax=ax) 569 fig.colorbar(im, ax=ax)
595 fig.tight_layout() 570 fig.tight_layout()
596 plt.savefig("output.png", dpi=125) 571 plt.savefig("output.png", dpi=125)
597 os.rename('output.png', 'output') 572 os.rename("output.png", "output")
598 573
599 return 0 574 return 0
600 575
601 # save pdf file to disk 576 # save pdf file to disk
602 # fig.write_image("image.pdf", format='pdf') 577 # fig.write_image("image.pdf", format='pdf')
603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 578 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
604 579
605 580
606 if __name__ == '__main__': 581 if __name__ == "__main__":
607 aparser = argparse.ArgumentParser() 582 aparser = argparse.ArgumentParser()
624 args = aparser.parse_args() 599 args = aparser.parse_args()
625 600
626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 601 main(
627 args.outfile_result, outfile_object=args.outfile_object, 602 args.inputs,
628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, 603 args.infile_estimator,
629 targets=args.targets, fasta_path=args.fasta_path, 604 args.infile1,
630 model_config=args.model_config, true_labels=args.true_labels, 605 args.infile2,
631 predicted_labels=args.predicted_labels, 606 args.outfile_result,
632 plot_color=args.plot_color, 607 outfile_object=args.outfile_object,
633 title=args.title) 608 groups=args.groups,
609 ref_seq=args.ref_seq,
610 intervals=args.intervals,
611 targets=args.targets,
612 fasta_path=args.fasta_path,
613 model_config=args.model_config,
614 true_labels=args.true_labels,
615 predicted_labels=args.predicted_labels,
616 plot_color=args.plot_color,
617 title=args.title,
618 )