comparison utils.py @ 7:372582a7a34d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 57f4407e278a615f47a377a3328782b1d8e0b54d
author bgruening
date Sun, 30 Dec 2018 01:50:39 -0500
parents 1c5989b930e3
children 1a9d5a8fff12
comparison
equal deleted inserted replaced
6:90f2d6532262 7:372582a7a34d
1 import sys 1 import json
2 import numpy as np
2 import os 3 import os
3 import pandas 4 import pandas
5 import pickle
4 import re 6 import re
5 import pickle
6 import warnings
7 import numpy as np
8 import xgboost
9 import scipy 7 import scipy
10 import sklearn 8 import sklearn
9 import sys
10 import warnings
11 import xgboost
12
11 from asteval import Interpreter, make_symbol_table 13 from asteval import Interpreter, make_symbol_table
12 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, 14 from sklearn import (cluster, compose, decomposition, ensemble, feature_extraction,
13 gaussian_process, kernel_approximation, metrics, 15 feature_selection, gaussian_process, kernel_approximation, metrics,
14 model_selection, naive_bayes, neighbors, pipeline, preprocessing, 16 model_selection, naive_bayes, neighbors, pipeline, preprocessing,
15 svm, linear_model, tree, discriminant_analysis) 17 svm, linear_model, tree, discriminant_analysis)
16 18
19 try:
20 import skrebate
21 except ModuleNotFoundError:
22 pass
23
24
17 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1)) 25 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
26
27 try:
28 sk_whitelist
29 except NameError:
30 sk_whitelist = None
18 31
19 32
20 class SafePickler(pickle.Unpickler): 33 class SafePickler(pickle.Unpickler):
21 """ 34 """
22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump 35 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump
23 Usage: 36 Usage:
24 eg.: SafePickler.load(pickled_file_object) 37 eg.: SafePickler.load(pickled_file_object)
25 """ 38 """
26 def find_class(self, module, name): 39 def find_class(self, module, name):
40
41 # sk_whitelist could be read from tool
42 global sk_whitelist
43 if not sk_whitelist:
44 whitelist_file = os.path.join(os.path.dirname(__file__), 'sk_whitelist.json')
45 with open(whitelist_file, 'r') as f:
46 sk_whitelist = json.load(f)
27 47
28 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', 48 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue',
29 'def', 'del', 'elif', 'else', 'except', 'exec', 49 'def', 'del', 'elif', 'else', 'except', 'exec',
30 'finally', 'for', 'from', 'global', 'if', 'import', 50 'finally', 'for', 'from', 'global', 'if', 'import',
31 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 51 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print',
44 fullname = module + '.' + name 64 fullname = module + '.' + name
45 if (fullname in good_names)\ 65 if (fullname in good_names)\
46 or ( ( module.startswith('sklearn.') 66 or ( ( module.startswith('sklearn.')
47 or module.startswith('xgboost.') 67 or module.startswith('xgboost.')
48 or module.startswith('skrebate.') 68 or module.startswith('skrebate.')
69 or module.startswith('imblearn')
49 or module.startswith('numpy.') 70 or module.startswith('numpy.')
50 or module == 'numpy' 71 or module == 'numpy'
51 ) 72 )
52 and (name not in bad_names) 73 and (name not in bad_names)
53 ): 74 ):
54 # TODO: replace with a whitelist checker 75 # TODO: replace with a whitelist checker
55 if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + good_names: 76 if fullname not in sk_whitelist['SK_NAMES'] + sk_whitelist['SKR_NAMES'] + sk_whitelist['XGB_NAMES'] + sk_whitelist['NUMPY_NAMES'] + sk_whitelist['IMBLEARN_NAMES'] + good_names:
56 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) 77 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname)
57 mod = sys.modules[module] 78 mod = sys.modules[module]
58 return getattr(mod, name) 79 return getattr(mod, name)
59 80
60 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) 81 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname)
81 y = data.values 102 y = data.values
82 if return_df: 103 if return_df:
83 return y, data 104 return y, data
84 else: 105 else:
85 return y 106 return y
86 return y
87 107
88 108
89 ## generate an instance for one of sklearn.feature_selection classes 109 ## generate an instance for one of sklearn.feature_selection classes
90 def feature_selector(inputs): 110 def feature_selector(inputs):
91 selector = inputs["selected_algorithm"] 111 selector = inputs['selected_algorithm']
92 selector = getattr(sklearn.feature_selection, selector) 112 selector = getattr(sklearn.feature_selection, selector)
93 options = inputs["options"] 113 options = inputs['options']
94 114
95 if inputs['selected_algorithm'] == 'SelectFromModel': 115 if inputs['selected_algorithm'] == 'SelectFromModel':
96 if not options['threshold'] or options['threshold'] == 'None': 116 if not options['threshold'] or options['threshold'] == 'None':
97 options['threshold'] = None 117 options['threshold'] = None
118 else:
119 try:
120 options['threshold'] = float(options['threshold'])
121 except ValueError:
122 pass
98 if inputs['model_inputter']['input_mode'] == 'prefitted': 123 if inputs['model_inputter']['input_mode'] == 'prefitted':
99 model_file = inputs['model_inputter']['fitted_estimator'] 124 model_file = inputs['model_inputter']['fitted_estimator']
100 with open(model_file, 'rb') as model_handler: 125 with open(model_file, 'rb') as model_handler:
101 fitted_estimator = load_model(model_handler) 126 fitted_estimator = load_model(model_handler)
102 new_selector = selector(fitted_estimator, prefit=True, **options) 127 new_selector = selector(fitted_estimator, prefit=True, **options)
103 else: 128 else:
104 estimator_json = inputs['model_inputter']["estimator_selector"] 129 estimator_json = inputs['model_inputter']['estimator_selector']
105 estimator = get_estimator(estimator_json) 130 estimator = get_estimator(estimator_json)
106 new_selector = selector(estimator, **options) 131 new_selector = selector(estimator, **options)
107 132
108 elif inputs['selected_algorithm'] == 'RFE': 133 elif inputs['selected_algorithm'] == 'RFE':
109 estimator = get_estimator(inputs["estimator_selector"]) 134 estimator = get_estimator(inputs['estimator_selector'])
135 step = options.get('step', None)
136 if step and step >= 1.0:
137 options['step'] = int(step)
110 new_selector = selector(estimator, **options) 138 new_selector = selector(estimator, **options)
111 139
112 elif inputs['selected_algorithm'] == 'RFECV': 140 elif inputs['selected_algorithm'] == 'RFECV':
113 options['scoring'] = get_scoring(options['scoring']) 141 options['scoring'] = get_scoring(options['scoring'])
114 options['n_jobs'] = N_JOBS 142 options['n_jobs'] = N_JOBS
115 options['cv'] = get_cv(options['cv'].strip()) 143 splitter, groups = get_cv(options.pop('cv_selector'))
116 estimator = get_estimator(inputs["estimator_selector"]) 144 # TODO support group cv splitters
145 options['cv'] = splitter
146 step = options.get('step', None)
147 if step and step >= 1.0:
148 options['step'] = int(step)
149 estimator = get_estimator(inputs['estimator_selector'])
117 new_selector = selector(estimator, **options) 150 new_selector = selector(estimator, **options)
118 151
119 elif inputs['selected_algorithm'] == "VarianceThreshold": 152 elif inputs['selected_algorithm'] == 'VarianceThreshold':
120 new_selector = selector(**options) 153 new_selector = selector(**options)
121 154
122 else: 155 else:
123 score_func = inputs["score_func"] 156 score_func = inputs['score_func']
124 score_func = getattr(sklearn.feature_selection, score_func) 157 score_func = getattr(sklearn.feature_selection, score_func)
125 new_selector = selector(score_func, **options) 158 new_selector = selector(score_func, **options)
126 159
127 return new_selector 160 return new_selector
128 161
129 162
130 def get_X_y(params, file1, file2): 163 def get_X_y(params, file1, file2):
131 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] 164 input_type = params['selected_tasks']['selected_algorithms']['input_options']['selected_input']
132 if input_type == "tabular": 165 if input_type == 'tabular':
133 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header1"] else None 166 header = 'infer' if params['selected_tasks']['selected_algorithms']['input_options']['header1'] else None
134 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["selected_column_selector_option"] 167 column_option = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_1']['selected_column_selector_option']
135 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: 168 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']:
136 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_1"]["col1"] 169 c = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_1']['col1']
137 else: 170 else:
138 c = None 171 c = None
139 X = read_columns( 172 X = read_columns(
140 file1, 173 file1,
141 c=c, 174 c=c,
145 parse_dates=True 178 parse_dates=True
146 ) 179 )
147 else: 180 else:
148 X = mmread(file1) 181 X = mmread(file1)
149 182
150 header = 'infer' if params["selected_tasks"]["selected_algorithms"]["input_options"]["header2"] else None 183 header = 'infer' if params['selected_tasks']['selected_algorithms']['input_options']['header2'] else None
151 column_option = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["selected_column_selector_option2"] 184 column_option = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_2']['selected_column_selector_option2']
152 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: 185 if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']:
153 c = params["selected_tasks"]["selected_algorithms"]["input_options"]["column_selector_options_2"]["col2"] 186 c = params['selected_tasks']['selected_algorithms']['input_options']['column_selector_options_2']['col2']
154 else: 187 else:
155 c = None 188 c = None
156 y = read_columns( 189 y = read_columns(
157 file2, 190 file2,
158 c=c, 191 c=c,
165 return X, y 198 return X, y
166 199
167 200
168 class SafeEval(Interpreter): 201 class SafeEval(Interpreter):
169 202
170 def __init__(self, load_scipy=False, load_numpy=False): 203 def __init__(self, load_scipy=False, load_numpy=False, load_estimators=False):
171 204
172 # File opening and other unneeded functions could be dropped 205 # File opening and other unneeded functions could be dropped
173 unwanted = ['open', 'type', 'dir', 'id', 'str', 'repr'] 206 unwanted = ['open', 'type', 'dir', 'id', 'str', 'repr']
174 207
175 # Allowed symbol table. Add more if needed. 208 # Allowed symbol table. Add more if needed.
197 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', 230 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform',
198 'vonmises', 'wald', 'weibull', 'zipf'] 231 'vonmises', 'wald', 'weibull', 'zipf']
199 for f in from_numpy_random: 232 for f in from_numpy_random:
200 syms['np_random_' + f] = getattr(np.random, f) 233 syms['np_random_' + f] = getattr(np.random, f)
201 234
235 if load_estimators:
236 estimator_table = {
237 'sklearn_svm' : getattr(sklearn, 'svm'),
238 'sklearn_tree' : getattr(sklearn, 'tree'),
239 'sklearn_ensemble' : getattr(sklearn, 'ensemble'),
240 'sklearn_neighbors' : getattr(sklearn, 'neighbors'),
241 'sklearn_naive_bayes' : getattr(sklearn, 'naive_bayes'),
242 'sklearn_linear_model' : getattr(sklearn, 'linear_model'),
243 'sklearn_cluster' : getattr(sklearn, 'cluster'),
244 'sklearn_decomposition' : getattr(sklearn, 'decomposition'),
245 'sklearn_preprocessing' : getattr(sklearn, 'preprocessing'),
246 'sklearn_feature_selection' : getattr(sklearn, 'feature_selection'),
247 'sklearn_kernel_approximation' : getattr(sklearn, 'kernel_approximation'),
248 'skrebate_ReliefF': getattr(skrebate, 'ReliefF'),
249 'skrebate_SURF': getattr(skrebate, 'SURF'),
250 'skrebate_SURFstar': getattr(skrebate, 'SURFstar'),
251 'skrebate_MultiSURF': getattr(skrebate, 'MultiSURF'),
252 'skrebate_MultiSURFstar': getattr(skrebate, 'MultiSURFstar'),
253 'skrebate_TuRF': getattr(skrebate, 'TuRF'),
254 'xgboost_XGBClassifier' : getattr(xgboost, 'XGBClassifier'),
255 'xgboost_XGBRegressor' : getattr(xgboost, 'XGBRegressor')
256 }
257 syms.update(estimator_table)
258
202 for key in unwanted: 259 for key in unwanted:
203 syms.pop(key, None) 260 syms.pop(key, None)
204 261
205 super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False, 262 super(SafeEval, self).__init__(symtable=syms, use_numpy=False, minimal=False,
206 no_if=True, no_for=True, no_while=True, no_try=True, 263 no_if=True, no_for=True, no_while=True, no_try=True,
207 no_functiondef=True, no_ifexp=True, no_listcomp=False, 264 no_functiondef=True, no_ifexp=True, no_listcomp=False,
208 no_augassign=False, no_assert=True, no_delete=True, 265 no_augassign=False, no_assert=True, no_delete=True,
209 no_raise=True, no_print=True) 266 no_raise=True, no_print=True)
210 267
211 268
212 def get_search_params(params_builder):
213 search_params = {}
214 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
215
216 for p in params_builder['param_set']:
217 search_p = p['search_param_selector']['search_p']
218 if search_p.strip() == '':
219 continue
220 param_type = p['search_param_selector']['selected_param_type']
221
222 lst = search_p.split(":")
223 assert (len(lst) == 2), "Error, make sure there is one and only one colon in search parameter input."
224 literal = lst[1].strip()
225 ev = safe_eval(literal)
226 if param_type == "final_estimator_p":
227 search_params["estimator__" + lst[0].strip()] = ev
228 else:
229 search_params["preprocessing_" + param_type[5:6] + "__" + lst[0].strip()] = ev
230
231 return search_params
232
233 269
234 def get_estimator(estimator_json): 270 def get_estimator(estimator_json):
271
235 estimator_module = estimator_json['selected_module'] 272 estimator_module = estimator_json['selected_module']
273
274 if estimator_module == 'customer_estimator':
275 c_estimator = estimator_json['c_estimator']
276 with open(c_estimator, 'rb') as model_handler:
277 new_model = load_model(model_handler)
278 return new_model
279
236 estimator_cls = estimator_json['selected_estimator'] 280 estimator_cls = estimator_json['selected_estimator']
237 281
238 if estimator_module == "xgboost": 282 if estimator_module == 'xgboost':
239 cls = getattr(xgboost, estimator_cls) 283 cls = getattr(xgboost, estimator_cls)
240 else: 284 else:
241 module = getattr(sklearn, estimator_module) 285 module = getattr(sklearn, estimator_module)
242 cls = getattr(module, estimator_cls) 286 cls = getattr(module, estimator_cls)
243 287
244 estimator = cls() 288 estimator = cls()
245 289
246 estimator_params = estimator_json['text_params'].strip() 290 estimator_params = estimator_json['text_params'].strip()
247 if estimator_params != "": 291 if estimator_params != '':
248 try: 292 try:
249 params = safe_eval('dict(' + estimator_params + ')') 293 params = safe_eval('dict(' + estimator_params + ')')
250 except ValueError: 294 except ValueError:
251 sys.exit("Unsupported parameter input: `%s`" % estimator_params) 295 sys.exit("Unsupported parameter input: `%s`" % estimator_params)
252 estimator.set_params(**params) 296 estimator.set_params(**params)
254 estimator.set_params(n_jobs=N_JOBS) 298 estimator.set_params(n_jobs=N_JOBS)
255 299
256 return estimator 300 return estimator
257 301
258 302
259 def get_cv(literal): 303 def get_cv(cv_json):
260 safe_eval = SafeEval() 304 """
261 if literal == "": 305 cv_json:
262 return None 306 e.g.:
263 if literal.isdigit(): 307 {
264 return int(literal) 308 'selected_cv': 'StratifiedKFold',
265 m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) 309 'n_splits': 3,
266 if m: 310 'shuffle': True,
267 my_class = getattr(model_selection, m.group('method')) 311 'random_state': 0
268 args = safe_eval('dict('+ m.group('args') + ')') 312 }
269 return my_class(**args) 313 """
270 sys.exit("Unsupported CV input: %s" % literal) 314 cv = cv_json.pop('selected_cv')
315 if cv == 'default':
316 return cv_json['n_splits'], None
317
318 groups = cv_json.pop('groups', None)
319 if groups:
320 groups = groups.strip()
321 if groups != '':
322 if groups.startswith('__ob__'):
323 groups = groups[6:]
324 if groups.endswith('__cb__'):
325 groups = groups[:-6]
326 groups = [int(x.strip()) for x in groups.split(',')]
327
328 for k, v in cv_json.items():
329 if v == '':
330 cv_json[k] = None
331
332 test_fold = cv_json.get('test_fold', None)
333 if test_fold:
334 if test_fold.startswith('__ob__'):
335 test_fold = test_fold[6:]
336 if test_fold.endswith('__cb__'):
337 test_fold = test_fold[:-6]
338 cv_json['test_fold'] = [int(x.strip()) for x in test_fold.split(',')]
339
340 test_size = cv_json.get('test_size', None)
341 if test_size and test_size > 1.0:
342 cv_json['test_size'] = int(test_size)
343
344 cv_class = getattr(model_selection, cv)
345 splitter = cv_class(**cv_json)
346
347 return splitter, groups
348
349
350 # needed when sklearn < v0.20
351 def balanced_accuracy_score(y_true, y_pred):
352 C = metrics.confusion_matrix(y_true, y_pred)
353 with np.errstate(divide='ignore', invalid='ignore'):
354 per_class = np.diag(C) / C.sum(axis=1)
355 if np.any(np.isnan(per_class)):
356 warnings.warn('y_pred contains classes not in y_true')
357 per_class = per_class[~np.isnan(per_class)]
358 score = np.mean(per_class)
359 return score
271 360
272 361
273 def get_scoring(scoring_json): 362 def get_scoring(scoring_json):
274 def balanced_accuracy_score(y_true, y_pred): 363
275 C = metrics.confusion_matrix(y_true, y_pred) 364 if scoring_json['primary_scoring'] == 'default':
276 with np.errstate(divide='ignore', invalid='ignore'):
277 per_class = np.diag(C) / C.sum(axis=1)
278 if np.any(np.isnan(per_class)):
279 warnings.warn('y_pred contains classes not in y_true')
280 per_class = per_class[~np.isnan(per_class)]
281 score = np.mean(per_class)
282 return score
283
284 if scoring_json['primary_scoring'] == "default":
285 return None 365 return None
286 366
287 my_scorers = metrics.SCORERS 367 my_scorers = metrics.SCORERS
288 if 'balanced_accuracy' not in my_scorers: 368 if 'balanced_accuracy' not in my_scorers:
289 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score) 369 my_scorers['balanced_accuracy'] = metrics.make_scorer(balanced_accuracy_score)