comparison train_test_eval.py @ 5:b650955a20cc draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:04:21 +0000
parents f93f0cdbaf18
children
comparison
equal deleted inserted replaced
4:14180f9c831e 5:b650955a20cc
1 import argparse 1 import argparse
2 import json 2 import json
3 import os 3 import os
4 import pickle
5 import warnings 4 import warnings
6 from itertools import chain 5 from itertools import chain
7 6
8 import joblib 7 import joblib
9 import numpy as np 8 import numpy as np
10 import pandas as pd 9 import pandas as pd
10 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
11 from galaxy_ml.model_validations import train_test_split 11 from galaxy_ml.model_validations import train_test_split
12 from galaxy_ml.utils import (get_module, get_scoring, load_model, 12 from galaxy_ml.utils import (
13 read_columns, SafeEval, try_get_attr) 13 clean_params,
14 get_module,
15 get_scoring,
16 read_columns,
17 SafeEval,
18 try_get_attr
19 )
14 from scipy.io import mmread 20 from scipy.io import mmread
15 from sklearn import pipeline 21 from sklearn import pipeline
16 from sklearn.metrics.scorer import _check_multimetric_scoring
17 from sklearn.model_selection import _search, _validation 22 from sklearn.model_selection import _search, _validation
18 from sklearn.model_selection._validation import _score 23 from sklearn.model_selection._validation import _score
19 from sklearn.utils import indexable, safe_indexing 24 from sklearn.utils import _safe_indexing, indexable
20 25
21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") 26 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
22 setattr(_search, "_fit_and_score", _fit_and_score) 27 setattr(_search, "_fit_and_score", _fit_and_score)
23 setattr(_validation, "_fit_and_score", _fit_and_score) 28 setattr(_validation, "_fit_and_score", _fit_and_score)
24 29
91 index_arr = np.arange(n_samples) 96 index_arr = np.arange(n_samples)
92 test = index_arr[np.isin(groups, group_names)] 97 test = index_arr[np.isin(groups, group_names)]
93 train = index_arr[~np.isin(groups, group_names)] 98 train = index_arr[~np.isin(groups, group_names)]
94 rval = list( 99 rval = list(
95 chain.from_iterable( 100 chain.from_iterable(
96 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays 101 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays
97 ) 102 )
98 ) 103 )
99 else: 104 else:
100 rval = train_test_split(*new_arrays, **kwargs) 105 rval = train_test_split(*new_arrays, **kwargs)
101 106
162 167
163 with open(inputs, "r") as param_handler: 168 with open(inputs, "r") as param_handler:
164 params = json.load(param_handler) 169 params = json.load(param_handler)
165 170
166 # load estimator 171 # load estimator
167 with open(infile_estimator, "rb") as estimator_handler: 172 estimator = load_model_from_h5(infile_estimator)
168 estimator = load_model(estimator_handler) 173 estimator = clean_params(estimator)
169 174
170 # swap hyperparameter 175 # swap hyperparameter
171 swapping = params["experiment_schemes"]["hyperparams_swapping"] 176 swapping = params["experiment_schemes"]["hyperparams_swapping"]
172 swap_params = _eval_swap_params(swapping) 177 swap_params = _eval_swap_params(swapping)
173 estimator.set_params(**swap_params) 178 estimator.set_params(**swap_params)
346 secondary_scoring = scoring.get("secondary_scoring", None) 351 secondary_scoring = scoring.get("secondary_scoring", None)
347 if secondary_scoring is not None: 352 if secondary_scoring is not None:
348 # If secondary_scoring is specified, convert the list into comman separated string 353 # If secondary_scoring is specified, convert the list into comman separated string
349 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) 354 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
350 scorer = get_scoring(scoring) 355 scorer = get_scoring(scoring)
351 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
352 356
353 # handle test (first) split 357 # handle test (first) split
354 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] 358 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
355 359
356 if test_split_options["shuffle"] == "group": 360 if test_split_options["shuffle"] == "group":
410 if hasattr(estimator, "evaluate"): 414 if hasattr(estimator, "evaluate"):
411 scores = estimator.evaluate( 415 scores = estimator.evaluate(
412 X_test, y_test=y_test, scorer=scorer, is_multimetric=True 416 X_test, y_test=y_test, scorer=scorer, is_multimetric=True
413 ) 417 )
414 else: 418 else:
415 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) 419 scores = _score(estimator, X_test, y_test, scorer)
416 # handle output 420 # handle output
417 for name, score in scores.items(): 421 for name, score in scores.items():
418 scores[name] = [score] 422 scores[name] = [score]
419 df = pd.DataFrame(scores) 423 df = pd.DataFrame(scores)
420 df = df[sorted(df.columns)] 424 df = df[sorted(df.columns)]
439 if getattr(main_est, "validation_data", None): 443 if getattr(main_est, "validation_data", None):
440 del main_est.validation_data 444 del main_est.validation_data
441 if getattr(main_est, "data_generator_", None): 445 if getattr(main_est, "data_generator_", None):
442 del main_est.data_generator_ 446 del main_est.data_generator_
443 447
444 with open(outfile_object, "wb") as output_handler: 448 dump_model_to_h5(estimator, outfile_object)
445 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
446 449
447 450
448 if __name__ == "__main__": 451 if __name__ == "__main__":
449 aparser = argparse.ArgumentParser() 452 aparser = argparse.ArgumentParser()
450 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 453 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)