diff simple_model_fit.py @ 10:b3093f953091 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:30:51 +0000
parents ead8f1822587
children
line wrap: on
line diff
--- a/simple_model_fit.py	Thu Aug 11 09:15:54 2022 +0000
+++ b/simple_model_fit.py	Wed Aug 09 13:30:51 2023 +0000
@@ -1,9 +1,9 @@
 import argparse
 import json
-import pickle
 
 import pandas as pd
-from galaxy_ml.utils import load_model, read_columns
+from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
+from galaxy_ml.utils import read_columns
 from scipy.io import mmread
 from sklearn.pipeline import Pipeline
 
@@ -148,9 +148,9 @@
         params = json.load(param_handler)
 
     # load model
-    with open(infile_estimator, "rb") as est_handler:
-        estimator = load_model(est_handler)
-    estimator = clean_params(estimator, n_jobs=N_JOBS)
+    estimator = load_model_from_h5(infile_estimator)
+
+    estimator = clean_params(estimator)
 
     X_train, y_train = _get_X_y(params, infile1, infile2)
 
@@ -170,8 +170,7 @@
         if getattr(main_est, "data_generator_", None):
             del main_est.data_generator_
 
-    with open(out_object, "wb") as output_handler:
-        pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
+    dump_model_to_h5(estimator, out_object)
 
 
 if __name__ == "__main__":