comparison keras_train_and_eval.py @ 9:5e1581dfb419 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
author bgruening
date Fri, 03 Nov 2023 22:54:44 +0000
parents e8d6b4acb7c6
children
comparison
equal deleted inserted replaced
8:e8d6b4acb7c6 9:5e1581dfb419
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 from galaxy_ml.keras_galaxy_models import ( 10 from galaxy_ml.keras_galaxy_models import (
11 _predict_generator, 11 _predict_generator,
12 KerasGBatchClassifier, 12 KerasGBatchClassifier,
13 KerasGClassifier,
14 KerasGRegressor
13 ) 15 )
14 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 16 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
15 from galaxy_ml.model_validations import train_test_split 17 from galaxy_ml.model_validations import train_test_split
16 from galaxy_ml.utils import ( 18 from galaxy_ml.utils import (
17 clean_params, 19 clean_params,
398 main_est.set_params(memory=memory) 400 main_est.set_params(memory=memory)
399 401
400 # handle scorer, convert to scorer dict 402 # handle scorer, convert to scorer dict
401 scoring = params["experiment_schemes"]["metrics"]["scoring"] 403 scoring = params["experiment_schemes"]["metrics"]["scoring"]
402 scorer = get_scoring(scoring) 404 scorer = get_scoring(scoring)
403 if not isinstance(scorer, (dict, list)): 405
404 scorer = [scoring["primary_scoring"]] 406 # We get 'None' back from the call to 'get_scoring()' if
407 # the primary scoring is 'default'. Replace 'default' with
408 # the default scoring for classification/regression (accuracy/r2)
409 if scorer is None:
410 if isinstance(estimator, KerasGClassifier):
411 scorer = ['accuracy']
412 if isinstance(estimator, KerasGRegressor):
413 scorer = ['r2']
414
405 scorer = _check_multimetric_scoring(estimator, scoring=scorer) 415 scorer = _check_multimetric_scoring(estimator, scoring=scorer)
406 416
407 # handle test (first) split 417 # handle test (first) split
408 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] 418 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
409 419
497 if hasattr(estimator, "predict_proba"): 507 if hasattr(estimator, "predict_proba"):
498 predictions = estimator.predict_proba(X_test) 508 predictions = estimator.predict_proba(X_test)
499 else: 509 else:
500 predictions = estimator.predict(X_test) 510 predictions = estimator.predict(X_test)
501 511
502 y_true = y_test 512 # Un-do OHE of the validation labels
503 sk_scores = _score(estimator, X_test, y_test, scorer) 513 if len(y_test.shape) == 2:
514 rounded_test_labels = np.argmax(y_test, axis=1)
515 y_true = rounded_test_labels
516 sk_scores = _score(estimator, X_test, rounded_test_labels, scorer)
517 else:
518 y_true = y_test
519 sk_scores = _score(estimator, X_test, y_true, scorer)
520
504 scores.update(sk_scores) 521 scores.update(sk_scores)
505 522
506 # handle output 523 # handle output
507 if outfile_y_true: 524 if outfile_y_true:
508 try: 525 try: