diff plot_ml_performance.py @ 3:1c5dcef5ce0f draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 271a4454eea5902e29da4b8dfa7b9124fefac6bc
author bgruening
date Tue, 07 May 2024 14:11:16 +0000
parents 62e3a4e8c54c
children
line wrap: on
line diff
--- a/plot_ml_performance.py	Thu Jan 16 13:49:49 2020 -0500
+++ b/plot_ml_performance.py	Tue May 07 14:11:16 2024 +0000
@@ -1,9 +1,12 @@
 import argparse
+
 import pandas as pd
 import plotly
-import pickle
 import plotly.graph_objs as go
-from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import clean_params
+from sklearn.metrics import (auc, confusion_matrix,
+                             precision_recall_fscore_support, roc_curve)
 from sklearn.preprocessing import label_binarize
 
 
@@ -16,8 +19,8 @@
         infile_trained_model: str, input trained model file (zip)
     """
 
-    df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True)
-    df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True)
+    df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True)
+    df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True)
     true_labels = df_input.iloc[:, -1].copy()
     predicted_labels = df_output.iloc[:, -1].copy()
     axis_labels = list(set(true_labels))
@@ -27,47 +30,40 @@
             z=c_matrix,
             x=axis_labels,
             y=axis_labels,
-            colorscale='Portland',
+            colorscale="Portland",
         )
     ]
 
     layout = go.Layout(
-        title='Confusion Matrix between true and predicted class labels',
-        xaxis=dict(title='Predicted class labels'),
-        yaxis=dict(title='True class labels')
+        title="Confusion Matrix between true and predicted class labels",
+        xaxis=dict(title="Predicted class labels"),
+        yaxis=dict(title="True class labels"),
     )
 
     fig = go.Figure(data=data, layout=layout)
     plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
 
     # plot precision, recall and f_score for each class label
-    precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels)
+    precision, recall, f_score, _ = precision_recall_fscore_support(
+        true_labels, predicted_labels
+    )
 
     trace_precision = go.Scatter(
-        x=axis_labels,
-        y=precision,
-        mode='lines+markers',
-        name='Precision'
+        x=axis_labels, y=precision, mode="lines+markers", name="Precision"
     )
 
     trace_recall = go.Scatter(
-        x=axis_labels,
-        y=recall,
-        mode='lines+markers',
-        name='Recall'
+        x=axis_labels, y=recall, mode="lines+markers", name="Recall"
     )
 
     trace_fscore = go.Scatter(
-        x=axis_labels,
-        y=f_score,
-        mode='lines+markers',
-        name='F-score'
+        x=axis_labels, y=f_score, mode="lines+markers", name="F-score"
     )
 
     layout_prf = go.Layout(
-        title='Precision, recall and f-score of true and predicted class labels',
-        xaxis=dict(title='Class labels'),
-        yaxis=dict(title='Precision, recall and f-score')
+        title="Precision, recall and f-score of true and predicted class labels",
+        xaxis=dict(title="Class labels"),
+        yaxis=dict(title="Precision, recall and f-score"),
     )
 
     data_prf = [trace_precision, trace_recall, trace_fscore]
@@ -75,8 +71,8 @@
     plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
 
     # plot roc and auc curves for different classes
-    with open(infile_trained_model, 'rb') as model_file:
-        model = pickle.load(model_file)
+    classifier_object = load_model_from_h5(infile_trained_model)
+    model = clean_params(classifier_object)
 
     # remove the last column (label column)
     test_data = df_input.iloc[:, :-1]
@@ -84,9 +80,9 @@
 
     try:
         # find the probability estimating method
-        if 'predict_proba' in model_items:
+        if "predict_proba" in model_items:
             y_score = model.predict_proba(test_data)
-        elif 'decision_function' in model_items:
+        elif "decision_function" in model_items:
             y_score = model.decision_function(test_data)
 
         true_labels_list = true_labels.tolist()
@@ -104,43 +100,44 @@
                 trace = go.Scatter(
                     x=fpr[i],
                     y=tpr[i],
-                    mode='lines+markers',
-                    name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i])
+                    mode="lines+markers",
+                    name="ROC curve of class {0} (AUC = {1:0.2f})".format(
+                        i, roc_auc[i]
+                    ),
                 )
                 data_roc.append(trace)
         else:
             try:
                 y_score_binary = y_score[:, 1]
-            except:
+            except Exception:
                 y_score_binary = y_score
             fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
             roc_auc = auc(fpr, tpr)
             trace = go.Scatter(
                 x=fpr,
                 y=tpr,
-                mode='lines+markers',
-                name='ROC curve (AUC = {0:0.2f})'.format(roc_auc)
+                mode="lines+markers",
+                name="ROC curve (AUC = {0:0.2f})".format(roc_auc),
             )
             data_roc.append(trace)
 
-        trace_diag = go.Scatter(
-            x=[0, 1],
-            y=[0, 1],
-            mode='lines',
-            name='Chance'
-        )
+        trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance")
         data_roc.append(trace_diag)
         layout_roc = go.Layout(
-            title='Receiver operating characteristics (ROC) and area under curve (AUC)',
-            xaxis=dict(title='False positive rate'),
-            yaxis=dict(title='True positive rate')
+            title="Receiver operating characteristics (ROC) and area under curve (AUC)",
+            xaxis=dict(title="False positive rate"),
+            yaxis=dict(title="True positive rate"),
         )
 
         fig_roc = go.Figure(data=data_roc, layout=layout_roc)
         plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
 
     except Exception as exp:
-        print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp))
+        print(
+            "Plotting the ROC-AUC graph failed. This exception was raised: {}".format(
+                exp
+            )
+        )
 
 
 if __name__ == "__main__":