Mercurial > repos > bgruening > sklearn_nn_classifier
comparison stacking_ensembles.py @ 10:e9ba818e7877 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit c0a3a186966888e5787335a7628bf0a4382637e7
author | bgruening |
---|---|
date | Tue, 14 May 2019 18:09:43 -0400 |
parents | |
children | d0efc68a3ddb |
comparison
equal
deleted
inserted
replaced
9:ed7b1654e841 | 10:e9ba818e7877 |
---|---|
1 import argparse | |
2 import json | |
3 import pandas as pd | |
4 import pickle | |
5 import xgboost | |
6 import warnings | |
7 from sklearn import (cluster, compose, decomposition, ensemble, | |
8 feature_extraction, feature_selection, | |
9 gaussian_process, kernel_approximation, metrics, | |
10 model_selection, naive_bayes, neighbors, | |
11 pipeline, preprocessing, svm, linear_model, | |
12 tree, discriminant_analysis) | |
13 from sklearn.model_selection._split import check_cv | |
14 from feature_selectors import (DyRFE, DyRFECV, | |
15 MyPipeline, MyimbPipeline) | |
16 from iraps_classifier import (IRAPSCore, IRAPSClassifier, | |
17 BinarizeTargetClassifier, | |
18 BinarizeTargetRegressor) | |
19 from preprocessors import Z_RandomOverSampler | |
20 from utils import load_model, get_cv, get_estimator, get_search_params | |
21 | |
22 from mlxtend.regressor import StackingCVRegressor, StackingRegressor | |
23 from mlxtend.classifier import StackingCVClassifier, StackingClassifier | |
24 | |
25 | |
26 warnings.filterwarnings('ignore') | |
27 | |
28 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
29 | |
30 | |
31 def main(inputs_path, output_obj, base_paths=None, meta_path=None, | |
32 outfile_params=None): | |
33 """ | |
34 Parameter | |
35 --------- | |
36 inputs_path : str | |
37 File path for Galaxy parameters | |
38 | |
39 output_obj : str | |
40 File path for ensemble estimator ouput | |
41 | |
42 base_paths : str | |
43 File path or paths concatenated by comma. | |
44 | |
45 meta_path : str | |
46 File path | |
47 | |
48 outfile_params : str | |
49 File path for params output | |
50 """ | |
51 with open(inputs_path, 'r') as param_handler: | |
52 params = json.load(param_handler) | |
53 | |
54 base_estimators = [] | |
55 for idx, base_file in enumerate(base_paths.split(',')): | |
56 if base_file and base_file != 'None': | |
57 with open(base_file, 'rb') as handler: | |
58 model = load_model(handler) | |
59 else: | |
60 estimator_json = (params['base_est_builder'][idx] | |
61 ['estimator_selector']) | |
62 model = get_estimator(estimator_json) | |
63 base_estimators.append(model) | |
64 | |
65 if meta_path: | |
66 with open(meta_path, 'rb') as f: | |
67 meta_estimator = load_model(f) | |
68 else: | |
69 estimator_json = params['meta_estimator']['estimator_selector'] | |
70 meta_estimator = get_estimator(estimator_json) | |
71 | |
72 options = params['algo_selection']['options'] | |
73 | |
74 cv_selector = options.pop('cv_selector', None) | |
75 if cv_selector: | |
76 splitter, groups = get_cv(cv_selector) | |
77 options['cv'] = splitter | |
78 # set n_jobs | |
79 options['n_jobs'] = N_JOBS | |
80 | |
81 if params['algo_selection']['estimator_type'] == 'StackingCVClassifier': | |
82 ensemble_estimator = StackingCVClassifier( | |
83 classifiers=base_estimators, | |
84 meta_classifier=meta_estimator, | |
85 **options) | |
86 | |
87 elif params['algo_selection']['estimator_type'] == 'StackingClassifier': | |
88 ensemble_estimator = StackingClassifier( | |
89 classifiers=base_estimators, | |
90 meta_classifier=meta_estimator, | |
91 **options) | |
92 | |
93 elif params['algo_selection']['estimator_type'] == 'StackingCVRegressor': | |
94 ensemble_estimator = StackingCVRegressor( | |
95 regressors=base_estimators, | |
96 meta_regressor=meta_estimator, | |
97 **options) | |
98 | |
99 else: | |
100 ensemble_estimator = StackingRegressor( | |
101 regressors=base_estimators, | |
102 meta_regressor=meta_estimator, | |
103 **options) | |
104 | |
105 print(ensemble_estimator) | |
106 for base_est in base_estimators: | |
107 print(base_est) | |
108 | |
109 with open(output_obj, 'wb') as out_handler: | |
110 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
111 | |
112 if params['get_params'] and outfile_params: | |
113 results = get_search_params(ensemble_estimator) | |
114 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) | |
115 df.to_csv(outfile_params, sep='\t', index=False) | |
116 | |
117 | |
118 if __name__ == '__main__': | |
119 aparser = argparse.ArgumentParser() | |
120 aparser.add_argument("-b", "--bases", dest="bases") | |
121 aparser.add_argument("-m", "--meta", dest="meta") | |
122 aparser.add_argument("-i", "--inputs", dest="inputs") | |
123 aparser.add_argument("-o", "--outfile", dest="outfile") | |
124 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
125 args = aparser.parse_args() | |
126 | |
127 main(args.inputs, args.outfile, base_paths=args.bases, | |
128 meta_path=args.meta, outfile_params=args.outfile_params) |