Mercurial > repos > bgruening > sklearn_ensemble
diff model_prediction.py @ 41:6546d7c9f08b draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:52:25 +0000 |
parents | 4ecc0ce9d0a2 |
children |
line wrap: on
line diff
--- a/model_prediction.py Thu Aug 11 09:18:09 2022 +0000 +++ b/model_prediction.py Wed Aug 09 12:52:25 2023 +0000 @@ -4,9 +4,10 @@ import numpy as np import pandas as pd -from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr +from galaxy_ml.model_persist import load_model_from_h5 +from galaxy_ml.utils import (clean_params, get_module, read_columns, + try_get_attr) from scipy.io import mmread -from sklearn.pipeline import Pipeline N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) @@ -15,7 +16,6 @@ inputs, infile_estimator, outfile_predict, - infile_weights=None, infile1=None, fasta_path=None, ref_seq=None, @@ -27,15 +27,12 @@ inputs : str File path to galaxy tool parameter - infile_estimator : strgit + infile_estimator : str File path to trained estimator input outfile_predict : str File path to save the prediction results, tabular - infile_weights : str - File path to weights input - infile1 : str File path to dataset containing features @@ -54,19 +51,8 @@ params = json.load(param_handler) # load model - with open(infile_estimator, "rb") as est_handler: - estimator = load_model(est_handler) - - main_est = estimator - if isinstance(estimator, Pipeline): - main_est = estimator.steps[-1][-1] - if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): - if not infile_weights or infile_weights == "None": - raise ValueError( - "The selected model skeleton asks for weights, " - "but dataset for weights wan not selected!" - ) - main_est.load_weights(infile_weights) + estimator = load_model_from_h5(infile_estimator) + estimator = clean_params(estimator) # handle data input input_type = params["input_options"]["selected_input"] @@ -221,7 +207,6 @@ aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") - aparser.add_argument("-w", "--infile_weights", dest="infile_weights") aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") @@ -233,7 +218,6 @@ args.inputs, args.infile_estimator, args.outfile_predict, - infile_weights=args.infile_weights, infile1=args.infile1, fasta_path=args.fasta_path, ref_seq=args.ref_seq,