Mercurial > repos > bgruening > sklearn_model_fit
comparison ml_visualization_ex.py @ 0:734c66aa945a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author | bgruening |
---|---|
date | Fri, 01 Nov 2019 17:18:28 -0400 |
parents | |
children | 8861ece0b66f |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:734c66aa945a |
---|---|
1 import argparse | |
2 import json | |
3 import numpy as np | |
4 import pandas as pd | |
5 import plotly | |
6 import plotly.graph_objs as go | |
7 import warnings | |
8 | |
9 from keras.models import model_from_json | |
10 from keras.utils import plot_model | |
11 from sklearn.feature_selection.base import SelectorMixin | |
12 from sklearn.metrics import precision_recall_curve, average_precision_score | |
13 from sklearn.metrics import roc_curve, auc | |
14 from sklearn.pipeline import Pipeline | |
15 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
16 | |
17 | |
18 safe_eval = SafeEval() | |
19 | |
20 | |
21 def main(inputs, infile_estimator=None, infile1=None, | |
22 infile2=None, outfile_result=None, | |
23 outfile_object=None, groups=None, | |
24 ref_seq=None, intervals=None, | |
25 targets=None, fasta_path=None, | |
26 model_config=None): | |
27 """ | |
28 Parameter | |
29 --------- | |
30 inputs : str | |
31 File path to galaxy tool parameter | |
32 | |
33 infile_estimator : str, default is None | |
34 File path to estimator | |
35 | |
36 infile1 : str, default is None | |
37 File path to dataset containing features or true labels. | |
38 | |
39 infile2 : str, default is None | |
40 File path to dataset containing target values or predicted | |
41 probabilities. | |
42 | |
43 outfile_result : str, default is None | |
44 File path to save the results, either cv_results or test result | |
45 | |
46 outfile_object : str, default is None | |
47 File path to save searchCV object | |
48 | |
49 groups : str, default is None | |
50 File path to dataset containing groups labels | |
51 | |
52 ref_seq : str, default is None | |
53 File path to dataset containing genome sequence file | |
54 | |
55 intervals : str, default is None | |
56 File path to dataset containing interval file | |
57 | |
58 targets : str, default is None | |
59 File path to dataset compressed target bed file | |
60 | |
61 fasta_path : str, default is None | |
62 File path to dataset containing fasta file | |
63 | |
64 model_config : str, default is None | |
65 File path to dataset containing JSON config for neural networks | |
66 """ | |
67 warnings.simplefilter('ignore') | |
68 | |
69 with open(inputs, 'r') as param_handler: | |
70 params = json.load(param_handler) | |
71 | |
72 title = params['plotting_selection']['title'].strip() | |
73 plot_type = params['plotting_selection']['plot_type'] | |
74 if plot_type == 'feature_importances': | |
75 with open(infile_estimator, 'rb') as estimator_handler: | |
76 estimator = load_model(estimator_handler) | |
77 | |
78 column_option = (params['plotting_selection'] | |
79 ['column_selector_options'] | |
80 ['selected_column_selector_option']) | |
81 if column_option in ['by_index_number', 'all_but_by_index_number', | |
82 'by_header_name', 'all_but_by_header_name']: | |
83 c = (params['plotting_selection'] | |
84 ['column_selector_options']['col1']) | |
85 else: | |
86 c = None | |
87 | |
88 _, input_df = read_columns(infile1, c=c, | |
89 c_option=column_option, | |
90 return_df=True, | |
91 sep='\t', header='infer', | |
92 parse_dates=True) | |
93 | |
94 feature_names = input_df.columns.values | |
95 | |
96 if isinstance(estimator, Pipeline): | |
97 for st in estimator.steps[:-1]: | |
98 if isinstance(st[-1], SelectorMixin): | |
99 mask = st[-1].get_support() | |
100 feature_names = feature_names[mask] | |
101 estimator = estimator.steps[-1][-1] | |
102 | |
103 if hasattr(estimator, 'coef_'): | |
104 coefs = estimator.coef_ | |
105 else: | |
106 coefs = getattr(estimator, 'feature_importances_', None) | |
107 if coefs is None: | |
108 raise RuntimeError('The classifier does not expose ' | |
109 '"coef_" or "feature_importances_" ' | |
110 'attributes') | |
111 | |
112 threshold = params['plotting_selection']['threshold'] | |
113 if threshold is not None: | |
114 mask = (coefs > threshold) | (coefs < -threshold) | |
115 coefs = coefs[mask] | |
116 feature_names = feature_names[mask] | |
117 | |
118 # sort | |
119 indices = np.argsort(coefs)[::-1] | |
120 | |
121 trace = go.Bar(x=feature_names[indices], | |
122 y=coefs[indices]) | |
123 layout = go.Layout(title=title or "Feature Importances") | |
124 fig = go.Figure(data=[trace], layout=layout) | |
125 | |
126 elif plot_type == 'pr_curve': | |
127 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
128 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
129 | |
130 precision = {} | |
131 recall = {} | |
132 ap = {} | |
133 | |
134 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
135 or None | |
136 for col in df1.columns: | |
137 y_true = df1[col].values | |
138 y_score = df2[col].values | |
139 | |
140 precision[col], recall[col], _ = precision_recall_curve( | |
141 y_true, y_score, pos_label=pos_label) | |
142 ap[col] = average_precision_score( | |
143 y_true, y_score, pos_label=pos_label or 1) | |
144 | |
145 if len(df1.columns) > 1: | |
146 precision["micro"], recall["micro"], _ = precision_recall_curve( | |
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
148 ap['micro'] = average_precision_score( | |
149 df1.values, df2.values, average='micro', | |
150 pos_label=pos_label or 1) | |
151 | |
152 data = [] | |
153 for key in precision.keys(): | |
154 trace = go.Scatter( | |
155 x=recall[key], | |
156 y=precision[key], | |
157 mode='lines', | |
158 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro' | |
159 else 'column %s (area = %.2f)' % (key, ap[key]) | |
160 ) | |
161 data.append(trace) | |
162 | |
163 layout = go.Layout( | |
164 title=title or "Precision-Recall curve", | |
165 xaxis=dict(title='Recall'), | |
166 yaxis=dict(title='Precision') | |
167 ) | |
168 | |
169 fig = go.Figure(data=data, layout=layout) | |
170 | |
171 elif plot_type == 'roc_curve': | |
172 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
173 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
174 | |
175 fpr = {} | |
176 tpr = {} | |
177 roc_auc = {} | |
178 | |
179 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
180 or None | |
181 for col in df1.columns: | |
182 y_true = df1[col].values | |
183 y_score = df2[col].values | |
184 | |
185 fpr[col], tpr[col], _ = roc_curve( | |
186 y_true, y_score, pos_label=pos_label) | |
187 roc_auc[col] = auc(fpr[col], tpr[col]) | |
188 | |
189 if len(df1.columns) > 1: | |
190 fpr["micro"], tpr["micro"], _ = roc_curve( | |
191 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
192 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"]) | |
193 | |
194 data = [] | |
195 for key in fpr.keys(): | |
196 trace = go.Scatter( | |
197 x=fpr[key], | |
198 y=tpr[key], | |
199 mode='lines', | |
200 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' | |
201 else 'column %s (area = %.2f)' % (key, roc_auc[key]) | |
202 ) | |
203 data.append(trace) | |
204 | |
205 trace = go.Scatter(x=[0, 1], y=[0, 1], | |
206 mode='lines', | |
207 line=dict(color='black', dash='dash'), | |
208 showlegend=False) | |
209 data.append(trace) | |
210 | |
211 layout = go.Layout( | |
212 title=title or "Receiver operating characteristic curve", | |
213 xaxis=dict(title='False Positive Rate'), | |
214 yaxis=dict(title='True Positive Rate') | |
215 ) | |
216 | |
217 fig = go.Figure(data=data, layout=layout) | |
218 | |
219 elif plot_type == 'rfecv_gridscores': | |
220 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
221 scores = input_df.iloc[:, 0] | |
222 steps = params['plotting_selection']['steps'].strip() | |
223 steps = safe_eval(steps) | |
224 | |
225 data = go.Scatter( | |
226 x=list(range(len(scores))), | |
227 y=scores, | |
228 text=[str(_) for _ in steps] if steps else None, | |
229 mode='lines' | |
230 ) | |
231 layout = go.Layout( | |
232 xaxis=dict(title="Number of features selected"), | |
233 yaxis=dict(title="Cross validation score"), | |
234 title=title or None | |
235 ) | |
236 | |
237 fig = go.Figure(data=[data], layout=layout) | |
238 | |
239 elif plot_type == 'learning_curve': | |
240 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
241 plot_std_err = params['plotting_selection']['plot_std_err'] | |
242 data1 = go.Scatter( | |
243 x=input_df['train_sizes_abs'], | |
244 y=input_df['mean_train_scores'], | |
245 error_y=dict( | |
246 array=input_df['std_train_scores'] | |
247 ) if plot_std_err else None, | |
248 mode='lines', | |
249 name="Train Scores", | |
250 ) | |
251 data2 = go.Scatter( | |
252 x=input_df['train_sizes_abs'], | |
253 y=input_df['mean_test_scores'], | |
254 error_y=dict( | |
255 array=input_df['std_test_scores'] | |
256 ) if plot_std_err else None, | |
257 mode='lines', | |
258 name="Test Scores", | |
259 ) | |
260 layout = dict( | |
261 xaxis=dict( | |
262 title='No. of samples' | |
263 ), | |
264 yaxis=dict( | |
265 title='Performance Score' | |
266 ), | |
267 title=title or 'Learning Curve' | |
268 ) | |
269 fig = go.Figure(data=[data1, data2], layout=layout) | |
270 | |
271 elif plot_type == 'keras_plot_model': | |
272 with open(model_config, 'r') as f: | |
273 model_str = f.read() | |
274 model = model_from_json(model_str) | |
275 plot_model(model, to_file="output.png") | |
276 __import__('os').rename('output.png', 'output') | |
277 | |
278 return 0 | |
279 | |
280 plotly.offline.plot(fig, filename="output.html", | |
281 auto_open=False) | |
282 # to be discovered by `from_work_dir` | |
283 __import__('os').rename('output.html', 'output') | |
284 | |
285 | |
286 if __name__ == '__main__': | |
287 aparser = argparse.ArgumentParser() | |
288 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
289 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
290 aparser.add_argument("-X", "--infile1", dest="infile1") | |
291 aparser.add_argument("-y", "--infile2", dest="infile2") | |
292 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
293 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
294 aparser.add_argument("-g", "--groups", dest="groups") | |
295 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
296 aparser.add_argument("-b", "--intervals", dest="intervals") | |
297 aparser.add_argument("-t", "--targets", dest="targets") | |
298 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
299 aparser.add_argument("-c", "--model_config", dest="model_config") | |
300 args = aparser.parse_args() | |
301 | |
302 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | |
303 args.outfile_result, outfile_object=args.outfile_object, | |
304 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | |
305 targets=args.targets, fasta_path=args.fasta_path, | |
306 model_config=args.model_config) |