# HG changeset patch # User bgruening # Date 1699051957 0 # Node ID a6a308bf9262077be7ecdad2167b518ad07c494b # Parent 34e063ac042820fed7da056b892dbcb2f76f0c06 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f diff -r 34e063ac0428 -r a6a308bf9262 keras_train_and_eval.py --- a/keras_train_and_eval.py Mon Oct 02 10:21:10 2023 +0000 +++ b/keras_train_and_eval.py Fri Nov 03 22:52:37 2023 +0000 @@ -10,6 +10,8 @@ from galaxy_ml.keras_galaxy_models import ( _predict_generator, KerasGBatchClassifier, + KerasGClassifier, + KerasGRegressor ) from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split @@ -400,8 +402,16 @@ # handle scorer, convert to scorer dict scoring = params["experiment_schemes"]["metrics"]["scoring"] scorer = get_scoring(scoring) - if not isinstance(scorer, (dict, list)): - scorer = [scoring["primary_scoring"]] + + # We get 'None' back from the call to 'get_scoring()' if + # the primary scoring is 'default'. Replace 'default' with + # the default scoring for classification/regression (accuracy/r2) + if scorer is None: + if isinstance(estimator, KerasGClassifier): + scorer = ['accuracy'] + if isinstance(estimator, KerasGRegressor): + scorer = ['r2'] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split @@ -499,8 +509,15 @@ else: predictions = estimator.predict(X_test) - y_true = y_test - sk_scores = _score(estimator, X_test, y_test, scorer) + # Un-do OHE of the validation labels + if len(y_test.shape) == 2: + rounded_test_labels = np.argmax(y_test, axis=1) + y_true = rounded_test_labels + sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) + else: + y_true = y_test + sk_scores = _score(estimator, X_test, y_true, scorer) + scores.update(sk_scores) # handle output diff -r 34e063ac0428 -r a6a308bf9262 main_macros.xml --- a/main_macros.xml Mon Oct 02 10:21:10 2023 +0000 +++ b/main_macros.xml Fri Nov 03 22:52:37 2023 +0000 @@ -1,5 +1,5 @@ - 1.0.10.0 + 1.0.11.0 21.05