comparison plot_ml_performance.py @ 0:4fac53da862f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit 8496ba724e35ba551172ea975b0fed091d4bbe88
author bgruening
date Thu, 11 Oct 2018 14:37:54 -0400
parents
children 85da91bbdbfb
comparison
equal deleted inserted replaced
-1:000000000000 0:4fac53da862f
1 import argparse
2 import pandas as pd
3 import plotly
4 import pickle
5 import plotly.graph_objs as go
6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc
7 from sklearn.preprocessing import label_binarize
8
9
10 def main(infile_input, infile_output, infile_trained_model):
11 """
12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots
13 Args:
14 infile_input: str, input tabular file with true labels
15 infile_output: str, input tabular file with predicted labels
16 infile_trained_model: str, input trained model file (zip)
17 """
18
19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True)
20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True)
21 true_labels = df_input.iloc[:, -1].copy()
22 predicted_labels = df_output.iloc[:, -1].copy()
23 axis_labels = list(set(true_labels))
24 c_matrix = confusion_matrix(true_labels, predicted_labels)
25 data = [
26 go.Heatmap(
27 z=c_matrix,
28 x=axis_labels,
29 y=axis_labels,
30 colorscale='Portland',
31 )
32 ]
33
34 layout = go.Layout(
35 title='Confusion Matrix between true and predicted class labels',
36 xaxis=dict(title='True class labels'),
37 yaxis=dict(title='Predicted class labels')
38 )
39
40 fig = go.Figure(data=data, layout=layout)
41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False)
42
43 # plot precision, recall and f_score for each class label
44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels)
45
46 trace_precision = go.Scatter(
47 x=axis_labels,
48 y=precision,
49 mode='lines+markers',
50 name='Precision'
51 )
52
53 trace_recall = go.Scatter(
54 x=axis_labels,
55 y=recall,
56 mode='lines+markers',
57 name='Recall'
58 )
59
60 trace_fscore = go.Scatter(
61 x=axis_labels,
62 y=f_score,
63 mode='lines+markers',
64 name='F-score'
65 )
66
67 layout_prf = go.Layout(
68 title='Precision, recall and f-score of true and predicted class labels',
69 xaxis=dict(title='Class labels'),
70 yaxis=dict(title='Precision, recall and f-score')
71 )
72
73 data_prf = [trace_precision, trace_recall, trace_fscore]
74 fig_prf = go.Figure(data=data_prf, layout=layout_prf)
75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False)
76
77 # plot roc and auc curves for different classes
78 with open(infile_trained_model, 'rb') as model_file:
79 model = pickle.load(model_file)
80
81 # remove the last column (label column)
82 test_data = df_input.iloc[:, :-1]
83 model_items = dir(model)
84
85 try:
86 # find the probability estimating method
87 if 'predict_proba' in model_items:
88 y_score = model.predict_proba(test_data)
89 elif 'decision_function' in model_items:
90 y_score = model.decision_function(test_data)
91
92 true_labels_list = true_labels.tolist()
93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels)
94 data_roc = list()
95
96 if len(axis_labels) > 2:
97 fpr = dict()
98 tpr = dict()
99 roc_auc = dict()
100 for i in axis_labels:
101 fpr[i], tpr[i], _ = roc_curve(one_hot_labels[:, i], y_score[:, i])
102 roc_auc[i] = auc(fpr[i], tpr[i])
103 for i in range(len(axis_labels)):
104 trace = go.Scatter(
105 x=fpr[i],
106 y=tpr[i],
107 mode='lines+markers',
108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i])
109 )
110 data_roc.append(trace)
111 else:
112 try:
113 y_score_binary = y_score[:, 1]
114 except:
115 y_score_binary = y_score
116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1)
117 roc_auc = auc(fpr, tpr)
118 trace = go.Scatter(
119 x=fpr,
120 y=tpr,
121 mode='lines+markers',
122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc)
123 )
124 data_roc.append(trace)
125
126 trace_diag = go.Scatter(
127 x=[0, 1],
128 y=[0, 1],
129 mode='lines',
130 name='Chance'
131 )
132 data_roc.append(trace_diag)
133 layout_roc = go.Layout(
134 title='Receiver operating characteristics (ROC) and area under curve (AUC)',
135 xaxis=dict(title='False positive rate'),
136 yaxis=dict(title='True positive rate')
137 )
138
139 fig_roc = go.Figure(data=data_roc, layout=layout_roc)
140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False)
141
142 except Exception as exp:
143 pass
144
145
146 if __name__ == "__main__":
147 aparser = argparse.ArgumentParser()
148 aparser.add_argument("-i", "--input", dest="infile_input", required=True)
149 aparser.add_argument("-j", "--output", dest="infile_output", required=True)
150 aparser.add_argument("-k", "--model", dest="infile_trained_model", required=True)
151 args = aparser.parse_args()
152 main(args.infile_input, args.infile_output, args.infile_trained_model)