comparison stacking_ensembles.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
comparison
equal deleted inserted replaced
-1:000000000000 0:af2624d5ab32
1 import argparse
2 import ast
3 import json
4 import pickle
5 import sys
6 import warnings
7
8 import mlxtend.classifier
9 import mlxtend.regressor
10 import pandas as pd
11 from galaxy_ml.utils import (get_cv, get_estimator, get_search_params,
12 load_model)
13
14 warnings.filterwarnings("ignore")
15
16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
17
18
19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None):
20 """
21 Parameter
22 ---------
23 inputs_path : str
24 File path for Galaxy parameters
25
26 output_obj : str
27 File path for ensemble estimator ouput
28
29 base_paths : str
30 File path or paths concatenated by comma.
31
32 meta_path : str
33 File path
34
35 outfile_params : str
36 File path for params output
37 """
38 with open(inputs_path, "r") as param_handler:
39 params = json.load(param_handler)
40
41 estimator_type = params["algo_selection"]["estimator_type"]
42 # get base estimators
43 base_estimators = []
44 for idx, base_file in enumerate(base_paths.split(",")):
45 if base_file and base_file != "None":
46 with open(base_file, "rb") as handler:
47 model = load_model(handler)
48 else:
49 estimator_json = params["base_est_builder"][idx]["estimator_selector"]
50 model = get_estimator(estimator_json)
51
52 if estimator_type.startswith("sklearn"):
53 named = model.__class__.__name__.lower()
54 named = "base_%d_%s" % (idx, named)
55 base_estimators.append((named, model))
56 else:
57 base_estimators.append(model)
58
59 # get meta estimator, if applicable
60 if estimator_type.startswith("mlxtend"):
61 if meta_path:
62 with open(meta_path, "rb") as f:
63 meta_estimator = load_model(f)
64 else:
65 estimator_json = params["algo_selection"]["meta_estimator"][
66 "estimator_selector"
67 ]
68 meta_estimator = get_estimator(estimator_json)
69
70 options = params["algo_selection"]["options"]
71
72 cv_selector = options.pop("cv_selector", None)
73 if cv_selector:
74 splitter, _groups = get_cv(cv_selector)
75 options["cv"] = splitter
76 # set n_jobs
77 options["n_jobs"] = N_JOBS
78
79 weights = options.pop("weights", None)
80 if weights:
81 weights = ast.literal_eval(weights)
82 if weights:
83 options["weights"] = weights
84
85 mod_and_name = estimator_type.split("_")
86 mod = sys.modules[mod_and_name[0]]
87 klass = getattr(mod, mod_and_name[1])
88
89 if estimator_type.startswith("sklearn"):
90 options["n_jobs"] = N_JOBS
91 ensemble_estimator = klass(base_estimators, **options)
92
93 elif mod == mlxtend.classifier:
94 ensemble_estimator = klass(
95 classifiers=base_estimators, meta_classifier=meta_estimator, **options
96 )
97
98 else:
99 ensemble_estimator = klass(
100 regressors=base_estimators, meta_regressor=meta_estimator, **options
101 )
102
103 print(ensemble_estimator)
104 for base_est in base_estimators:
105 print(base_est)
106
107 with open(output_obj, "wb") as out_handler:
108 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
109
110 if params["get_params"] and outfile_params:
111 results = get_search_params(ensemble_estimator)
112 df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
113 df.to_csv(outfile_params, sep="\t", index=False)
114
115
116 if __name__ == "__main__":
117 aparser = argparse.ArgumentParser()
118 aparser.add_argument("-b", "--bases", dest="bases")
119 aparser.add_argument("-m", "--meta", dest="meta")
120 aparser.add_argument("-i", "--inputs", dest="inputs")
121 aparser.add_argument("-o", "--outfile", dest="outfile")
122 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
123 args = aparser.parse_args()
124
125 main(
126 args.inputs,
127 args.outfile,
128 base_paths=args.bases,
129 meta_path=args.meta,
130 outfile_params=args.outfile_params,
131 )