comparison fitted_model_eval.py @ 17:a01fa4e8fe4f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 12:54:40 +0000
parents c9ddd20d25d0
children
comparison
equal deleted inserted replaced
16:d0352e8b4c10 17:a01fa4e8fe4f
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings 3 import warnings
4 4
5 import pandas as pd 5 import pandas as pd
6 from galaxy_ml.utils import get_scoring, load_model, read_columns 6 from galaxy_ml.model_persist import load_model_from_h5
7 from galaxy_ml.utils import clean_params, get_scoring, read_columns
7 from scipy.io import mmread 8 from scipy.io import mmread
8 from sklearn.metrics.scorer import _check_multimetric_scoring 9 from sklearn.metrics._scorer import _check_multimetric_scoring
9 from sklearn.model_selection._validation import _score 10 from sklearn.model_selection._validation import _score
10 from sklearn.pipeline import Pipeline
11 11
12 12
13 def _get_X_y(params, infile1, infile2): 13 def _get_X_y(params, infile1, infile2):
14 """read from inputs and output X and y 14 """read from inputs and output X and y
15 15
73 else: 73 else:
74 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) 74 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
75 loaded_df[df_key] = infile2 75 loaded_df[df_key] = infile2
76 76
77 y = read_columns( 77 y = read_columns(
78 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True 78 infile2,
79 c=c,
80 c_option=column_option,
81 sep="\t",
82 header=header,
83 parse_dates=True,
79 ) 84 )
80 if len(y.shape) == 2 and y.shape[1] == 1: 85 if len(y.shape) == 2 and y.shape[1] == 1:
81 y = y.ravel() 86 y = y.ravel()
82 87
83 return X, y 88 return X, y
84 89
85 90
86 def main( 91 def main(inputs, infile_estimator, outfile_eval, infile1=None, infile2=None):
87 inputs,
88 infile_estimator,
89 outfile_eval,
90 infile_weights=None,
91 infile1=None,
92 infile2=None,
93 ):
94 """ 92 """
95 Parameter 93 Parameter
96 --------- 94 ---------
97 inputs : str 95 inputs : str
98 File path to galaxy tool parameter 96 File path to galaxy tool parameter
100 infile_estimator : strgit 98 infile_estimator : strgit
101 File path to trained estimator input 99 File path to trained estimator input
102 100
103 outfile_eval : str 101 outfile_eval : str
104 File path to save the evalulation results, tabular 102 File path to save the evalulation results, tabular
105
106 infile_weights : str
107 File path to weights input
108 103
109 infile1 : str 104 infile1 : str
110 File path to dataset containing features 105 File path to dataset containing features
111 106
112 infile2 : str 107 infile2 : str
118 params = json.load(param_handler) 113 params = json.load(param_handler)
119 114
120 X_test, y_test = _get_X_y(params, infile1, infile2) 115 X_test, y_test = _get_X_y(params, infile1, infile2)
121 116
122 # load model 117 # load model
123 with open(infile_estimator, "rb") as est_handler: 118 estimator = load_model_from_h5(infile_estimator)
124 estimator = load_model(est_handler) 119 estimator = clean_params(estimator)
125
126 main_est = estimator
127 if isinstance(estimator, Pipeline):
128 main_est = estimator.steps[-1][-1]
129 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
130 if not infile_weights or infile_weights == "None":
131 raise ValueError(
132 "The selected model skeleton asks for weights, "
133 "but no dataset for weights was provided!"
134 )
135 main_est.load_weights(infile_weights)
136 120
137 # handle scorer, convert to scorer dict 121 # handle scorer, convert to scorer dict
138 # Check if scoring is specified
139 scoring = params["scoring"] 122 scoring = params["scoring"]
140 if scoring is not None:
141 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
142 # Check if secondary_scoring is specified
143 secondary_scoring = scoring.get("secondary_scoring", None)
144 if secondary_scoring is not None:
145 # If secondary_scoring is specified, convert the list into comman separated string
146 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
147
148 scorer = get_scoring(scoring) 123 scorer = get_scoring(scoring)
149 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 124 if not isinstance(scorer, (dict, list)):
125 scorer = [scoring["primary_scoring"]]
126 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
150 127
151 if hasattr(estimator, "evaluate"): 128 if hasattr(estimator, "evaluate"):
152 scores = estimator.evaluate( 129 scores = estimator.evaluate(X_test, y_test=y_test, scorer=scorer)
153 X_test, y_test=y_test, scorer=scorer, is_multimetric=True
154 )
155 else: 130 else:
156 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) 131 scores = _score(estimator, X_test, y_test, scorer)
157 132
158 # handle output 133 # handle output
159 for name, score in scores.items(): 134 for name, score in scores.items():
160 scores[name] = [score] 135 scores[name] = [score]
161 df = pd.DataFrame(scores) 136 df = pd.DataFrame(scores)
165 140
166 if __name__ == "__main__": 141 if __name__ == "__main__":
167 aparser = argparse.ArgumentParser() 142 aparser = argparse.ArgumentParser()
168 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 143 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
169 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") 144 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
170 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
171 aparser.add_argument("-X", "--infile1", dest="infile1") 145 aparser.add_argument("-X", "--infile1", dest="infile1")
172 aparser.add_argument("-y", "--infile2", dest="infile2") 146 aparser.add_argument("-y", "--infile2", dest="infile2")
173 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval") 147 aparser.add_argument("-O", "--outfile_eval", dest="outfile_eval")
174 args = aparser.parse_args() 148 args = aparser.parse_args()
175 149
176 main( 150 main(
177 args.inputs, 151 args.inputs,
178 args.infile_estimator, 152 args.infile_estimator,
179 args.outfile_eval, 153 args.outfile_eval,
180 infile_weights=args.infile_weights,
181 infile1=args.infile1, 154 infile1=args.infile1,
182 infile2=args.infile2, 155 infile2=args.infile2,
183 ) 156 )