comparison ml_visualization_ex.py @ 27:8e49f26b14d3 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:10:41 -0400
parents
children 7696d389675c
comparison
equal deleted inserted replaced
26:9bb505eafac9 27:8e49f26b14d3
1 import argparse
2 import json
3 import numpy as np
4 import pandas as pd
5 import plotly
6 import plotly.graph_objs as go
7 import warnings
8
9 from keras.models import model_from_json
10 from keras.utils import plot_model
11 from sklearn.feature_selection.base import SelectorMixin
12 from sklearn.metrics import precision_recall_curve, average_precision_score
13 from sklearn.metrics import roc_curve, auc
14 from sklearn.pipeline import Pipeline
15 from galaxy_ml.utils import load_model, read_columns, SafeEval
16
17
18 safe_eval = SafeEval()
19
20
21 def main(inputs, infile_estimator=None, infile1=None,
22 infile2=None, outfile_result=None,
23 outfile_object=None, groups=None,
24 ref_seq=None, intervals=None,
25 targets=None, fasta_path=None,
26 model_config=None):
27 """
28 Parameter
29 ---------
30 inputs : str
31 File path to galaxy tool parameter
32
33 infile_estimator : str, default is None
34 File path to estimator
35
36 infile1 : str, default is None
37 File path to dataset containing features or true labels.
38
39 infile2 : str, default is None
40 File path to dataset containing target values or predicted
41 probabilities.
42
43 outfile_result : str, default is None
44 File path to save the results, either cv_results or test result
45
46 outfile_object : str, default is None
47 File path to save searchCV object
48
49 groups : str, default is None
50 File path to dataset containing groups labels
51
52 ref_seq : str, default is None
53 File path to dataset containing genome sequence file
54
55 intervals : str, default is None
56 File path to dataset containing interval file
57
58 targets : str, default is None
59 File path to dataset compressed target bed file
60
61 fasta_path : str, default is None
62 File path to dataset containing fasta file
63
64 model_config : str, default is None
65 File path to dataset containing JSON config for neural networks
66 """
67 warnings.simplefilter('ignore')
68
69 with open(inputs, 'r') as param_handler:
70 params = json.load(param_handler)
71
72 title = params['plotting_selection']['title'].strip()
73 plot_type = params['plotting_selection']['plot_type']
74 if plot_type == 'feature_importances':
75 with open(infile_estimator, 'rb') as estimator_handler:
76 estimator = load_model(estimator_handler)
77
78 column_option = (params['plotting_selection']
79 ['column_selector_options']
80 ['selected_column_selector_option'])
81 if column_option in ['by_index_number', 'all_but_by_index_number',
82 'by_header_name', 'all_but_by_header_name']:
83 c = (params['plotting_selection']
84 ['column_selector_options']['col1'])
85 else:
86 c = None
87
88 _, input_df = read_columns(infile1, c=c,
89 c_option=column_option,
90 return_df=True,
91 sep='\t', header='infer',
92 parse_dates=True)
93
94 feature_names = input_df.columns.values
95
96 if isinstance(estimator, Pipeline):
97 for st in estimator.steps[:-1]:
98 if isinstance(st[-1], SelectorMixin):
99 mask = st[-1].get_support()
100 feature_names = feature_names[mask]
101 estimator = estimator.steps[-1][-1]
102
103 if hasattr(estimator, 'coef_'):
104 coefs = estimator.coef_
105 else:
106 coefs = getattr(estimator, 'feature_importances_', None)
107 if coefs is None:
108 raise RuntimeError('The classifier does not expose '
109 '"coef_" or "feature_importances_" '
110 'attributes')
111
112 threshold = params['plotting_selection']['threshold']
113 if threshold is not None:
114 mask = (coefs > threshold) | (coefs < -threshold)
115 coefs = coefs[mask]
116 feature_names = feature_names[mask]
117
118 # sort
119 indices = np.argsort(coefs)[::-1]
120
121 trace = go.Bar(x=feature_names[indices],
122 y=coefs[indices])
123 layout = go.Layout(title=title or "Feature Importances")
124 fig = go.Figure(data=[trace], layout=layout)
125
126 elif plot_type == 'pr_curve':
127 df1 = pd.read_csv(infile1, sep='\t', header=None)
128 df2 = pd.read_csv(infile2, sep='\t', header=None)
129
130 precision = {}
131 recall = {}
132 ap = {}
133
134 pos_label = params['plotting_selection']['pos_label'].strip() \
135 or None
136 for col in df1.columns:
137 y_true = df1[col].values
138 y_score = df2[col].values
139
140 precision[col], recall[col], _ = precision_recall_curve(
141 y_true, y_score, pos_label=pos_label)
142 ap[col] = average_precision_score(
143 y_true, y_score, pos_label=pos_label or 1)
144
145 if len(df1.columns) > 1:
146 precision["micro"], recall["micro"], _ = precision_recall_curve(
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
148 ap['micro'] = average_precision_score(
149 df1.values, df2.values, average='micro', pos_label=pos_label or 1)
150
151 data = []
152 for key in precision.keys():
153 trace = go.Scatter(
154 x=recall[key],
155 y=precision[key],
156 mode='lines',
157 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro'
158 else 'column %s (area = %.2f)' % (key, ap[key])
159 )
160 data.append(trace)
161
162 layout = go.Layout(
163 title=title or "Precision-Recall curve",
164 xaxis=dict(title='Recall'),
165 yaxis=dict(title='Precision')
166 )
167
168 fig = go.Figure(data=data, layout=layout)
169
170 elif plot_type == 'roc_curve':
171 df1 = pd.read_csv(infile1, sep='\t', header=None)
172 df2 = pd.read_csv(infile2, sep='\t', header=None)
173
174 fpr = {}
175 tpr = {}
176 roc_auc = {}
177
178 pos_label = params['plotting_selection']['pos_label'].strip() \
179 or None
180 for col in df1.columns:
181 y_true = df1[col].values
182 y_score = df2[col].values
183
184 fpr[col], tpr[col], _ = roc_curve(
185 y_true, y_score, pos_label=pos_label)
186 roc_auc[col] = auc(fpr[col], tpr[col])
187
188 if len(df1.columns) > 1:
189 fpr["micro"], tpr["micro"], _ = roc_curve(
190 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
191 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"])
192
193 data = []
194 for key in fpr.keys():
195 trace = go.Scatter(
196 x=fpr[key],
197 y=tpr[key],
198 mode='lines',
199 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro'
200 else 'column %s (area = %.2f)' % (key, roc_auc[key])
201 )
202 data.append(trace)
203
204 trace = go.Scatter(x=[0, 1], y=[0, 1],
205 mode='lines',
206 line=dict(color='black', dash='dash'),
207 showlegend=False)
208 data.append(trace)
209
210 layout = go.Layout(
211 title=title or "Receiver operating characteristic curve",
212 xaxis=dict(title='False Positive Rate'),
213 yaxis=dict(title='True Positive Rate')
214 )
215
216 fig = go.Figure(data=data, layout=layout)
217
218 elif plot_type == 'rfecv_gridscores':
219 input_df = pd.read_csv(infile1, sep='\t', header='infer')
220 scores = input_df.iloc[:, 0]
221 steps = params['plotting_selection']['steps'].strip()
222 steps = safe_eval(steps)
223
224 data = go.Scatter(
225 x=list(range(len(scores))),
226 y=scores,
227 text=[str(_) for _ in steps] if steps else None,
228 mode='lines'
229 )
230 layout = go.Layout(
231 xaxis=dict(title="Number of features selected"),
232 yaxis=dict(title="Cross validation score"),
233 title=title or None
234 )
235
236 fig = go.Figure(data=[data], layout=layout)
237
238 elif plot_type == 'learning_curve':
239 input_df = pd.read_csv(infile1, sep='\t', header='infer')
240 plot_std_err = params['plotting_selection']['plot_std_err']
241 data1 = go.Scatter(
242 x=input_df['train_sizes_abs'],
243 y=input_df['mean_train_scores'],
244 error_y=dict(
245 array=input_df['std_train_scores']
246 ) if plot_std_err else None,
247 mode='lines',
248 name="Train Scores",
249 )
250 data2 = go.Scatter(
251 x=input_df['train_sizes_abs'],
252 y=input_df['mean_test_scores'],
253 error_y=dict(
254 array=input_df['std_test_scores']
255 ) if plot_std_err else None,
256 mode='lines',
257 name="Test Scores",
258 )
259 layout = dict(
260 xaxis=dict(
261 title='No. of samples'
262 ),
263 yaxis=dict(
264 title='Performance Score'
265 ),
266 title=title or 'Learning Curve'
267 )
268 fig = go.Figure(data=[data1, data2], layout=layout)
269
270 elif plot_type == 'keras_plot_model':
271 with open(model_config, 'r') as f:
272 model_str = f.read()
273 model = model_from_json(model_str)
274 plot_model(model, to_file="output.png")
275 __import__('os').rename('output.png', 'output')
276
277 return 0
278
279 plotly.offline.plot(fig, filename="output.html",
280 auto_open=False)
281 # to be discovered by `from_work_dir`
282 __import__('os').rename('output.html', 'output')
283
284
285 if __name__ == '__main__':
286 aparser = argparse.ArgumentParser()
287 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
288 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
289 aparser.add_argument("-X", "--infile1", dest="infile1")
290 aparser.add_argument("-y", "--infile2", dest="infile2")
291 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
292 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
293 aparser.add_argument("-g", "--groups", dest="groups")
294 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
295 aparser.add_argument("-b", "--intervals", dest="intervals")
296 aparser.add_argument("-t", "--targets", dest="targets")
297 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
298 aparser.add_argument("-c", "--model_config", dest="model_config")
299 args = aparser.parse_args()
300
301 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
302 args.outfile_result, outfile_object=args.outfile_object,
303 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
304 targets=args.targets, fasta_path=args.fasta_path,
305 model_config=args.model_config)