Previous changeset 23:8a307b946c58 (2019-11-01) Next changeset 25:41b109e70a7f (2019-12-16) |
Commit message:
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53" |
modified:
simple_model_fit.py |
b |
diff -r 8a307b946c58 -r 4170e2bda73d simple_model_fit.py --- a/simple_model_fit.py Fri Nov 01 17:24:11 2019 -0400 +++ b/simple_model_fit.py Thu Nov 07 05:37:06 2019 -0500 |
[ |
@@ -7,6 +7,44 @@ from sklearn.pipeline import Pipeline +N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) + + +# TODO import from galaxy_ml.utils in future versions +def clean_params(estimator, n_jobs=None): + """clean unwanted hyperparameter settings + + If n_jobs is not None, set it into the estimator, if applicable + + Return + ------ + Cleaned estimator object + """ + ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', + 'ReduceLROnPlateau', 'CSVLogger', 'None') + + estimator_params = estimator.get_params() + + for name, p in estimator_params.items(): + # all potential unauthorized file write + if name == 'memory' or name.endswith('__memory') \ + or name.endswith('_path'): + new_p = {name: None} + estimator.set_params(**new_p) + elif n_jobs is not None and (name == 'n_jobs' or + name.endswith('__n_jobs')): + new_p = {name: n_jobs} + estimator.set_params(**new_p) + elif name.endswith('callbacks'): + for cb in p: + cb_type = cb['callback_selection']['callback_type'] + if cb_type not in ALLOWED_CALLBACKS: + raise ValueError( + "Prohibited callback type: %s!" % cb_type) + + return estimator + + def _get_X_y(params, infile1, infile2): """ read from inputs and output X and y @@ -107,6 +145,7 @@ # load model with open(infile_estimator, 'rb') as est_handler: estimator = load_model(est_handler) + estimator = clean_params(estimator, n_jobs=N_JOBS) X_train, y_train = _get_X_y(params, infile1, infile2) |