comparison stacking_ensembles.py @ 0:59e8b4328c82 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:40:10 +0000
parents
children f93f0cdbaf18
comparison
equal deleted inserted replaced
-1:000000000000 0:59e8b4328c82
1 import argparse
2 import ast
3 import json
4 import pickle
5 import sys
6 import warnings
7
8 import mlxtend.classifier
9 import mlxtend.regressor
10 import pandas as pd
11 from galaxy_ml.utils import get_cv, get_estimator, get_search_params, load_model
12
13
14 warnings.filterwarnings("ignore")
15
16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
17
18
19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None):
20 """
21 Parameter
22 ---------
23 inputs_path : str
24 File path for Galaxy parameters
25
26 output_obj : str
27 File path for ensemble estimator ouput
28
29 base_paths : str
30 File path or paths concatenated by comma.
31
32 meta_path : str
33 File path
34
35 outfile_params : str
36 File path for params output
37 """
38 with open(inputs_path, "r") as param_handler:
39 params = json.load(param_handler)
40
41 estimator_type = params["algo_selection"]["estimator_type"]
42 # get base estimators
43 base_estimators = []
44 for idx, base_file in enumerate(base_paths.split(",")):
45 if base_file and base_file != "None":
46 with open(base_file, "rb") as handler:
47 model = load_model(handler)
48 else:
49 estimator_json = params["base_est_builder"][idx]["estimator_selector"]
50 model = get_estimator(estimator_json)
51
52 if estimator_type.startswith("sklearn"):
53 named = model.__class__.__name__.lower()
54 named = "base_%d_%s" % (idx, named)
55 base_estimators.append((named, model))
56 else:
57 base_estimators.append(model)
58
59 # get meta estimator, if applicable
60 if estimator_type.startswith("mlxtend"):
61 if meta_path:
62 with open(meta_path, "rb") as f:
63 meta_estimator = load_model(f)
64 else:
65 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"]
66 meta_estimator = get_estimator(estimator_json)
67
68 options = params["algo_selection"]["options"]
69
70 cv_selector = options.pop("cv_selector", None)
71 if cv_selector:
72 splitter, _groups = get_cv(cv_selector)
73 options["cv"] = splitter
74 # set n_jobs
75 options["n_jobs"] = N_JOBS
76
77 weights = options.pop("weights", None)
78 if weights:
79 weights = ast.literal_eval(weights)
80 if weights:
81 options["weights"] = weights
82
83 mod_and_name = estimator_type.split("_")
84 mod = sys.modules[mod_and_name[0]]
85 klass = getattr(mod, mod_and_name[1])
86
87 if estimator_type.startswith("sklearn"):
88 options["n_jobs"] = N_JOBS
89 ensemble_estimator = klass(base_estimators, **options)
90
91 elif mod == mlxtend.classifier:
92 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options)
93
94 else:
95 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options)
96
97 print(ensemble_estimator)
98 for base_est in base_estimators:
99 print(base_est)
100
101 with open(output_obj, "wb") as out_handler:
102 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
103
104 if params["get_params"] and outfile_params:
105 results = get_search_params(ensemble_estimator)
106 df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
107 df.to_csv(outfile_params, sep="\t", index=False)
108
109
110 if __name__ == "__main__":
111 aparser = argparse.ArgumentParser()
112 aparser.add_argument("-b", "--bases", dest="bases")
113 aparser.add_argument("-m", "--meta", dest="meta")
114 aparser.add_argument("-i", "--inputs", dest="inputs")
115 aparser.add_argument("-o", "--outfile", dest="outfile")
116 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
117 args = aparser.parse_args()
118
119 main(
120 args.inputs,
121 args.outfile,
122 base_paths=args.bases,
123 meta_path=args.meta,
124 outfile_params=args.outfile_params,
125 )