comparison stacking_ensembles.py @ 6:13b9ac5d277c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:24:07 +0000
parents 0985b0dd6f1a
children 3312fb686ffb
comparison
equal deleted inserted replaced
5:ce2fd1edbc6e 6:13b9ac5d277c
1 import argparse 1 import argparse
2 import ast 2 import ast
3 import json 3 import json
4 import mlxtend.regressor
5 import mlxtend.classifier
6 import pandas as pd
7 import pickle 4 import pickle
8 import sklearn
9 import sys 5 import sys
10 import warnings 6 import warnings
11 from sklearn import ensemble
12 7
13 from galaxy_ml.utils import (load_model, get_cv, get_estimator, 8 import mlxtend.classifier
14 get_search_params) 9 import mlxtend.regressor
10 import pandas as pd
11 from galaxy_ml.utils import get_cv, get_estimator, get_search_params, load_model
15 12
16 13
17 warnings.filterwarnings('ignore') 14 warnings.filterwarnings("ignore")
18 15
19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
20 17
21 18
22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None):
23 outfile_params=None):
24 """ 20 """
25 Parameter 21 Parameter
26 --------- 22 ---------
27 inputs_path : str 23 inputs_path : str
28 File path for Galaxy parameters 24 File path for Galaxy parameters
37 File path 33 File path
38 34
39 outfile_params : str 35 outfile_params : str
40 File path for params output 36 File path for params output
41 """ 37 """
42 with open(inputs_path, 'r') as param_handler: 38 with open(inputs_path, "r") as param_handler:
43 params = json.load(param_handler) 39 params = json.load(param_handler)
44 40
45 estimator_type = params['algo_selection']['estimator_type'] 41 estimator_type = params["algo_selection"]["estimator_type"]
46 # get base estimators 42 # get base estimators
47 base_estimators = [] 43 base_estimators = []
48 for idx, base_file in enumerate(base_paths.split(',')): 44 for idx, base_file in enumerate(base_paths.split(",")):
49 if base_file and base_file != 'None': 45 if base_file and base_file != "None":
50 with open(base_file, 'rb') as handler: 46 with open(base_file, "rb") as handler:
51 model = load_model(handler) 47 model = load_model(handler)
52 else: 48 else:
53 estimator_json = (params['base_est_builder'][idx] 49 estimator_json = params["base_est_builder"][idx]["estimator_selector"]
54 ['estimator_selector'])
55 model = get_estimator(estimator_json) 50 model = get_estimator(estimator_json)
56 51
57 if estimator_type.startswith('sklearn'): 52 if estimator_type.startswith("sklearn"):
58 named = model.__class__.__name__.lower() 53 named = model.__class__.__name__.lower()
59 named = 'base_%d_%s' % (idx, named) 54 named = "base_%d_%s" % (idx, named)
60 base_estimators.append((named, model)) 55 base_estimators.append((named, model))
61 else: 56 else:
62 base_estimators.append(model) 57 base_estimators.append(model)
63 58
64 # get meta estimator, if applicable 59 # get meta estimator, if applicable
65 if estimator_type.startswith('mlxtend'): 60 if estimator_type.startswith("mlxtend"):
66 if meta_path: 61 if meta_path:
67 with open(meta_path, 'rb') as f: 62 with open(meta_path, "rb") as f:
68 meta_estimator = load_model(f) 63 meta_estimator = load_model(f)
69 else: 64 else:
70 estimator_json = (params['algo_selection'] 65 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"]
71 ['meta_estimator']['estimator_selector'])
72 meta_estimator = get_estimator(estimator_json) 66 meta_estimator = get_estimator(estimator_json)
73 67
74 options = params['algo_selection']['options'] 68 options = params["algo_selection"]["options"]
75 69
76 cv_selector = options.pop('cv_selector', None) 70 cv_selector = options.pop("cv_selector", None)
77 if cv_selector: 71 if cv_selector:
78 splitter, groups = get_cv(cv_selector) 72 splitter, _groups = get_cv(cv_selector)
79 options['cv'] = splitter 73 options["cv"] = splitter
80 # set n_jobs 74 # set n_jobs
81 options['n_jobs'] = N_JOBS 75 options["n_jobs"] = N_JOBS
82 76
83 weights = options.pop('weights', None) 77 weights = options.pop("weights", None)
84 if weights: 78 if weights:
85 weights = ast.literal_eval(weights) 79 weights = ast.literal_eval(weights)
86 if weights: 80 if weights:
87 options['weights'] = weights 81 options["weights"] = weights
88 82
89 mod_and_name = estimator_type.split('_') 83 mod_and_name = estimator_type.split("_")
90 mod = sys.modules[mod_and_name[0]] 84 mod = sys.modules[mod_and_name[0]]
91 klass = getattr(mod, mod_and_name[1]) 85 klass = getattr(mod, mod_and_name[1])
92 86
93 if estimator_type.startswith('sklearn'): 87 if estimator_type.startswith("sklearn"):
94 options['n_jobs'] = N_JOBS 88 options["n_jobs"] = N_JOBS
95 ensemble_estimator = klass(base_estimators, **options) 89 ensemble_estimator = klass(base_estimators, **options)
96 90
97 elif mod == mlxtend.classifier: 91 elif mod == mlxtend.classifier:
98 ensemble_estimator = klass( 92 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options)
99 classifiers=base_estimators,
100 meta_classifier=meta_estimator,
101 **options)
102 93
103 else: 94 else:
104 ensemble_estimator = klass( 95 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options)
105 regressors=base_estimators,
106 meta_regressor=meta_estimator,
107 **options)
108 96
109 print(ensemble_estimator) 97 print(ensemble_estimator)
110 for base_est in base_estimators: 98 for base_est in base_estimators:
111 print(base_est) 99 print(base_est)
112 100
113 with open(output_obj, 'wb') as out_handler: 101 with open(output_obj, "wb") as out_handler:
114 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) 102 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
115 103
116 if params['get_params'] and outfile_params: 104 if params["get_params"] and outfile_params:
117 results = get_search_params(ensemble_estimator) 105 results = get_search_params(ensemble_estimator)
118 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) 106 df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
119 df.to_csv(outfile_params, sep='\t', index=False) 107 df.to_csv(outfile_params, sep="\t", index=False)
120 108
121 109
122 if __name__ == '__main__': 110 if __name__ == "__main__":
123 aparser = argparse.ArgumentParser() 111 aparser = argparse.ArgumentParser()
124 aparser.add_argument("-b", "--bases", dest="bases") 112 aparser.add_argument("-b", "--bases", dest="bases")
125 aparser.add_argument("-m", "--meta", dest="meta") 113 aparser.add_argument("-m", "--meta", dest="meta")
126 aparser.add_argument("-i", "--inputs", dest="inputs") 114 aparser.add_argument("-i", "--inputs", dest="inputs")
127 aparser.add_argument("-o", "--outfile", dest="outfile") 115 aparser.add_argument("-o", "--outfile", dest="outfile")
128 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") 116 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
129 args = aparser.parse_args() 117 args = aparser.parse_args()
130 118
131 main(args.inputs, args.outfile, base_paths=args.bases, 119 main(
132 meta_path=args.meta, outfile_params=args.outfile_params) 120 args.inputs,
121 args.outfile,
122 base_paths=args.bases,
123 meta_path=args.meta,
124 outfile_params=args.outfile_params,
125 )