Mercurial > repos > bgruening > sklearn_train_test_eval
changeset 19:f12383ee3234 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5eca9041ce0154eded5aec07195502d5eb3cdd4f
| author | bgruening | 
|---|---|
| date | Fri, 03 Nov 2023 23:01:13 +0000 | 
| parents | ddc4fe317a92 | 
| children | f6ea9fc38223 | 
| files | keras_train_and_eval.py main_macros.xml | 
| diffstat | 2 files changed, 22 insertions(+), 5 deletions(-) [+] | 
line wrap: on
 line diff
--- a/keras_train_and_eval.py Mon Oct 02 10:28:18 2023 +0000 +++ b/keras_train_and_eval.py Fri Nov 03 23:01:13 2023 +0000 @@ -10,6 +10,8 @@ from galaxy_ml.keras_galaxy_models import ( _predict_generator, KerasGBatchClassifier, + KerasGClassifier, + KerasGRegressor ) from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 from galaxy_ml.model_validations import train_test_split @@ -400,8 +402,16 @@ # handle scorer, convert to scorer dict scoring = params["experiment_schemes"]["metrics"]["scoring"] scorer = get_scoring(scoring) - if not isinstance(scorer, (dict, list)): - scorer = [scoring["primary_scoring"]] + + # We get 'None' back from the call to 'get_scoring()' if + # the primary scoring is 'default'. Replace 'default' with + # the default scoring for classification/regression (accuracy/r2) + if scorer is None: + if isinstance(estimator, KerasGClassifier): + scorer = ['accuracy'] + if isinstance(estimator, KerasGRegressor): + scorer = ['r2'] + scorer = _check_multimetric_scoring(estimator, scoring=scorer) # handle test (first) split @@ -499,8 +509,15 @@ else: predictions = estimator.predict(X_test) - y_true = y_test - sk_scores = _score(estimator, X_test, y_test, scorer) + # Un-do OHE of the validation labels + if len(y_test.shape) == 2: + rounded_test_labels = np.argmax(y_test, axis=1) + y_true = rounded_test_labels + sk_scores = _score(estimator, X_test, rounded_test_labels, scorer) + else: + y_true = y_test + sk_scores = _score(estimator, X_test, y_true, scorer) + scores.update(sk_scores) # handle output
