annotate plot_ml_performance.py @ 4:f234e2e59d76 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author bgruening
date Wed, 07 Aug 2024 10:20:17 +0000
parents 1c5dcef5ce0f
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
1 import argparse
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
2
4
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
3 import matplotlib.pyplot as plt
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
4 import pandas as pd
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
5 import plotly
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
6 import plotly.graph_objs as go
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
7 from galaxy_ml.model_persist import load_model_from_h5
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
8 from galaxy_ml.utils import clean_params
4
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
9 from sklearn.metrics import (
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
10 auc,
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
11 confusion_matrix,
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
12 precision_recall_fscore_support,
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
13 roc_curve,
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
14 )
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
15 from sklearn.preprocessing import label_binarize
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
16
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
17
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
18 def main(infile_input, infile_output, infile_trained_model):
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
19 """
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
21 Args:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
22 infile_input: str, input tabular file with true labels
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
23 infile_output: str, input tabular file with predicted labels
4
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
24 infile_trained_model: str, input trained model file (h5mlm)
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
25 """
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
26
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True)
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
29 true_labels = df_input.iloc[:, -1].copy()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
30 predicted_labels = df_output.iloc[:, -1].copy()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
31 axis_labels = list(set(true_labels))
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
32 c_matrix = confusion_matrix(true_labels, predicted_labels)
4
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
33 fig, ax = plt.subplots(figsize=(7, 7))
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
34 im = plt.imshow(c_matrix, cmap="viridis")
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
35 # add number of samples to each cell of confusion matrix plot
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
36 for i in range(len(c_matrix)):
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
37 for j in range(len(c_matrix)):
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
39 ax.set_ylabel("True class labels")
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
40 ax.set_xlabel("Predicted class labels")
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
41 ax.set_title("Confusion Matrix between true and predicted class labels")
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
42 ax.set_xticks(axis_labels)
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
43 ax.set_yticks(axis_labels)
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
44 fig.colorbar(im, ax=ax)
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
45 fig.tight_layout()
f234e2e59d76 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
bgruening
parents: 3
diff changeset
46 plt.savefig("output_confusion.png", dpi=120)
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
47
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
48 # plot precision, recall and f_score for each class label
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
49 precision, recall, f_score, _ = precision_recall_fscore_support(
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
50 true_labels, predicted_labels
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
51 )
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
52
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
53 trace_precision = go.Scatter(
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
54 x=axis_labels, y=precision, mode="lines+markers", name="Precision"
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
55 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
56
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
57 trace_recall = go.Scatter(
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
58 x=axis_labels, y=recall, mode="lines+markers", name="Recall"
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
59 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
60
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
61 trace_fscore = go.Scatter(
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
62 x=axis_labels, y=f_score, mode="lines+markers", name="F-score"
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
63 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
64
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
65 layout_prf = go.Layout(
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
66 title="Precision, recall and f-score of true and predicted class labels",
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
67 xaxis=dict(title="Class labels"),
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
68 yaxis=dict(title="Precision, recall and f-score"),
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
69 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
70
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
71 data_prf = [trace_precision, trace_recall, trace_fscore]
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
72 fig_prf = go.Figure(data=data_prf, layout=layout_prf)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
73 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
74
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
75 # plot roc and auc curves for different classes
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
76 classifier_object = load_model_from_h5(infile_trained_model)
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
77 model = clean_params(classifier_object)
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
78
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
79 # remove the last column (label column)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
80 test_data = df_input.iloc[:, :-1]
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
81 model_items = dir(model)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
82
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
83 try:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
84 # find the probability estimating method
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
85 if "predict_proba" in model_items:
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
86 y_score = model.predict_proba(test_data)
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
87 elif "decision_function" in model_items:
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
88 y_score = model.decision_function(test_data)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
89
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
90 true_labels_list = true_labels.tolist()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
91 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
92 data_roc = list()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
93
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
94 if len(axis_labels) > 2:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
95 fpr = dict()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
96 tpr = dict()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
97 roc_auc = dict()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
98 for i in axis_labels:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
99 fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i])
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
100 roc_auc[i] = auc(fpr[i], tpr[i])
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
101 for i in range(len(axis_labels)):
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
102 trace = go.Scatter(
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
103 x=fpr[i],
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
104 y=tpr[i],
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
105 mode="lines+markers",
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
106 name="ROC curve of class {0} (AUC = {1:0.2f})".format(
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
107 i, roc_auc[i]
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
108 ),
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
109 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
110 data_roc.append(trace)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
111 else:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
112 try:
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
113 y_score_binary = y_score[:, 1]
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
114 except Exception:
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
115 y_score_binary = y_score
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
117 roc_auc = auc(fpr, tpr)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
118 trace = go.Scatter(
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
119 x=fpr,
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
120 y=tpr,
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
121 mode="lines+markers",
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
122 name="ROC curve (AUC = {0:0.2f})".format(roc_auc),
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
123 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
124 data_roc.append(trace)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
125
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
126 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance")
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
127 data_roc.append(trace_diag)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
128 layout_roc = go.Layout(
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
129 title="Receiver operating characteristics (ROC) and area under curve (AUC)",
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
130 xaxis=dict(title="False positive rate"),
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
131 yaxis=dict(title="True positive rate"),
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
132 )
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
133
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
134 fig_roc = go.Figure(data=data_roc, layout=layout_roc)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
135 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
136
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
137 except Exception as exp:
3
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
138 print(
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
139 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format(
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
140 exp
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
141 )
1c5dcef5ce0f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
bgruening
parents: 2
diff changeset
142 )
0
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
143
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
144
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
145 if __name__ == "__main__":
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
146 aparser = argparse.ArgumentParser()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
147 aparser.add_argument("-i", "--input", dest="infile_input", required=True)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
148 aparser.add_argument("-j", "--output", dest="infile_output", required=True)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
149 aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True)
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
150 args = aparser.parse_args()
4fac53da862f planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
bgruening
parents:
diff changeset
151 main(args.infile_input, args.infile_output, args.infile_trained_model)