comparison stacking_ensembles.py @ 28:0cc5098a9bff draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 19:13:15 +0000
parents 3703b5796442
children 9c8ede554b7c
comparison
equal deleted inserted replaced
27:7d98dfb04d27 28:0cc5098a9bff
3 import json 3 import json
4 import mlxtend.regressor 4 import mlxtend.regressor
5 import mlxtend.classifier 5 import mlxtend.classifier
6 import pandas as pd 6 import pandas as pd
7 import pickle 7 import pickle
8 import sklearn
9 import sys 8 import sys
10 import warnings 9 import warnings
11 from sklearn import ensemble 10 from galaxy_ml.utils import load_model, get_cv, get_estimator, get_search_params
12
13 from galaxy_ml.utils import (load_model, get_cv, get_estimator,
14 get_search_params)
15 11
16 12
17 warnings.filterwarnings('ignore') 13 warnings.filterwarnings("ignore")
18 14
19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 15 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
20 16
21 17
22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, 18 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None):
23 outfile_params=None):
24 """ 19 """
25 Parameter 20 Parameter
26 --------- 21 ---------
27 inputs_path : str 22 inputs_path : str
28 File path for Galaxy parameters 23 File path for Galaxy parameters
37 File path 32 File path
38 33
39 outfile_params : str 34 outfile_params : str
40 File path for params output 35 File path for params output
41 """ 36 """
42 with open(inputs_path, 'r') as param_handler: 37 with open(inputs_path, "r") as param_handler:
43 params = json.load(param_handler) 38 params = json.load(param_handler)
44 39
45 estimator_type = params['algo_selection']['estimator_type'] 40 estimator_type = params["algo_selection"]["estimator_type"]
46 # get base estimators 41 # get base estimators
47 base_estimators = [] 42 base_estimators = []
48 for idx, base_file in enumerate(base_paths.split(',')): 43 for idx, base_file in enumerate(base_paths.split(",")):
49 if base_file and base_file != 'None': 44 if base_file and base_file != "None":
50 with open(base_file, 'rb') as handler: 45 with open(base_file, "rb") as handler:
51 model = load_model(handler) 46 model = load_model(handler)
52 else: 47 else:
53 estimator_json = (params['base_est_builder'][idx] 48 estimator_json = params["base_est_builder"][idx]["estimator_selector"]
54 ['estimator_selector'])
55 model = get_estimator(estimator_json) 49 model = get_estimator(estimator_json)
56 50
57 if estimator_type.startswith('sklearn'): 51 if estimator_type.startswith("sklearn"):
58 named = model.__class__.__name__.lower() 52 named = model.__class__.__name__.lower()
59 named = 'base_%d_%s' % (idx, named) 53 named = "base_%d_%s" % (idx, named)
60 base_estimators.append((named, model)) 54 base_estimators.append((named, model))
61 else: 55 else:
62 base_estimators.append(model) 56 base_estimators.append(model)
63 57
64 # get meta estimator, if applicable 58 # get meta estimator, if applicable
65 if estimator_type.startswith('mlxtend'): 59 if estimator_type.startswith("mlxtend"):
66 if meta_path: 60 if meta_path:
67 with open(meta_path, 'rb') as f: 61 with open(meta_path, "rb") as f:
68 meta_estimator = load_model(f) 62 meta_estimator = load_model(f)
69 else: 63 else:
70 estimator_json = (params['algo_selection'] 64 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"]
71 ['meta_estimator']['estimator_selector'])
72 meta_estimator = get_estimator(estimator_json) 65 meta_estimator = get_estimator(estimator_json)
73 66
74 options = params['algo_selection']['options'] 67 options = params["algo_selection"]["options"]
75 68
76 cv_selector = options.pop('cv_selector', None) 69 cv_selector = options.pop("cv_selector", None)
77 if cv_selector: 70 if cv_selector:
78 splitter, groups = get_cv(cv_selector) 71 splitter, _groups = get_cv(cv_selector)
79 options['cv'] = splitter 72 options["cv"] = splitter
80 # set n_jobs 73 # set n_jobs
81 options['n_jobs'] = N_JOBS 74 options["n_jobs"] = N_JOBS
82 75
83 weights = options.pop('weights', None) 76 weights = options.pop("weights", None)
84 if weights: 77 if weights:
85 weights = ast.literal_eval(weights) 78 weights = ast.literal_eval(weights)
86 if weights: 79 if weights:
87 options['weights'] = weights 80 options["weights"] = weights
88 81
89 mod_and_name = estimator_type.split('_') 82 mod_and_name = estimator_type.split("_")
90 mod = sys.modules[mod_and_name[0]] 83 mod = sys.modules[mod_and_name[0]]
91 klass = getattr(mod, mod_and_name[1]) 84 klass = getattr(mod, mod_and_name[1])
92 85
93 if estimator_type.startswith('sklearn'): 86 if estimator_type.startswith("sklearn"):
94 options['n_jobs'] = N_JOBS 87 options["n_jobs"] = N_JOBS
95 ensemble_estimator = klass(base_estimators, **options) 88 ensemble_estimator = klass(base_estimators, **options)
96 89
97 elif mod == mlxtend.classifier: 90 elif mod == mlxtend.classifier:
98 ensemble_estimator = klass( 91 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options)
99 classifiers=base_estimators,
100 meta_classifier=meta_estimator,
101 **options)
102 92
103 else: 93 else:
104 ensemble_estimator = klass( 94 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options)
105 regressors=base_estimators,
106 meta_regressor=meta_estimator,
107 **options)
108 95
109 print(ensemble_estimator) 96 print(ensemble_estimator)
110 for base_est in base_estimators: 97 for base_est in base_estimators:
111 print(base_est) 98 print(base_est)
112 99
113 with open(output_obj, 'wb') as out_handler: 100 with open(output_obj, "wb") as out_handler:
114 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) 101 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
115 102
116 if params['get_params'] and outfile_params: 103 if params["get_params"] and outfile_params:
117 results = get_search_params(ensemble_estimator) 104 results = get_search_params(ensemble_estimator)
118 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) 105 df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
119 df.to_csv(outfile_params, sep='\t', index=False) 106 df.to_csv(outfile_params, sep="\t", index=False)
120 107
121 108
122 if __name__ == '__main__': 109 if __name__ == "__main__":
123 aparser = argparse.ArgumentParser() 110 aparser = argparse.ArgumentParser()
124 aparser.add_argument("-b", "--bases", dest="bases") 111 aparser.add_argument("-b", "--bases", dest="bases")
125 aparser.add_argument("-m", "--meta", dest="meta") 112 aparser.add_argument("-m", "--meta", dest="meta")
126 aparser.add_argument("-i", "--inputs", dest="inputs") 113 aparser.add_argument("-i", "--inputs", dest="inputs")
127 aparser.add_argument("-o", "--outfile", dest="outfile") 114 aparser.add_argument("-o", "--outfile", dest="outfile")
128 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") 115 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
129 args = aparser.parse_args() 116 args = aparser.parse_args()
130 117
131 main(args.inputs, args.outfile, base_paths=args.bases, 118 main(
132 meta_path=args.meta, outfile_params=args.outfile_params) 119 args.inputs,
120 args.outfile,
121 base_paths=args.bases,
122 meta_path=args.meta,
123 outfile_params=args.outfile_params,
124 )