comparison ml_visualization_ex.py @ 0:eaddff553324 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author bgruening
date Fri, 01 Nov 2019 17:15:22 -0400
parents
children cf54bae8ad42
comparison
equal deleted inserted replaced
-1:000000000000 0:eaddff553324
1 import argparse
2 import json
3 import numpy as np
4 import pandas as pd
5 import plotly
6 import plotly.graph_objs as go
7 import warnings
8
9 from keras.models import model_from_json
10 from keras.utils import plot_model
11 from sklearn.feature_selection.base import SelectorMixin
12 from sklearn.metrics import precision_recall_curve, average_precision_score
13 from sklearn.metrics import roc_curve, auc
14 from sklearn.pipeline import Pipeline
15 from galaxy_ml.utils import load_model, read_columns, SafeEval
16
17
18 safe_eval = SafeEval()
19
20
21 def main(inputs, infile_estimator=None, infile1=None,
22 infile2=None, outfile_result=None,
23 outfile_object=None, groups=None,
24 ref_seq=None, intervals=None,
25 targets=None, fasta_path=None,
26 model_config=None):
27 """
28 Parameter
29 ---------
30 inputs : str
31 File path to galaxy tool parameter
32
33 infile_estimator : str, default is None
34 File path to estimator
35
36 infile1 : str, default is None
37 File path to dataset containing features or true labels.
38
39 infile2 : str, default is None
40 File path to dataset containing target values or predicted
41 probabilities.
42
43 outfile_result : str, default is None
44 File path to save the results, either cv_results or test result
45
46 outfile_object : str, default is None
47 File path to save searchCV object
48
49 groups : str, default is None
50 File path to dataset containing groups labels
51
52 ref_seq : str, default is None
53 File path to dataset containing genome sequence file
54
55 intervals : str, default is None
56 File path to dataset containing interval file
57
58 targets : str, default is None
59 File path to dataset compressed target bed file
60
61 fasta_path : str, default is None
62 File path to dataset containing fasta file
63
64 model_config : str, default is None
65 File path to dataset containing JSON config for neural networks
66 """
67 warnings.simplefilter('ignore')
68
69 with open(inputs, 'r') as param_handler:
70 params = json.load(param_handler)
71
72 title = params['plotting_selection']['title'].strip()
73 plot_type = params['plotting_selection']['plot_type']
74 if plot_type == 'feature_importances':
75 with open(infile_estimator, 'rb') as estimator_handler:
76 estimator = load_model(estimator_handler)
77
78 column_option = (params['plotting_selection']
79 ['column_selector_options']
80 ['selected_column_selector_option'])
81 if column_option in ['by_index_number', 'all_but_by_index_number',
82 'by_header_name', 'all_but_by_header_name']:
83 c = (params['plotting_selection']
84 ['column_selector_options']['col1'])
85 else:
86 c = None
87
88 _, input_df = read_columns(infile1, c=c,
89 c_option=column_option,
90 return_df=True,
91 sep='\t', header='infer',
92 parse_dates=True)
93
94 feature_names = input_df.columns.values
95
96 if isinstance(estimator, Pipeline):
97 for st in estimator.steps[:-1]:
98 if isinstance(st[-1], SelectorMixin):
99 mask = st[-1].get_support()
100 feature_names = feature_names[mask]
101 estimator = estimator.steps[-1][-1]
102
103 if hasattr(estimator, 'coef_'):
104 coefs = estimator.coef_
105 else:
106 coefs = getattr(estimator, 'feature_importances_', None)
107 if coefs is None:
108 raise RuntimeError('The classifier does not expose '
109 '"coef_" or "feature_importances_" '
110 'attributes')
111
112 threshold = params['plotting_selection']['threshold']
113 if threshold is not None:
114 mask = (coefs > threshold) | (coefs < -threshold)
115 coefs = coefs[mask]
116 feature_names = feature_names[mask]
117
118 # sort
119 indices = np.argsort(coefs)[::-1]
120
121 trace = go.Bar(x=feature_names[indices],
122 y=coefs[indices])
123 layout = go.Layout(title=title or "Feature Importances")
124 fig = go.Figure(data=[trace], layout=layout)
125
126 elif plot_type == 'pr_curve':
127 df1 = pd.read_csv(infile1, sep='\t', header=None)
128 df2 = pd.read_csv(infile2, sep='\t', header=None)
129
130 precision = {}
131 recall = {}
132 ap = {}
133
134 pos_label = params['plotting_selection']['pos_label'].strip() \
135 or None
136 for col in df1.columns:
137 y_true = df1[col].values
138 y_score = df2[col].values
139
140 precision[col], recall[col], _ = precision_recall_curve(
141 y_true, y_score, pos_label=pos_label)
142 ap[col] = average_precision_score(
143 y_true, y_score, pos_label=pos_label or 1)
144
145 if len(df1.columns) > 1:
146 precision["micro"], recall["micro"], _ = precision_recall_curve(
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
148 ap['micro'] = average_precision_score(
149 df1.values, df2.values, average='micro',
150 pos_label=pos_label or 1)
151
152 data = []
153 for key in precision.keys():
154 trace = go.Scatter(
155 x=recall[key],
156 y=precision[key],
157 mode='lines',
158 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro'
159 else 'column %s (area = %.2f)' % (key, ap[key])
160 )
161 data.append(trace)
162
163 layout = go.Layout(
164 title=title or "Precision-Recall curve",
165 xaxis=dict(title='Recall'),
166 yaxis=dict(title='Precision')
167 )
168
169 fig = go.Figure(data=data, layout=layout)
170
171 elif plot_type == 'roc_curve':
172 df1 = pd.read_csv(infile1, sep='\t', header=None)
173 df2 = pd.read_csv(infile2, sep='\t', header=None)
174
175 fpr = {}
176 tpr = {}
177 roc_auc = {}
178
179 pos_label = params['plotting_selection']['pos_label'].strip() \
180 or None
181 for col in df1.columns:
182 y_true = df1[col].values
183 y_score = df2[col].values
184
185 fpr[col], tpr[col], _ = roc_curve(
186 y_true, y_score, pos_label=pos_label)
187 roc_auc[col] = auc(fpr[col], tpr[col])
188
189 if len(df1.columns) > 1:
190 fpr["micro"], tpr["micro"], _ = roc_curve(
191 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
192 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"])
193
194 data = []
195 for key in fpr.keys():
196 trace = go.Scatter(
197 x=fpr[key],
198 y=tpr[key],
199 mode='lines',
200 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro'
201 else 'column %s (area = %.2f)' % (key, roc_auc[key])
202 )
203 data.append(trace)
204
205 trace = go.Scatter(x=[0, 1], y=[0, 1],
206 mode='lines',
207 line=dict(color='black', dash='dash'),
208 showlegend=False)
209 data.append(trace)
210
211 layout = go.Layout(
212 title=title or "Receiver operating characteristic curve",
213 xaxis=dict(title='False Positive Rate'),
214 yaxis=dict(title='True Positive Rate')
215 )
216
217 fig = go.Figure(data=data, layout=layout)
218
219 elif plot_type == 'rfecv_gridscores':
220 input_df = pd.read_csv(infile1, sep='\t', header='infer')
221 scores = input_df.iloc[:, 0]
222 steps = params['plotting_selection']['steps'].strip()
223 steps = safe_eval(steps)
224
225 data = go.Scatter(
226 x=list(range(len(scores))),
227 y=scores,
228 text=[str(_) for _ in steps] if steps else None,
229 mode='lines'
230 )
231 layout = go.Layout(
232 xaxis=dict(title="Number of features selected"),
233 yaxis=dict(title="Cross validation score"),
234 title=title or None
235 )
236
237 fig = go.Figure(data=[data], layout=layout)
238
239 elif plot_type == 'learning_curve':
240 input_df = pd.read_csv(infile1, sep='\t', header='infer')
241 plot_std_err = params['plotting_selection']['plot_std_err']
242 data1 = go.Scatter(
243 x=input_df['train_sizes_abs'],
244 y=input_df['mean_train_scores'],
245 error_y=dict(
246 array=input_df['std_train_scores']
247 ) if plot_std_err else None,
248 mode='lines',
249 name="Train Scores",
250 )
251 data2 = go.Scatter(
252 x=input_df['train_sizes_abs'],
253 y=input_df['mean_test_scores'],
254 error_y=dict(
255 array=input_df['std_test_scores']
256 ) if plot_std_err else None,
257 mode='lines',
258 name="Test Scores",
259 )
260 layout = dict(
261 xaxis=dict(
262 title='No. of samples'
263 ),
264 yaxis=dict(
265 title='Performance Score'
266 ),
267 title=title or 'Learning Curve'
268 )
269 fig = go.Figure(data=[data1, data2], layout=layout)
270
271 elif plot_type == 'keras_plot_model':
272 with open(model_config, 'r') as f:
273 model_str = f.read()
274 model = model_from_json(model_str)
275 plot_model(model, to_file="output.png")
276 __import__('os').rename('output.png', 'output')
277
278 return 0
279
280 plotly.offline.plot(fig, filename="output.html",
281 auto_open=False)
282 # to be discovered by `from_work_dir`
283 __import__('os').rename('output.html', 'output')
284
285
286 if __name__ == '__main__':
287 aparser = argparse.ArgumentParser()
288 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
289 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
290 aparser.add_argument("-X", "--infile1", dest="infile1")
291 aparser.add_argument("-y", "--infile2", dest="infile2")
292 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
293 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
294 aparser.add_argument("-g", "--groups", dest="groups")
295 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
296 aparser.add_argument("-b", "--intervals", dest="intervals")
297 aparser.add_argument("-t", "--targets", dest="targets")
298 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
299 aparser.add_argument("-c", "--model_config", dest="model_config")
300 args = aparser.parse_args()
301
302 main(args.inputs, args.infile_estimator, args.infile1, args.infile2,
303 args.outfile_result, outfile_object=args.outfile_object,
304 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals,
305 targets=args.targets, fasta_path=args.fasta_path,
306 model_config=args.model_config)