comparison ml_visualization_ex.py @ 21:14fa42b095c4 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:53:41 +0000
parents a5665e1b06b0
children b878e4cdd63a
comparison
equal deleted inserted replaced
20:a5665e1b06b0 21:14fa42b095c4
11 import plotly.graph_objs as go 11 import plotly.graph_objs as go
12 from galaxy_ml.utils import load_model, read_columns, SafeEval 12 from galaxy_ml.utils import load_model, read_columns, SafeEval
13 from keras.models import model_from_json 13 from keras.models import model_from_json
14 from keras.utils import plot_model 14 from keras.utils import plot_model
15 from sklearn.feature_selection.base import SelectorMixin 15 from sklearn.feature_selection.base import SelectorMixin
16 from sklearn.metrics import auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve 16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix,
17 precision_recall_curve, roc_curve)
17 from sklearn.pipeline import Pipeline 18 from sklearn.pipeline import Pipeline
18
19 19
20 safe_eval = SafeEval() 20 safe_eval = SafeEval()
21 21
22 # plotly default colors 22 # plotly default colors
23 default_colors = [ 23 default_colors = [
49 data = [] 49 data = []
50 for idx in range(df1.shape[1]): 50 for idx in range(df1.shape[1]):
51 y_true = df1.iloc[:, idx].values 51 y_true = df1.iloc[:, idx].values
52 y_score = df2.iloc[:, idx].values 52 y_score = df2.iloc[:, idx].values
53 53
54 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) 54 precision, recall, _ = precision_recall_curve(
55 y_true, y_score, pos_label=pos_label
56 )
55 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) 57 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
56 58
57 trace = go.Scatter( 59 trace = go.Scatter(
58 x=recall, 60 x=recall,
59 y=precision, 61 y=precision,
109 111
110 for idx in range(df1.shape[1]): 112 for idx in range(df1.shape[1]):
111 y_true = df1.iloc[:, idx].values 113 y_true = df1.iloc[:, idx].values
112 y_score = df2.iloc[:, idx].values 114 y_score = df2.iloc[:, idx].values
113 115
114 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) 116 precision, recall, _ = precision_recall_curve(
117 y_true, y_score, pos_label=pos_label
118 )
115 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) 119 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
116 120
117 plt.step( 121 plt.step(
118 recall, 122 recall,
119 precision, 123 precision,
153 data = [] 157 data = []
154 for idx in range(df1.shape[1]): 158 for idx in range(df1.shape[1]):
155 y_true = df1.iloc[:, idx].values 159 y_true = df1.iloc[:, idx].values
156 y_score = df2.iloc[:, idx].values 160 y_score = df2.iloc[:, idx].values
157 161
158 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) 162 fpr, tpr, _ = roc_curve(
163 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
164 )
159 roc_auc = auc(fpr, tpr) 165 roc_auc = auc(fpr, tpr)
160 166
161 trace = go.Scatter( 167 trace = go.Scatter(
162 x=fpr, 168 x=fpr,
163 y=tpr, 169 y=tpr,
166 name="%s (area = %.3f)" % (idx, roc_auc), 172 name="%s (area = %.3f)" % (idx, roc_auc),
167 ) 173 )
168 data.append(trace) 174 data.append(trace)
169 175
170 layout = go.Layout( 176 layout = go.Layout(
171 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1), 177 xaxis=dict(
178 title="False Positive Rate", linecolor="lightslategray", linewidth=1
179 ),
172 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), 180 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
173 title=dict( 181 title=dict(
174 text=title or "Receiver Operating Characteristic (ROC) Curve", 182 text=title or "Receiver Operating Characteristic (ROC) Curve",
175 x=0.5, 183 x=0.5,
176 y=0.92, 184 y=0.92,
202 plotly.offline.plot(fig, filename="output.html", auto_open=False) 210 plotly.offline.plot(fig, filename="output.html", auto_open=False)
203 # to be discovered by `from_work_dir` 211 # to be discovered by `from_work_dir`
204 os.rename("output.html", "output") 212 os.rename("output.html", "output")
205 213
206 214
207 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None): 215 def visualize_roc_curve_matplotlib(
216 df1, df2, pos_label, drop_intermediate=True, title=None
217 ):
208 """visualize roc-curve using matplotlib and output svg image""" 218 """visualize roc-curve using matplotlib and output svg image"""
209 backend = matplotlib.get_backend() 219 backend = matplotlib.get_backend()
210 if "inline" not in backend: 220 if "inline" not in backend:
211 matplotlib.use("SVG") 221 matplotlib.use("SVG")
212 plt.style.use("seaborn-colorblind") 222 plt.style.use("seaborn-colorblind")
214 224
215 for idx in range(df1.shape[1]): 225 for idx in range(df1.shape[1]):
216 y_true = df1.iloc[:, idx].values 226 y_true = df1.iloc[:, idx].values
217 y_score = df2.iloc[:, idx].values 227 y_score = df2.iloc[:, idx].values
218 228
219 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) 229 fpr, tpr, _ = roc_curve(
230 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
231 )
220 roc_auc = auc(fpr, tpr) 232 roc_auc = auc(fpr, tpr)
221 233
222 plt.step( 234 plt.step(
223 fpr, 235 fpr,
224 tpr, 236 tpr,
251 "all_but_by_header_name", 263 "all_but_by_header_name",
252 ]: 264 ]:
253 col = plot_selection[column_name]["col1"] 265 col = plot_selection[column_name]["col1"]
254 else: 266 else:
255 col = None 267 col = None
256 _, input_df = read_columns(file_path, c=col, 268 _, input_df = read_columns(
257 c_option=column_option, 269 file_path,
258 return_df=True, 270 c=col,
259 sep='\t', header=header, 271 c_option=column_option,
260 parse_dates=True) 272 return_df=True,
273 sep="\t",
274 header=header,
275 parse_dates=True,
276 )
261 return input_df 277 return input_df
262 278
263 279
264 def main( 280 def main(
265 inputs, 281 inputs,
342 358
343 if plot_type == "feature_importances": 359 if plot_type == "feature_importances":
344 with open(infile_estimator, "rb") as estimator_handler: 360 with open(infile_estimator, "rb") as estimator_handler:
345 estimator = load_model(estimator_handler) 361 estimator = load_model(estimator_handler)
346 362
347 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"] 363 column_option = params["plotting_selection"]["column_selector_options"][
364 "selected_column_selector_option"
365 ]
348 if column_option in [ 366 if column_option in [
349 "by_index_number", 367 "by_index_number",
350 "all_but_by_index_number", 368 "all_but_by_index_number",
351 "by_header_name", 369 "by_header_name",
352 "all_but_by_header_name", 370 "all_but_by_header_name",
377 if hasattr(estimator, "coef_"): 395 if hasattr(estimator, "coef_"):
378 coefs = estimator.coef_ 396 coefs = estimator.coef_
379 else: 397 else:
380 coefs = getattr(estimator, "feature_importances_", None) 398 coefs = getattr(estimator, "feature_importances_", None)
381 if coefs is None: 399 if coefs is None:
382 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes") 400 raise RuntimeError(
401 "The classifier does not expose "
402 '"coef_" or "feature_importances_" '
403 "attributes"
404 )
383 405
384 threshold = params["plotting_selection"]["threshold"] 406 threshold = params["plotting_selection"]["threshold"]
385 if threshold is not None: 407 if threshold is not None:
386 mask = (coefs > threshold) | (coefs < -threshold) 408 mask = (coefs > threshold) | (coefs < -threshold)
387 coefs = coefs[mask] 409 coefs = coefs[mask]
452 mode="lines", 474 mode="lines",
453 ) 475 )
454 layout = go.Layout( 476 layout = go.Layout(
455 xaxis=dict(title="Number of features selected"), 477 xaxis=dict(title="Number of features selected"),
456 yaxis=dict(title="Cross validation score"), 478 yaxis=dict(title="Cross validation score"),
457 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"), 479 title=dict(
480 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"
481 ),
458 font=dict(family="sans-serif", size=11), 482 font=dict(family="sans-serif", size=11),
459 # control backgroud colors 483 # control backgroud colors
460 plot_bgcolor="rgba(255,255,255,0)", 484 plot_bgcolor="rgba(255,255,255,0)",
461 ) 485 )
462 """ 486 """
546 570
547 return 0 571 return 0
548 572
549 elif plot_type == "classification_confusion_matrix": 573 elif plot_type == "classification_confusion_matrix":
550 plot_selection = params["plotting_selection"] 574 plot_selection = params["plotting_selection"]
551 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") 575 input_true = get_dataframe(
576 true_labels, plot_selection, "header_true", "column_selector_options_true"
577 )
552 header_predicted = "infer" if plot_selection["header_predicted"] else None 578 header_predicted = "infer" if plot_selection["header_predicted"] else None
553 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted) 579 input_predicted = pd.read_csv(
580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted
581 )
554 true_classes = input_true.iloc[:, -1].copy() 582 true_classes = input_true.iloc[:, -1].copy()
555 predicted_classes = input_predicted.iloc[:, -1].copy() 583 predicted_classes = input_predicted.iloc[:, -1].copy()
556 axis_labels = list(set(true_classes)) 584 axis_labels = list(set(true_classes))
557 c_matrix = confusion_matrix(true_classes, predicted_classes) 585 c_matrix = confusion_matrix(true_classes, predicted_classes)
558 fig, ax = plt.subplots(figsize=(7, 7)) 586 fig, ax = plt.subplots(figsize=(7, 7))