diff plot_ml_performance.py @ 0:4fac53da862f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
author bgruening
date Thu, 11 Oct 2018 14:37:54 -0400
parents
children 85da91bbdbfb
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plot_ml_performance.py	Thu Oct 11 14:37:54 2018 -0400
@@ -0,0 +1,152 @@
+import argparse
+import pandas as pd
+import plotly
+import pickle
+import plotly.graph_objs as go
+from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc
+from sklearn.preprocessing import label_binarize
+
+
+def main(infile_input, infile_output, infile_trained_model):
+    """
+    Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots
+    Args:
+        infile_input: str, input tabular file with true labels
+        infile_output: str, input tabular file with predicted labels
+        infile_trained_model: str, input trained model file (zip)
+    """
+
+    df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True)
+    df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True)
+    true_labels = df_input.iloc[:, -1].copy()
+    predicted_labels = df_output.iloc[:, -1].copy()
+    axis_labels = list(set(true_labels))
+    c_matrix = confusion_matrix(true_labels, predicted_labels)
+    data = [
+        go.Heatmap(
+            z=c_matrix,
+            x=axis_labels,
+            y=axis_labels,
+            colorscale='Portland',
+        )
+    ]
+
+    layout = go.Layout(
+        title='Confusion Matrix between true and predicted class labels',
+        xaxis=dict(title='True class labels'),
+        yaxis=dict(title='Predicted class labels')
+    )
+
+    fig = go.Figure(data=data, layout=layout)
+    plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
+
+    # plot precision, recall and f_score for each class label
+    precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels)
+
+    trace_precision = go.Scatter(
+        x=axis_labels,
+        y=precision,
+        mode='lines+markers',
+        name='Precision'
+    )
+
+    trace_recall = go.Scatter(
+        x=axis_labels,
+        y=recall,
+        mode='lines+markers',
+        name='Recall'
+    )
+
+    trace_fscore = go.Scatter(
+        x=axis_labels,
+        y=f_score,
+        mode='lines+markers',
+        name='F-score'
+    )
+
+    layout_prf = go.Layout(
+        title='Precision, recall and f-score of true and predicted class labels',
+        xaxis=dict(title='Class labels'),
+        yaxis=dict(title='Precision, recall and f-score')
+    )
+
+    data_prf = [trace_precision, trace_recall, trace_fscore]
+    fig_prf = go.Figure(data=data_prf, layout=layout_prf)
+    plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
+
+    # plot roc and auc curves for different classes
+    with open(infile_trained_model, 'rb') as model_file:
+        model = pickle.load(model_file)
+
+    # remove the last column (label column)
+    test_data = df_input.iloc[:, :-1]
+    model_items = dir(model)
+
+    try:
+        # find the probability estimating method
+        if 'predict_proba' in model_items:
+            y_score = model.predict_proba(test_data)
+        elif 'decision_function' in model_items:
+            y_score = model.decision_function(test_data)
+
+        true_labels_list = true_labels.tolist()
+        one_hot_labels = label_binarize(true_labels_list, classes=axis_labels)
+        data_roc = list()
+
+        if len(axis_labels) > 2:
+            fpr = dict()
+            tpr = dict()
+            roc_auc = dict()
+            for i in axis_labels:
+                fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i])
+                roc_auc[i] = auc(fpr[i], tpr[i])
+            for i in range(len(axis_labels)):
+                trace = go.Scatter(
+                    x=fpr[i],
+                    y=tpr[i],
+                    mode='lines+markers',
+                    name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i])
+                )
+                data_roc.append(trace)
+        else:
+            try:
+                y_score_binary = y_score[:, 1]
+            except:
+                y_score_binary = y_score
+            fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
+            roc_auc = auc(fpr, tpr)
+            trace = go.Scatter(
+                x=fpr,
+                y=tpr,
+                mode='lines+markers',
+                name='ROC curve (AUC = {0:0.2f})'.format(roc_auc)
+            )
+            data_roc.append(trace)
+
+        trace_diag = go.Scatter(
+            x=[0, 1],
+            y=[0, 1],
+            mode='lines',
+            name='Chance'
+        )
+        data_roc.append(trace_diag)
+        layout_roc = go.Layout(
+            title='Receiver operating characteristics (ROC) and area under curve (AUC)',
+            xaxis=dict(title='False positive rate'),
+            yaxis=dict(title='True positive rate')
+        )
+
+        fig_roc = go.Figure(data=data_roc, layout=layout_roc)
+        plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
+
+    except Exception as exp:
+        pass
+
+
+if __name__ == "__main__":
+    aparser = argparse.ArgumentParser()
+    aparser.add_argument("-i", "--input", dest="infile_input", required=True)
+    aparser.add_argument("-j", "--output", dest="infile_output", required=True)
+    aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True)
+    args = aparser.parse_args()
+    main(args.infile_input, args.infile_output, args.infile_trained_model)