comparison ml_visualization_ex.py @ 6:13b9ac5d277c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:24:07 +0000
parents 145208b3579d
children 3312fb686ffb
comparison
equal deleted inserted replaced
5:ce2fd1edbc6e 6:13b9ac5d277c
1 import argparse 1 import argparse
2 import json 2 import json
3 import os
4 import warnings
5
3 import matplotlib 6 import matplotlib
4 import matplotlib.pyplot as plt 7 import matplotlib.pyplot as plt
5 import numpy as np 8 import numpy as np
6 import os
7 import pandas as pd 9 import pandas as pd
8 import plotly 10 import plotly
9 import plotly.graph_objs as go 11 import plotly.graph_objs as go
10 import warnings 12 from galaxy_ml.utils import load_model, read_columns, SafeEval
11
12 from keras.models import model_from_json 13 from keras.models import model_from_json
13 from keras.utils import plot_model 14 from keras.utils import plot_model
14 from sklearn.feature_selection.base import SelectorMixin 15 from sklearn.feature_selection.base import SelectorMixin
15 from sklearn.metrics import precision_recall_curve, average_precision_score 16 from sklearn.metrics import auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve
16 from sklearn.metrics import roc_curve, auc, confusion_matrix
17 from sklearn.pipeline import Pipeline 17 from sklearn.pipeline import Pipeline
18 from galaxy_ml.utils import load_model, read_columns, SafeEval
19 18
20 19
21 safe_eval = SafeEval() 20 safe_eval = SafeEval()
22 21
23 # plotly default colors 22 # plotly default colors
24 default_colors = [ 23 default_colors = [
25 '#1f77b4', # muted blue 24 "#1f77b4", # muted blue
26 '#ff7f0e', # safety orange 25 "#ff7f0e", # safety orange
27 '#2ca02c', # cooked asparagus green 26 "#2ca02c", # cooked asparagus green
28 '#d62728', # brick red 27 "#d62728", # brick red
29 '#9467bd', # muted purple 28 "#9467bd", # muted purple
30 '#8c564b', # chestnut brown 29 "#8c564b", # chestnut brown
31 '#e377c2', # raspberry yogurt pink 30 "#e377c2", # raspberry yogurt pink
32 '#7f7f7f', # middle gray 31 "#7f7f7f", # middle gray
33 '#bcbd22', # curry yellow-green 32 "#bcbd22", # curry yellow-green
34 '#17becf' # blue-teal 33 "#17becf", # blue-teal
35 ] 34 ]
36 35
37 36
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): 37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
39 """output pr-curve in html using plotly 38 """output pr-curve in html using plotly
50 data = [] 49 data = []
51 for idx in range(df1.shape[1]): 50 for idx in range(df1.shape[1]):
52 y_true = df1.iloc[:, idx].values 51 y_true = df1.iloc[:, idx].values
53 y_score = df2.iloc[:, idx].values 52 y_score = df2.iloc[:, idx].values
54 53
55 precision, recall, _ = precision_recall_curve( 54 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
56 y_true, y_score, pos_label=pos_label) 55 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
57 ap = average_precision_score(
58 y_true, y_score, pos_label=pos_label or 1)
59 56
60 trace = go.Scatter( 57 trace = go.Scatter(
61 x=recall, 58 x=recall,
62 y=precision, 59 y=precision,
63 mode='lines', 60 mode="lines",
64 marker=dict( 61 marker=dict(color=default_colors[idx % len(default_colors)]),
65 color=default_colors[idx % len(default_colors)] 62 name="%s (area = %.3f)" % (idx, ap),
66 ),
67 name='%s (area = %.3f)' % (idx, ap)
68 ) 63 )
69 data.append(trace) 64 data.append(trace)
70 65
71 layout = go.Layout( 66 layout = go.Layout(
72 xaxis=dict( 67 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
73 title='Recall', 68 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
74 linecolor='lightslategray',
75 linewidth=1
76 ),
77 yaxis=dict(
78 title='Precision',
79 linecolor='lightslategray',
80 linewidth=1
81 ),
82 title=dict( 69 title=dict(
83 text=title or 'Precision-Recall Curve', 70 text=title or "Precision-Recall Curve",
84 x=0.5, 71 x=0.5,
85 y=0.92, 72 y=0.92,
86 xanchor='center', 73 xanchor="center",
87 yanchor='top' 74 yanchor="top",
88 ), 75 ),
89 font=dict( 76 font=dict(family="sans-serif", size=11),
90 family="sans-serif",
91 size=11
92 ),
93 # control backgroud colors 77 # control backgroud colors
94 plot_bgcolor='rgba(255,255,255,0)' 78 plot_bgcolor="rgba(255,255,255,0)",
95 ) 79 )
96 """ 80 """
97 legend=dict( 81 legend=dict(
98 x=0.95, 82 x=0.95,
99 y=0, 83 y=0,
110 94
111 fig = go.Figure(data=data, layout=layout) 95 fig = go.Figure(data=data, layout=layout)
112 96
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) 97 plotly.offline.plot(fig, filename="output.html", auto_open=False)
114 # to be discovered by `from_work_dir` 98 # to be discovered by `from_work_dir`
115 os.rename('output.html', 'output') 99 os.rename("output.html", "output")
116 100
117 101
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): 102 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
119 """visualize pr-curve using matplotlib and output svg image 103 """visualize pr-curve using matplotlib and output svg image"""
120 """
121 backend = matplotlib.get_backend() 104 backend = matplotlib.get_backend()
122 if "inline" not in backend: 105 if "inline" not in backend:
123 matplotlib.use("SVG") 106 matplotlib.use("SVG")
124 plt.style.use('seaborn-colorblind') 107 plt.style.use("seaborn-colorblind")
125 plt.figure() 108 plt.figure()
126 109
127 for idx in range(df1.shape[1]): 110 for idx in range(df1.shape[1]):
128 y_true = df1.iloc[:, idx].values 111 y_true = df1.iloc[:, idx].values
129 y_score = df2.iloc[:, idx].values 112 y_score = df2.iloc[:, idx].values
130 113
131 precision, recall, _ = precision_recall_curve( 114 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label)
132 y_true, y_score, pos_label=pos_label) 115 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
133 ap = average_precision_score( 116
134 y_true, y_score, pos_label=pos_label or 1) 117 plt.step(
135 118 recall,
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, 119 precision,
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) 120 "r-",
121 color="black",
122 alpha=0.3,
123 lw=1,
124 where="post",
125 label="%s (area = %.3f)" % (idx, ap),
126 )
138 127
139 plt.xlim([0.0, 1.0]) 128 plt.xlim([0.0, 1.0])
140 plt.ylim([0.0, 1.05]) 129 plt.ylim([0.0, 1.05])
141 plt.xlabel('Recall') 130 plt.xlabel("Recall")
142 plt.ylabel('Precision') 131 plt.ylabel("Precision")
143 title = title or 'Precision-Recall Curve' 132 title = title or "Precision-Recall Curve"
144 plt.title(title) 133 plt.title(title)
145 folder = os.getcwd() 134 folder = os.getcwd()
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 135 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
147 os.rename(os.path.join(folder, "output.svg"), 136 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
148 os.path.join(folder, "output")) 137
149 138
150 139 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
151 def visualize_roc_curve_plotly(df1, df2, pos_label,
152 drop_intermediate=True,
153 title=None):
154 """output roc-curve in html using plotly 140 """output roc-curve in html using plotly
155 141
156 df1 : pandas.DataFrame 142 df1 : pandas.DataFrame
157 Containing y_true 143 Containing y_true
158 df2 : pandas.DataFrame 144 df2 : pandas.DataFrame
167 data = [] 153 data = []
168 for idx in range(df1.shape[1]): 154 for idx in range(df1.shape[1]):
169 y_true = df1.iloc[:, idx].values 155 y_true = df1.iloc[:, idx].values
170 y_score = df2.iloc[:, idx].values 156 y_score = df2.iloc[:, idx].values
171 157
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 158 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
173 drop_intermediate=drop_intermediate)
174 roc_auc = auc(fpr, tpr) 159 roc_auc = auc(fpr, tpr)
175 160
176 trace = go.Scatter( 161 trace = go.Scatter(
177 x=fpr, 162 x=fpr,
178 y=tpr, 163 y=tpr,
179 mode='lines', 164 mode="lines",
180 marker=dict( 165 marker=dict(color=default_colors[idx % len(default_colors)]),
181 color=default_colors[idx % len(default_colors)] 166 name="%s (area = %.3f)" % (idx, roc_auc),
182 ),
183 name='%s (area = %.3f)' % (idx, roc_auc)
184 ) 167 )
185 data.append(trace) 168 data.append(trace)
186 169
187 layout = go.Layout( 170 layout = go.Layout(
188 xaxis=dict( 171 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1),
189 title='False Positive Rate', 172 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
190 linecolor='lightslategray',
191 linewidth=1
192 ),
193 yaxis=dict(
194 title='True Positive Rate',
195 linecolor='lightslategray',
196 linewidth=1
197 ),
198 title=dict( 173 title=dict(
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', 174 text=title or "Receiver Operating Characteristic (ROC) Curve",
200 x=0.5, 175 x=0.5,
201 y=0.92, 176 y=0.92,
202 xanchor='center', 177 xanchor="center",
203 yanchor='top' 178 yanchor="top",
204 ), 179 ),
205 font=dict( 180 font=dict(family="sans-serif", size=11),
206 family="sans-serif",
207 size=11
208 ),
209 # control backgroud colors 181 # control backgroud colors
210 plot_bgcolor='rgba(255,255,255,0)' 182 plot_bgcolor="rgba(255,255,255,0)",
211 ) 183 )
212 """ 184 """
213 # legend=dict( 185 # legend=dict(
214 # x=0.95, 186 # x=0.95,
215 # y=0, 187 # y=0,
227 199
228 fig = go.Figure(data=data, layout=layout) 200 fig = go.Figure(data=data, layout=layout)
229 201
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) 202 plotly.offline.plot(fig, filename="output.html", auto_open=False)
231 # to be discovered by `from_work_dir` 203 # to be discovered by `from_work_dir`
232 os.rename('output.html', 'output') 204 os.rename("output.html", "output")
233 205
234 206
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, 207 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None):
236 drop_intermediate=True, 208 """visualize roc-curve using matplotlib and output svg image"""
237 title=None):
238 """visualize roc-curve using matplotlib and output svg image
239 """
240 backend = matplotlib.get_backend() 209 backend = matplotlib.get_backend()
241 if "inline" not in backend: 210 if "inline" not in backend:
242 matplotlib.use("SVG") 211 matplotlib.use("SVG")
243 plt.style.use('seaborn-colorblind') 212 plt.style.use("seaborn-colorblind")
244 plt.figure() 213 plt.figure()
245 214
246 for idx in range(df1.shape[1]): 215 for idx in range(df1.shape[1]):
247 y_true = df1.iloc[:, idx].values 216 y_true = df1.iloc[:, idx].values
248 y_score = df2.iloc[:, idx].values 217 y_score = df2.iloc[:, idx].values
249 218
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, 219 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate)
251 drop_intermediate=drop_intermediate)
252 roc_auc = auc(fpr, tpr) 220 roc_auc = auc(fpr, tpr)
253 221
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, 222 plt.step(
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) 223 fpr,
224 tpr,
225 "r-",
226 color="black",
227 alpha=0.3,
228 lw=1,
229 where="post",
230 label="%s (area = %.3f)" % (idx, roc_auc),
231 )
256 232
257 plt.xlim([0.0, 1.0]) 233 plt.xlim([0.0, 1.0])
258 plt.ylim([0.0, 1.05]) 234 plt.ylim([0.0, 1.05])
259 plt.xlabel('False Positive Rate') 235 plt.xlabel("False Positive Rate")
260 plt.ylabel('True Positive Rate') 236 plt.ylabel("True Positive Rate")
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' 237 title = title or "Receiver Operating Characteristic (ROC) Curve"
262 plt.title(title) 238 plt.title(title)
263 folder = os.getcwd() 239 folder = os.getcwd()
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 240 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
265 os.rename(os.path.join(folder, "output.svg"), 241 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
266 os.path.join(folder, "output"))
267 242
268 243
269 def get_dataframe(file_path, plot_selection, header_name, column_name): 244 def get_dataframe(file_path, plot_selection, header_name, column_name):
270 header = 'infer' if plot_selection[header_name] else None 245 header = "infer" if plot_selection[header_name] else None
271 column_option = plot_selection[column_name]["selected_column_selector_option"] 246 column_option = plot_selection[column_name]["selected_column_selector_option"]
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: 247 if column_option in [
248 "by_index_number",
249 "all_but_by_index_number",
250 "by_header_name",
251 "all_but_by_header_name",
252 ]:
273 col = plot_selection[column_name]["col1"] 253 col = plot_selection[column_name]["col1"]
274 else: 254 else:
275 col = None 255 col = None
276 _, input_df = read_columns(file_path, c=col, 256 _, input_df = read_columns(file_path, c=col,
277 c_option=column_option, 257 c_option=column_option,
278 return_df=True, 258 return_df=True,
279 sep='\t', header=header, 259 sep='\t', header=header,
280 parse_dates=True) 260 parse_dates=True)
281 return input_df 261 return input_df
282 262
283 263
284 def main(inputs, infile_estimator=None, infile1=None, 264 def main(
285 infile2=None, outfile_result=None, 265 inputs,
286 outfile_object=None, groups=None, 266 infile_estimator=None,
287 ref_seq=None, intervals=None, 267 infile1=None,
288 targets=None, fasta_path=None, 268 infile2=None,
289 model_config=None, true_labels=None, 269 outfile_result=None,
290 predicted_labels=None, plot_color=None, 270 outfile_object=None,
291 title=None): 271 groups=None,
272 ref_seq=None,
273 intervals=None,
274 targets=None,
275 fasta_path=None,
276 model_config=None,
277 true_labels=None,
278 predicted_labels=None,
279 plot_color=None,
280 title=None,
281 ):
292 """ 282 """
293 Parameter 283 Parameter
294 --------- 284 ---------
295 inputs : str 285 inputs : str
296 File path to galaxy tool parameter 286 File path to galaxy tool parameter
339 Color of the confusion matrix heatmap 329 Color of the confusion matrix heatmap
340 330
341 title : str, default is None 331 title : str, default is None
342 Title of the confusion matrix heatmap 332 Title of the confusion matrix heatmap
343 """ 333 """
344 warnings.simplefilter('ignore') 334 warnings.simplefilter("ignore")
345 335
346 with open(inputs, 'r') as param_handler: 336 with open(inputs, "r") as param_handler:
347 params = json.load(param_handler) 337 params = json.load(param_handler)
348 338
349 title = params['plotting_selection']['title'].strip() 339 title = params["plotting_selection"]["title"].strip()
350 plot_type = params['plotting_selection']['plot_type'] 340 plot_type = params["plotting_selection"]["plot_type"]
351 plot_format = params['plotting_selection']['plot_format'] 341 plot_format = params["plotting_selection"]["plot_format"]
352 342
353 if plot_type == 'feature_importances': 343 if plot_type == "feature_importances":
354 with open(infile_estimator, 'rb') as estimator_handler: 344 with open(infile_estimator, "rb") as estimator_handler:
355 estimator = load_model(estimator_handler) 345 estimator = load_model(estimator_handler)
356 346
357 column_option = (params['plotting_selection'] 347 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"]
358 ['column_selector_options'] 348 if column_option in [
359 ['selected_column_selector_option']) 349 "by_index_number",
360 if column_option in ['by_index_number', 'all_but_by_index_number', 350 "all_but_by_index_number",
361 'by_header_name', 'all_but_by_header_name']: 351 "by_header_name",
362 c = (params['plotting_selection'] 352 "all_but_by_header_name",
363 ['column_selector_options']['col1']) 353 ]:
354 c = params["plotting_selection"]["column_selector_options"]["col1"]
364 else: 355 else:
365 c = None 356 c = None
366 357
367 _, input_df = read_columns(infile1, c=c, 358 _, input_df = read_columns(
368 c_option=column_option, 359 infile1,
369 return_df=True, 360 c=c,
370 sep='\t', header='infer', 361 c_option=column_option,
371 parse_dates=True) 362 return_df=True,
363 sep="\t",
364 header="infer",
365 parse_dates=True,
366 )
372 367
373 feature_names = input_df.columns.values 368 feature_names = input_df.columns.values
374 369
375 if isinstance(estimator, Pipeline): 370 if isinstance(estimator, Pipeline):
376 for st in estimator.steps[:-1]: 371 for st in estimator.steps[:-1]:
377 if isinstance(st[-1], SelectorMixin): 372 if isinstance(st[-1], SelectorMixin):
378 mask = st[-1].get_support() 373 mask = st[-1].get_support()
379 feature_names = feature_names[mask] 374 feature_names = feature_names[mask]
380 estimator = estimator.steps[-1][-1] 375 estimator = estimator.steps[-1][-1]
381 376
382 if hasattr(estimator, 'coef_'): 377 if hasattr(estimator, "coef_"):
383 coefs = estimator.coef_ 378 coefs = estimator.coef_
384 else: 379 else:
385 coefs = getattr(estimator, 'feature_importances_', None) 380 coefs = getattr(estimator, "feature_importances_", None)
386 if coefs is None: 381 if coefs is None:
387 raise RuntimeError('The classifier does not expose ' 382 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes")
388 '"coef_" or "feature_importances_" ' 383
389 'attributes') 384 threshold = params["plotting_selection"]["threshold"]
390
391 threshold = params['plotting_selection']['threshold']
392 if threshold is not None: 385 if threshold is not None:
393 mask = (coefs > threshold) | (coefs < -threshold) 386 mask = (coefs > threshold) | (coefs < -threshold)
394 coefs = coefs[mask] 387 coefs = coefs[mask]
395 feature_names = feature_names[mask] 388 feature_names = feature_names[mask]
396 389
397 # sort 390 # sort
398 indices = np.argsort(coefs)[::-1] 391 indices = np.argsort(coefs)[::-1]
399 392
400 trace = go.Bar(x=feature_names[indices], 393 trace = go.Bar(x=feature_names[indices], y=coefs[indices])
401 y=coefs[indices])
402 layout = go.Layout(title=title or "Feature Importances") 394 layout = go.Layout(title=title or "Feature Importances")
403 fig = go.Figure(data=[trace], layout=layout) 395 fig = go.Figure(data=[trace], layout=layout)
404 396
405 plotly.offline.plot(fig, filename="output.html", 397 plotly.offline.plot(fig, filename="output.html", auto_open=False)
406 auto_open=False)
407 # to be discovered by `from_work_dir` 398 # to be discovered by `from_work_dir`
408 os.rename('output.html', 'output') 399 os.rename("output.html", "output")
409 400
410 return 0 401 return 0
411 402
412 elif plot_type in ('pr_curve', 'roc_curve'): 403 elif plot_type in ("pr_curve", "roc_curve"):
413 df1 = pd.read_csv(infile1, sep='\t', header='infer') 404 df1 = pd.read_csv(infile1, sep="\t", header="infer")
414 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) 405 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)
415 406
416 minimum = params['plotting_selection']['report_minimum_n_positives'] 407 minimum = params["plotting_selection"]["report_minimum_n_positives"]
417 # filter out columns whose n_positives is beblow the threhold 408 # filter out columns whose n_positives is beblow the threhold
418 if minimum: 409 if minimum:
419 mask = df1.sum(axis=0) >= minimum 410 mask = df1.sum(axis=0) >= minimum
420 df1 = df1.loc[:, mask] 411 df1 = df1.loc[:, mask]
421 df2 = df2.loc[:, mask] 412 df2 = df2.loc[:, mask]
422 413
423 pos_label = params['plotting_selection']['pos_label'].strip() \ 414 pos_label = params["plotting_selection"]["pos_label"].strip() or None
424 or None 415
425 416 if plot_type == "pr_curve":
426 if plot_type == 'pr_curve': 417 if plot_format == "plotly_html":
427 if plot_format == 'plotly_html':
428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) 418 visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
429 else: 419 else:
430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) 420 visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
431 else: # 'roc_curve' 421 else: # 'roc_curve'
432 drop_intermediate = (params['plotting_selection'] 422 drop_intermediate = params["plotting_selection"]["drop_intermediate"]
433 ['drop_intermediate']) 423 if plot_format == "plotly_html":
434 if plot_format == 'plotly_html': 424 visualize_roc_curve_plotly(
435 visualize_roc_curve_plotly(df1, df2, pos_label, 425 df1,
436 drop_intermediate=drop_intermediate, 426 df2,
437 title=title) 427 pos_label,
428 drop_intermediate=drop_intermediate,
429 title=title,
430 )
438 else: 431 else:
439 visualize_roc_curve_matplotlib( 432 visualize_roc_curve_matplotlib(
440 df1, df2, pos_label, 433 df1,
434 df2,
435 pos_label,
441 drop_intermediate=drop_intermediate, 436 drop_intermediate=drop_intermediate,
442 title=title) 437 title=title,
438 )
443 439
444 return 0 440 return 0
445 441
446 elif plot_type == 'rfecv_gridscores': 442 elif plot_type == "rfecv_gridscores":
447 input_df = pd.read_csv(infile1, sep='\t', header='infer') 443 input_df = pd.read_csv(infile1, sep="\t", header="infer")
448 scores = input_df.iloc[:, 0] 444 scores = input_df.iloc[:, 0]
449 steps = params['plotting_selection']['steps'].strip() 445 steps = params["plotting_selection"]["steps"].strip()
450 steps = safe_eval(steps) 446 steps = safe_eval(steps)
451 447
452 data = go.Scatter( 448 data = go.Scatter(
453 x=list(range(len(scores))), 449 x=list(range(len(scores))),
454 y=scores, 450 y=scores,
455 text=[str(_) for _ in steps] if steps else None, 451 text=[str(_) for _ in steps] if steps else None,
456 mode='lines' 452 mode="lines",
457 ) 453 )
458 layout = go.Layout( 454 layout = go.Layout(
459 xaxis=dict(title="Number of features selected"), 455 xaxis=dict(title="Number of features selected"),
460 yaxis=dict(title="Cross validation score"), 456 yaxis=dict(title="Cross validation score"),
461 title=dict( 457 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"),
462 text=title or None, 458 font=dict(family="sans-serif", size=11),
463 x=0.5,
464 y=0.92,
465 xanchor='center',
466 yanchor='top'
467 ),
468 font=dict(
469 family="sans-serif",
470 size=11
471 ),
472 # control backgroud colors 459 # control backgroud colors
473 plot_bgcolor='rgba(255,255,255,0)' 460 plot_bgcolor="rgba(255,255,255,0)",
474 ) 461 )
475 """ 462 """
476 # legend=dict( 463 # legend=dict(
477 # x=0.95, 464 # x=0.95,
478 # y=0, 465 # y=0,
487 # borderwidth=2 474 # borderwidth=2
488 # ), 475 # ),
489 """ 476 """
490 477
491 fig = go.Figure(data=[data], layout=layout) 478 fig = go.Figure(data=[data], layout=layout)
492 plotly.offline.plot(fig, filename="output.html", 479 plotly.offline.plot(fig, filename="output.html", auto_open=False)
493 auto_open=False)
494 # to be discovered by `from_work_dir` 480 # to be discovered by `from_work_dir`
495 os.rename('output.html', 'output') 481 os.rename("output.html", "output")
496 482
497 return 0 483 return 0
498 484
499 elif plot_type == 'learning_curve': 485 elif plot_type == "learning_curve":
500 input_df = pd.read_csv(infile1, sep='\t', header='infer') 486 input_df = pd.read_csv(infile1, sep="\t", header="infer")
501 plot_std_err = params['plotting_selection']['plot_std_err'] 487 plot_std_err = params["plotting_selection"]["plot_std_err"]
502 data1 = go.Scatter( 488 data1 = go.Scatter(
503 x=input_df['train_sizes_abs'], 489 x=input_df["train_sizes_abs"],
504 y=input_df['mean_train_scores'], 490 y=input_df["mean_train_scores"],
505 error_y=dict( 491 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
506 array=input_df['std_train_scores'] 492 mode="lines",
507 ) if plot_std_err else None,
508 mode='lines',
509 name="Train Scores", 493 name="Train Scores",
510 ) 494 )
511 data2 = go.Scatter( 495 data2 = go.Scatter(
512 x=input_df['train_sizes_abs'], 496 x=input_df["train_sizes_abs"],
513 y=input_df['mean_test_scores'], 497 y=input_df["mean_test_scores"],
514 error_y=dict( 498 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
515 array=input_df['std_test_scores'] 499 mode="lines",
516 ) if plot_std_err else None,
517 mode='lines',
518 name="Test Scores", 500 name="Test Scores",
519 ) 501 )
520 layout = dict( 502 layout = dict(
521 xaxis=dict( 503 xaxis=dict(title="No. of samples"),
522 title='No. of samples' 504 yaxis=dict(title="Performance Score"),
523 ),
524 yaxis=dict(
525 title='Performance Score'
526 ),
527 # modify these configurations to customize image 505 # modify these configurations to customize image
528 title=dict( 506 title=dict(
529 text=title or 'Learning Curve', 507 text=title or "Learning Curve",
530 x=0.5, 508 x=0.5,
531 y=0.92, 509 y=0.92,
532 xanchor='center', 510 xanchor="center",
533 yanchor='top' 511 yanchor="top",
534 ), 512 ),
535 font=dict( 513 font=dict(family="sans-serif", size=11),
536 family="sans-serif",
537 size=11
538 ),
539 # control backgroud colors 514 # control backgroud colors
540 plot_bgcolor='rgba(255,255,255,0)' 515 plot_bgcolor="rgba(255,255,255,0)",
541 ) 516 )
542 """ 517 """
543 # legend=dict( 518 # legend=dict(
544 # x=0.95, 519 # x=0.95,
545 # y=0, 520 # y=0,
554 # borderwidth=2 529 # borderwidth=2
555 # ), 530 # ),
556 """ 531 """
557 532
558 fig = go.Figure(data=[data1, data2], layout=layout) 533 fig = go.Figure(data=[data1, data2], layout=layout)
559 plotly.offline.plot(fig, filename="output.html", 534 plotly.offline.plot(fig, filename="output.html", auto_open=False)
560 auto_open=False)
561 # to be discovered by `from_work_dir` 535 # to be discovered by `from_work_dir`
562 os.rename('output.html', 'output') 536 os.rename("output.html", "output")
563 537
564 return 0 538 return 0
565 539
566 elif plot_type == 'keras_plot_model': 540 elif plot_type == "keras_plot_model":
567 with open(model_config, 'r') as f: 541 with open(model_config, "r") as f:
568 model_str = f.read() 542 model_str = f.read()
569 model = model_from_json(model_str) 543 model = model_from_json(model_str)
570 plot_model(model, to_file="output.png") 544 plot_model(model, to_file="output.png")
571 os.rename('output.png', 'output') 545 os.rename("output.png", "output")
572 546
573 return 0 547 return 0
574 548
575 elif plot_type == 'classification_confusion_matrix': 549 elif plot_type == "classification_confusion_matrix":
576 plot_selection = params["plotting_selection"] 550 plot_selection = params["plotting_selection"]
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") 551 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true")
578 header_predicted = 'infer' if plot_selection["header_predicted"] else None 552 header_predicted = "infer" if plot_selection["header_predicted"] else None
579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) 553 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted)
580 true_classes = input_true.iloc[:, -1].copy() 554 true_classes = input_true.iloc[:, -1].copy()
581 predicted_classes = input_predicted.iloc[:, -1].copy() 555 predicted_classes = input_predicted.iloc[:, -1].copy()
582 axis_labels = list(set(true_classes)) 556 axis_labels = list(set(true_classes))
583 c_matrix = confusion_matrix(true_classes, predicted_classes) 557 c_matrix = confusion_matrix(true_classes, predicted_classes)
584 fig, ax = plt.subplots(figsize=(7, 7)) 558 fig, ax = plt.subplots(figsize=(7, 7))
585 im = plt.imshow(c_matrix, cmap=plot_color) 559 im = plt.imshow(c_matrix, cmap=plot_color)
586 for i in range(len(c_matrix)): 560 for i in range(len(c_matrix)):
587 for j in range(len(c_matrix)): 561 for j in range(len(c_matrix)):
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") 562 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
589 ax.set_ylabel('True class labels') 563 ax.set_ylabel("True class labels")
590 ax.set_xlabel('Predicted class labels') 564 ax.set_xlabel("Predicted class labels")
591 ax.set_title(title) 565 ax.set_title(title)
592 ax.set_xticks(axis_labels) 566 ax.set_xticks(axis_labels)
593 ax.set_yticks(axis_labels) 567 ax.set_yticks(axis_labels)
594 fig.colorbar(im, ax=ax) 568 fig.colorbar(im, ax=ax)
595 fig.tight_layout() 569 fig.tight_layout()
596 plt.savefig("output.png", dpi=125) 570 plt.savefig("output.png", dpi=125)
597 os.rename('output.png', 'output') 571 os.rename("output.png", "output")
598 572
599 return 0 573 return 0
600 574
601 # save pdf file to disk 575 # save pdf file to disk
602 # fig.write_image("image.pdf", format='pdf') 576 # fig.write_image("image.pdf", format='pdf')
603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 577 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
604 578
605 579
606 if __name__ == '__main__': 580 if __name__ == "__main__":
607 aparser = argparse.ArgumentParser() 581 aparser = argparse.ArgumentParser()
608 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 582 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
609 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 583 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
610 aparser.add_argument("-X", "--infile1", dest="infile1") 584 aparser.add_argument("-X", "--infile1", dest="infile1")
611 aparser.add_argument("-y", "--infile2", dest="infile2") 585 aparser.add_argument("-y", "--infile2", dest="infile2")
621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") 595 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
622 aparser.add_argument("-pc", "--plot_color", dest="plot_color") 596 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
623 aparser.add_argument("-pt", "--title", dest="title") 597 aparser.add_argument("-pt", "--title", dest="title")
624 args = aparser.parse_args() 598 args = aparser.parse_args()
625 599
626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 600 main(
627 args.outfile_result, outfile_object=args.outfile_object, 601 args.inputs,
628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, 602 args.infile_estimator,
629 targets=args.targets, fasta_path=args.fasta_path, 603 args.infile1,
630 model_config=args.model_config, true_labels=args.true_labels, 604 args.infile2,
631 predicted_labels=args.predicted_labels, 605 args.outfile_result,
632 plot_color=args.plot_color, 606 outfile_object=args.outfile_object,
633 title=args.title) 607 groups=args.groups,
608 ref_seq=args.ref_seq,
609 intervals=args.intervals,
610 targets=args.targets,
611 fasta_path=args.fasta_path,
612 model_config=args.model_config,
613 true_labels=args.true_labels,
614 predicted_labels=args.predicted_labels,
615 plot_color=args.plot_color,
616 title=args.title,
617 )