Mercurial > repos > bgruening > keras_train_and_eval
comparison stacking_ensembles.py @ 0:03f61bb3ca43 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author | bgruening |
---|---|
date | Mon, 16 Dec 2019 05:36:53 -0500 |
parents | |
children | 3866911c93ae |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:03f61bb3ca43 |
---|---|
1 import argparse | |
2 import ast | |
3 import json | |
4 import mlxtend.regressor | |
5 import mlxtend.classifier | |
6 import pandas as pd | |
7 import pickle | |
8 import sklearn | |
9 import sys | |
10 import warnings | |
11 from sklearn import ensemble | |
12 | |
13 from galaxy_ml.utils import (load_model, get_cv, get_estimator, | |
14 get_search_params) | |
15 | |
16 | |
17 warnings.filterwarnings('ignore') | |
18 | |
19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
20 | |
21 | |
22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, | |
23 outfile_params=None): | |
24 """ | |
25 Parameter | |
26 --------- | |
27 inputs_path : str | |
28 File path for Galaxy parameters | |
29 | |
30 output_obj : str | |
31 File path for ensemble estimator ouput | |
32 | |
33 base_paths : str | |
34 File path or paths concatenated by comma. | |
35 | |
36 meta_path : str | |
37 File path | |
38 | |
39 outfile_params : str | |
40 File path for params output | |
41 """ | |
42 with open(inputs_path, 'r') as param_handler: | |
43 params = json.load(param_handler) | |
44 | |
45 estimator_type = params['algo_selection']['estimator_type'] | |
46 # get base estimators | |
47 base_estimators = [] | |
48 for idx, base_file in enumerate(base_paths.split(',')): | |
49 if base_file and base_file != 'None': | |
50 with open(base_file, 'rb') as handler: | |
51 model = load_model(handler) | |
52 else: | |
53 estimator_json = (params['base_est_builder'][idx] | |
54 ['estimator_selector']) | |
55 model = get_estimator(estimator_json) | |
56 | |
57 if estimator_type.startswith('sklearn'): | |
58 named = model.__class__.__name__.lower() | |
59 named = 'base_%d_%s' % (idx, named) | |
60 base_estimators.append((named, model)) | |
61 else: | |
62 base_estimators.append(model) | |
63 | |
64 # get meta estimator, if applicable | |
65 if estimator_type.startswith('mlxtend'): | |
66 if meta_path: | |
67 with open(meta_path, 'rb') as f: | |
68 meta_estimator = load_model(f) | |
69 else: | |
70 estimator_json = (params['algo_selection'] | |
71 ['meta_estimator']['estimator_selector']) | |
72 meta_estimator = get_estimator(estimator_json) | |
73 | |
74 options = params['algo_selection']['options'] | |
75 | |
76 cv_selector = options.pop('cv_selector', None) | |
77 if cv_selector: | |
78 splitter, groups = get_cv(cv_selector) | |
79 options['cv'] = splitter | |
80 # set n_jobs | |
81 options['n_jobs'] = N_JOBS | |
82 | |
83 weights = options.pop('weights', None) | |
84 if weights: | |
85 weights = ast.literal_eval(weights) | |
86 if weights: | |
87 options['weights'] = weights | |
88 | |
89 mod_and_name = estimator_type.split('_') | |
90 mod = sys.modules[mod_and_name[0]] | |
91 klass = getattr(mod, mod_and_name[1]) | |
92 | |
93 if estimator_type.startswith('sklearn'): | |
94 options['n_jobs'] = N_JOBS | |
95 ensemble_estimator = klass(base_estimators, **options) | |
96 | |
97 elif mod == mlxtend.classifier: | |
98 ensemble_estimator = klass( | |
99 classifiers=base_estimators, | |
100 meta_classifier=meta_estimator, | |
101 **options) | |
102 | |
103 else: | |
104 ensemble_estimator = klass( | |
105 regressors=base_estimators, | |
106 meta_regressor=meta_estimator, | |
107 **options) | |
108 | |
109 print(ensemble_estimator) | |
110 for base_est in base_estimators: | |
111 print(base_est) | |
112 | |
113 with open(output_obj, 'wb') as out_handler: | |
114 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
115 | |
116 if params['get_params'] and outfile_params: | |
117 results = get_search_params(ensemble_estimator) | |
118 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) | |
119 df.to_csv(outfile_params, sep='\t', index=False) | |
120 | |
121 | |
122 if __name__ == '__main__': | |
123 aparser = argparse.ArgumentParser() | |
124 aparser.add_argument("-b", "--bases", dest="bases") | |
125 aparser.add_argument("-m", "--meta", dest="meta") | |
126 aparser.add_argument("-i", "--inputs", dest="inputs") | |
127 aparser.add_argument("-o", "--outfile", dest="outfile") | |
128 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
129 args = aparser.parse_args() | |
130 | |
131 main(args.inputs, args.outfile, base_paths=args.bases, | |
132 meta_path=args.meta, outfile_params=args.outfile_params) |