Mercurial > repos > bgruening > sklearn_mlxtend_association_rules
changeset 9:a6a308bf9262 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
author | bgruening |
---|---|
date | Fri, 03 Nov 2023 22:52:37 +0000 (14 months ago) |
parents | 34e063ac0428 |
children | e1919c102646 |
files | keras_train_and_eval.py main_macros.xml |
diffstat | 2 files changed, 22 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/keras_train_and_eval.py Mon Oct 02 10:21:10 2023 +0000 +++ b/keras_train_and_eval.py Fri Nov 03 22:52:37 2023 +0000 @@ -10,6 +10,8 @@ from galaxy_ml.keras_galaxy_models import ( _predict_generator, KerasGBatchClassifier, + KerasGClassifier, + KerasGRegressor ) from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split @@ -400,8 +402,16 @@ # handle scorer, convert to scorer dict scoring = params["experiment_schemes"]["metrics"]["scoring"] scorer = get_scoring(scoring) - if not isinstance(scorer, (dict, list)): - scorer = [scoring["primary_scoring"]] + + # We get 'None' back from the call to 'get_scoring()' if + # the primary scoring is 'default'. Replace 'default' with + # the default scoring for classification/regression (accuracy/r2) + if scorer is None: + if isinstance(estimator, KerasGClassifier): + scorer = ['accuracy'] + if isinstance(estimator, KerasGRegressor): + scorer = ['r2'] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split @@ -499,8 +509,15 @@ else: predictions = estimator.predict(X_test) - y_true = y_test - sk_scores = _score(estimator, X_test, y_test, scorer) + # Un-do OHE of the validation labels + if len(y_test.shape) == 2: + rounded_test_labels = np.argmax(y_test, axis=1) + y_true = rounded_test_labels + sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) + else: + y_true = y_test + sk_scores = _score(estimator, X_test, y_true, scorer) + scores.update(sk_scores) # handle output