# HG changeset patch # User bgruening # Date 1573123449 18000 # Node ID 4af699d766e4f38cb5c166ff8c929463a63543b1 # Parent 3d5b2491c62bb74e46ff1b73e73fa37e59e70f2a "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53" diff -r 3d5b2491c62b -r 4af699d766e4 simple_model_fit.py --- a/simple_model_fit.py Fri Nov 01 17:31:18 2019 -0400 +++ b/simple_model_fit.py Thu Nov 07 05:44:09 2019 -0500 @@ -7,6 +7,44 @@ from sklearn.pipeline import Pipeline +N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) + + +# TODO import from galaxy_ml.utils in future versions +def clean_params(estimator, n_jobs=None): + """clean unwanted hyperparameter settings + + If n_jobs is not None, set it into the estimator, if applicable + + Return + ------ + Cleaned estimator object + """ + ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', + 'ReduceLROnPlateau', 'CSVLogger', 'None') + + estimator_params = estimator.get_params() + + for name, p in estimator_params.items(): + # all potential unauthorized file write + if name == 'memory' or name.endswith('__memory') \ + or name.endswith('_path'): + new_p = {name: None} + estimator.set_params(**new_p) + elif n_jobs is not None and (name == 'n_jobs' or + name.endswith('__n_jobs')): + new_p = {name: n_jobs} + estimator.set_params(**new_p) + elif name.endswith('callbacks'): + for cb in p: + cb_type = cb['callback_selection']['callback_type'] + if cb_type not in ALLOWED_CALLBACKS: + raise ValueError( + "Prohibited callback type: %s!" % cb_type) + + return estimator + + def _get_X_y(params, infile1, infile2): """ read from inputs and output X and y @@ -107,6 +145,7 @@ # load model with open(infile_estimator, 'rb') as est_handler: estimator = load_model(est_handler) + estimator = clean_params(estimator, n_jobs=N_JOBS) X_train, y_train = _get_X_y(params, infile1, infile2)