Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 4:f234e2e59d76 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
author | bgruening |
---|---|
date | Wed, 07 Aug 2024 10:20:17 +0000 |
parents | 1c5dcef5ce0f |
children |
comparison
equal
deleted
inserted
replaced
3:1c5dcef5ce0f | 4:f234e2e59d76 |
---|---|
1 import argparse | 1 import argparse |
2 | 2 |
3 import matplotlib.pyplot as plt | |
3 import pandas as pd | 4 import pandas as pd |
4 import plotly | 5 import plotly |
5 import plotly.graph_objs as go | 6 import plotly.graph_objs as go |
6 from galaxy_ml.model_persist import load_model_from_h5 | 7 from galaxy_ml.model_persist import load_model_from_h5 |
7 from galaxy_ml.utils import clean_params | 8 from galaxy_ml.utils import clean_params |
8 from sklearn.metrics import (auc, confusion_matrix, | 9 from sklearn.metrics import ( |
9 precision_recall_fscore_support, roc_curve) | 10 auc, |
11 confusion_matrix, | |
12 precision_recall_fscore_support, | |
13 roc_curve, | |
14 ) | |
10 from sklearn.preprocessing import label_binarize | 15 from sklearn.preprocessing import label_binarize |
11 | 16 |
12 | 17 |
13 def main(infile_input, infile_output, infile_trained_model): | 18 def main(infile_input, infile_output, infile_trained_model): |
14 """ | 19 """ |
15 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots |
16 Args: | 21 Args: |
17 infile_input: str, input tabular file with true labels | 22 infile_input: str, input tabular file with true labels |
18 infile_output: str, input tabular file with predicted labels | 23 infile_output: str, input tabular file with predicted labels |
19 infile_trained_model: str, input trained model file (zip) | 24 infile_trained_model: str, input trained model file (h5mlm) |
20 """ | 25 """ |
21 | 26 |
22 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) | 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
23 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) | 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
24 true_labels = df_input.iloc[:, -1].copy() | 29 true_labels = df_input.iloc[:, -1].copy() |
25 predicted_labels = df_output.iloc[:, -1].copy() | 30 predicted_labels = df_output.iloc[:, -1].copy() |
26 axis_labels = list(set(true_labels)) | 31 axis_labels = list(set(true_labels)) |
27 c_matrix = confusion_matrix(true_labels, predicted_labels) | 32 c_matrix = confusion_matrix(true_labels, predicted_labels) |
28 data = [ | 33 fig, ax = plt.subplots(figsize=(7, 7)) |
29 go.Heatmap( | 34 im = plt.imshow(c_matrix, cmap="viridis") |
30 z=c_matrix, | 35 # add number of samples to each cell of confusion matrix plot |
31 x=axis_labels, | 36 for i in range(len(c_matrix)): |
32 y=axis_labels, | 37 for j in range(len(c_matrix)): |
33 colorscale="Portland", | 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
34 ) | 39 ax.set_ylabel("True class labels") |
35 ] | 40 ax.set_xlabel("Predicted class labels") |
36 | 41 ax.set_title("Confusion Matrix between true and predicted class labels") |
37 layout = go.Layout( | 42 ax.set_xticks(axis_labels) |
38 title="Confusion Matrix between true and predicted class labels", | 43 ax.set_yticks(axis_labels) |
39 xaxis=dict(title="Predicted class labels"), | 44 fig.colorbar(im, ax=ax) |
40 yaxis=dict(title="True class labels"), | 45 fig.tight_layout() |
41 ) | 46 plt.savefig("output_confusion.png", dpi=120) |
42 | |
43 fig = go.Figure(data=data, layout=layout) | |
44 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
45 | 47 |
46 # plot precision, recall and f_score for each class label | 48 # plot precision, recall and f_score for each class label |
47 precision, recall, f_score, _ = precision_recall_fscore_support( | 49 precision, recall, f_score, _ = precision_recall_fscore_support( |
48 true_labels, predicted_labels | 50 true_labels, predicted_labels |
49 ) | 51 ) |