comparison stacking_ensembles.py @ 24:b9ed7b774ba3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ab963ec9498bd05d2fb2f24f75adb2fccae7958c
author bgruening
date Wed, 15 May 2019 07:43:48 -0400
parents
children d0ed8e976b79
comparison
equal deleted inserted replaced
23:27c0b1a050df 24:b9ed7b774ba3
1 import argparse
2 import json
3 import pandas as pd
4 import pickle
5 import xgboost
6 import warnings
7 from sklearn import (cluster, compose, decomposition, ensemble,
8 feature_extraction, feature_selection,
9 gaussian_process, kernel_approximation, metrics,
10 model_selection, naive_bayes, neighbors,
11 pipeline, preprocessing, svm, linear_model,
12 tree, discriminant_analysis)
13 from sklearn.model_selection._split import check_cv
14 from feature_selectors import (DyRFE, DyRFECV,
15 MyPipeline, MyimbPipeline)
16 from iraps_classifier import (IRAPSCore, IRAPSClassifier,
17 BinarizeTargetClassifier,
18 BinarizeTargetRegressor)
19 from preprocessors import Z_RandomOverSampler
20 from utils import load_model, get_cv, get_estimator, get_search_params
21
22 from mlxtend.regressor import StackingCVRegressor, StackingRegressor
23 from mlxtend.classifier import StackingCVClassifier, StackingClassifier
24
25
26 warnings.filterwarnings('ignore')
27
28 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
29
30
31 def main(inputs_path, output_obj, base_paths=None, meta_path=None,
32 outfile_params=None):
33 """
34 Parameter
35 ---------
36 inputs_path : str
37 File path for Galaxy parameters
38
39 output_obj : str
40 File path for ensemble estimator ouput
41
42 base_paths : str
43 File path or paths concatenated by comma.
44
45 meta_path : str
46 File path
47
48 outfile_params : str
49 File path for params output
50 """
51 with open(inputs_path, 'r') as param_handler:
52 params = json.load(param_handler)
53
54 base_estimators = []
55 for idx, base_file in enumerate(base_paths.split(',')):
56 if base_file and base_file != 'None':
57 with open(base_file, 'rb') as handler:
58 model = load_model(handler)
59 else:
60 estimator_json = (params['base_est_builder'][idx]
61 ['estimator_selector'])
62 model = get_estimator(estimator_json)
63 base_estimators.append(model)
64
65 if meta_path:
66 with open(meta_path, 'rb') as f:
67 meta_estimator = load_model(f)
68 else:
69 estimator_json = params['meta_estimator']['estimator_selector']
70 meta_estimator = get_estimator(estimator_json)
71
72 options = params['algo_selection']['options']
73
74 cv_selector = options.pop('cv_selector', None)
75 if cv_selector:
76 splitter, groups = get_cv(cv_selector)
77 options['cv'] = splitter
78 # set n_jobs
79 options['n_jobs'] = N_JOBS
80
81 if params['algo_selection']['estimator_type'] == 'StackingCVClassifier':
82 ensemble_estimator = StackingCVClassifier(
83 classifiers=base_estimators,
84 meta_classifier=meta_estimator,
85 **options)
86
87 elif params['algo_selection']['estimator_type'] == 'StackingClassifier':
88 ensemble_estimator = StackingClassifier(
89 classifiers=base_estimators,
90 meta_classifier=meta_estimator,
91 **options)
92
93 elif params['algo_selection']['estimator_type'] == 'StackingCVRegressor':
94 ensemble_estimator = StackingCVRegressor(
95 regressors=base_estimators,
96 meta_regressor=meta_estimator,
97 **options)
98
99 else:
100 ensemble_estimator = StackingRegressor(
101 regressors=base_estimators,
102 meta_regressor=meta_estimator,
103 **options)
104
105 print(ensemble_estimator)
106 for base_est in base_estimators:
107 print(base_est)
108
109 with open(output_obj, 'wb') as out_handler:
110 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
111
112 if params['get_params'] and outfile_params:
113 results = get_search_params(ensemble_estimator)
114 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value'])
115 df.to_csv(outfile_params, sep='\t', index=False)
116
117
118 if __name__ == '__main__':
119 aparser = argparse.ArgumentParser()
120 aparser.add_argument("-b", "--bases", dest="bases")
121 aparser.add_argument("-m", "--meta", dest="meta")
122 aparser.add_argument("-i", "--inputs", dest="inputs")
123 aparser.add_argument("-o", "--outfile", dest="outfile")
124 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
125 args = aparser.parse_args()
126
127 main(args.inputs, args.outfile, base_paths=args.bases,
128 meta_path=args.meta, outfile_params=args.outfile_params)