diff stacking_ensembles.py @ 4:9349ed2749c6 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:01:50 +0000
parents af2624d5ab32
children
line wrap: on
line diff
--- a/stacking_ensembles.py	Thu Aug 11 09:38:31 2022 +0000
+++ b/stacking_ensembles.py	Wed Aug 09 13:01:50 2023 +0000
@@ -1,22 +1,22 @@
 import argparse
 import ast
 import json
-import pickle
 import sys
 import warnings
+from distutils.version import LooseVersion as Version
 
 import mlxtend.classifier
 import mlxtend.regressor
-import pandas as pd
-from galaxy_ml.utils import (get_cv, get_estimator, get_search_params,
-                             load_model)
+from galaxy_ml import __version__ as galaxy_ml_version
+from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
+from galaxy_ml.utils import get_cv, get_estimator
 
 warnings.filterwarnings("ignore")
 
 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
 
 
-def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None):
+def main(inputs_path, output_obj, base_paths=None, meta_path=None):
     """
     Parameter
     ---------
@@ -31,9 +31,6 @@
 
     meta_path : str
         File path
-
-    outfile_params : str
-        File path for params output
     """
     with open(inputs_path, "r") as param_handler:
         params = json.load(param_handler)
@@ -43,8 +40,7 @@
     base_estimators = []
     for idx, base_file in enumerate(base_paths.split(",")):
         if base_file and base_file != "None":
-            with open(base_file, "rb") as handler:
-                model = load_model(handler)
+            model = load_model_from_h5(base_file)
         else:
             estimator_json = params["base_est_builder"][idx]["estimator_selector"]
             model = get_estimator(estimator_json)
@@ -59,8 +55,7 @@
     # get meta estimator, if applicable
     if estimator_type.startswith("mlxtend"):
         if meta_path:
-            with open(meta_path, "rb") as f:
-                meta_estimator = load_model(f)
+            meta_estimator = load_model_from_h5(meta_path)
         else:
             estimator_json = params["algo_selection"]["meta_estimator"][
                 "estimator_selector"
@@ -71,7 +66,9 @@
 
     cv_selector = options.pop("cv_selector", None)
     if cv_selector:
-        splitter, _groups = get_cv(cv_selector)
+        if Version(galaxy_ml_version) < Version("0.8.3"):
+            cv_selector.pop("n_stratification_bins", None)
+        splitter, groups = get_cv(cv_selector)
         options["cv"] = splitter
         # set n_jobs
         options["n_jobs"] = N_JOBS
@@ -104,13 +101,7 @@
     for base_est in base_estimators:
         print(base_est)
 
-    with open(output_obj, "wb") as out_handler:
-        pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
-
-    if params["get_params"] and outfile_params:
-        results = get_search_params(ensemble_estimator)
-        df = pd.DataFrame(results, columns=["", "Parameter", "Value"])
-        df.to_csv(outfile_params, sep="\t", index=False)
+    dump_model_to_h5(ensemble_estimator, output_obj)
 
 
 if __name__ == "__main__":
@@ -119,13 +110,6 @@
     aparser.add_argument("-m", "--meta", dest="meta")
     aparser.add_argument("-i", "--inputs", dest="inputs")
     aparser.add_argument("-o", "--outfile", dest="outfile")
-    aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
     args = aparser.parse_args()
 
-    main(
-        args.inputs,
-        args.outfile,
-        base_paths=args.bases,
-        meta_path=args.meta,
-        outfile_params=args.outfile_params,
-    )
+    main(args.inputs, args.outfile, base_paths=args.bases, meta_path=args.meta)