Mercurial > repos > bgruening > sklearn_model_fit
changeset 1:49525d37843a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
author | bgruening |
---|---|
date | Thu, 07 Nov 2019 05:31:59 -0500 |
parents | 734c66aa945a |
children | 8861ece0b66f |
files | simple_model_fit.py |
diffstat | 1 files changed, 39 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/simple_model_fit.py Fri Nov 01 17:18:28 2019 -0400 +++ b/simple_model_fit.py Thu Nov 07 05:31:59 2019 -0500 @@ -7,6 +7,44 @@ from sklearn.pipeline import Pipeline +N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) + + +# TODO import from galaxy_ml.utils in future versions +def clean_params(estimator, n_jobs=None): + """clean unwanted hyperparameter settings + + If n_jobs is not None, set it into the estimator, if applicable + + Return + ------ + Cleaned estimator object + """ + ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', + 'ReduceLROnPlateau', 'CSVLogger', 'None') + + estimator_params = estimator.get_params() + + for name, p in estimator_params.items(): + # all potential unauthorized file write + if name == 'memory' or name.endswith('__memory') \ + or name.endswith('_path'): + new_p = {name: None} + estimator.set_params(**new_p) + elif n_jobs is not None and (name == 'n_jobs' or + name.endswith('__n_jobs')): + new_p = {name: n_jobs} + estimator.set_params(**new_p) + elif name.endswith('callbacks'): + for cb in p: + cb_type = cb['callback_selection']['callback_type'] + if cb_type not in ALLOWED_CALLBACKS: + raise ValueError( + "Prohibited callback type: %s!" % cb_type) + + return estimator + + def _get_X_y(params, infile1, infile2): """ read from inputs and output X and y @@ -107,6 +145,7 @@ # load model with open(infile_estimator, 'rb') as est_handler: estimator = load_model(est_handler) + estimator = clean_params(estimator, n_jobs=N_JOBS) X_train, y_train = _get_X_y(params, infile1, infile2)