comparison model_prediction.py @ 15:2eb5c017958d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:15:27 +0000
parents caf7d2b71a48
children
comparison
equal deleted inserted replaced
14:4d1637cac794 15:2eb5c017958d
2 import json 2 import json
3 import warnings 3 import warnings
4 4
5 import numpy as np 5 import numpy as np
6 import pandas as pd 6 import pandas as pd
7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr 7 from galaxy_ml.model_persist import load_model_from_h5
8 from galaxy_ml.utils import (clean_params, get_module, read_columns,
9 try_get_attr)
8 from scipy.io import mmread 10 from scipy.io import mmread
9 from sklearn.pipeline import Pipeline
10 11
11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) 12 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
12 13
13 14
14 def main( 15 def main(
15 inputs, 16 inputs,
16 infile_estimator, 17 infile_estimator,
17 outfile_predict, 18 outfile_predict,
18 infile_weights=None,
19 infile1=None, 19 infile1=None,
20 fasta_path=None, 20 fasta_path=None,
21 ref_seq=None, 21 ref_seq=None,
22 vcf_path=None, 22 vcf_path=None,
23 ): 23 ):
25 Parameter 25 Parameter
26 --------- 26 ---------
27 inputs : str 27 inputs : str
28 File path to galaxy tool parameter 28 File path to galaxy tool parameter
29 29
30 infile_estimator : strgit 30 infile_estimator : str
31 File path to trained estimator input 31 File path to trained estimator input
32 32
33 outfile_predict : str 33 outfile_predict : str
34 File path to save the prediction results, tabular 34 File path to save the prediction results, tabular
35
36 infile_weights : str
37 File path to weights input
38 35
39 infile1 : str 36 infile1 : str
40 File path to dataset containing features 37 File path to dataset containing features
41 38
42 fasta_path : str 39 fasta_path : str
52 49
53 with open(inputs, "r") as param_handler: 50 with open(inputs, "r") as param_handler:
54 params = json.load(param_handler) 51 params = json.load(param_handler)
55 52
56 # load model 53 # load model
57 with open(infile_estimator, "rb") as est_handler: 54 estimator = load_model_from_h5(infile_estimator)
58 estimator = load_model(est_handler) 55 estimator = clean_params(estimator)
59
60 main_est = estimator
61 if isinstance(estimator, Pipeline):
62 main_est = estimator.steps[-1][-1]
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
64 if not infile_weights or infile_weights == "None":
65 raise ValueError(
66 "The selected model skeleton asks for weights, "
67 "but dataset for weights wan not selected!"
68 )
69 main_est.load_weights(infile_weights)
70 56
71 # handle data input 57 # handle data input
72 input_type = params["input_options"]["selected_input"] 58 input_type = params["input_options"]["selected_input"]
73 # tabular input 59 # tabular input
74 if input_type == "tabular": 60 if input_type == "tabular":
219 205
220 if __name__ == "__main__": 206 if __name__ == "__main__":
221 aparser = argparse.ArgumentParser() 207 aparser = argparse.ArgumentParser()
222 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 208 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
223 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 209 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
224 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
225 aparser.add_argument("-X", "--infile1", dest="infile1") 210 aparser.add_argument("-X", "--infile1", dest="infile1")
226 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") 211 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
227 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 212 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
228 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 213 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
229 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") 214 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
231 216
232 main( 217 main(
233 args.inputs, 218 args.inputs,
234 args.infile_estimator, 219 args.infile_estimator,
235 args.outfile_predict, 220 args.outfile_predict,
236 infile_weights=args.infile_weights,
237 infile1=args.infile1, 221 infile1=args.infile1,
238 fasta_path=args.fasta_path, 222 fasta_path=args.fasta_path,
239 ref_seq=args.ref_seq, 223 ref_seq=args.ref_seq,
240 vcf_path=args.vcf_path, 224 vcf_path=args.vcf_path,
241 ) 225 )