Previous changeset 14:818f9b69d8a0 (2023-10-02) Next changeset 16:2af1346e68c9 (2023-11-05) |
Commit message:
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f |
modified:
keras_train_and_eval.py main_macros.xml |
b |
diff -r 818f9b69d8a0 -r bba502278c25 keras_train_and_eval.py --- a/keras_train_and_eval.py Mon Oct 02 10:00:27 2023 +0000 +++ b/keras_train_and_eval.py Fri Nov 03 23:05:31 2023 +0000 |
[ |
@@ -10,6 +10,8 @@ from galaxy_ml.keras_galaxy_models import ( _predict_generator, KerasGBatchClassifier, + KerasGClassifier, + KerasGRegressor ) from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split @@ -400,8 +402,16 @@ # handle scorer, convert to scorer dict scoring = params["experiment_schemes"]["metrics"]["scoring"] scorer = get_scoring(scoring) - if not isinstance(scorer, (dict, list)): - scorer = [scoring["primary_scoring"]] + + # We get 'None' back from the call to 'get_scoring()' if + # the primary scoring is 'default'. Replace 'default' with + # the default scoring for classification/regression (accuracy/r2) + if scorer is None: + if isinstance(estimator, KerasGClassifier): + scorer = ['accuracy'] + if isinstance(estimator, KerasGRegressor): + scorer = ['r2'] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split @@ -499,8 +509,15 @@ else: predictions = estimator.predict(X_test) - y_true = y_test - sk_scores = _score(estimator, X_test, y_test, scorer) + # Un-do OHE of the validation labels + if len(y_test.shape) == 2: + rounded_test_labels = np.argmax(y_test, axis=1) + y_true = rounded_test_labels + sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) + else: + y_true = y_test + sk_scores = _score(estimator, X_test, y_true, scorer) + scores.update(sk_scores) # handle output |
b |
diff -r 818f9b69d8a0 -r bba502278c25 main_macros.xml --- a/main_macros.xml Mon Oct 02 10:00:27 2023 +0000 +++ b/main_macros.xml Fri Nov 03 23:05:31 2023 +0000 |
b |
@@ -1,5 +1,5 @@ <macros> - <token name="@VERSION@">1.0.10.0</token> + <token name="@VERSION@">1.0.11.0</token> <token name="@PROFILE@">21.05</token> <xml name="python_requirements"> |