diff simple_model_fit.py @ 1:49525d37843a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
author bgruening
date Thu, 07 Nov 2019 05:31:59 -0500
parents 734c66aa945a
children 26decbf4bdb8
line wrap: on
line diff
--- a/simple_model_fit.py	Fri Nov 01 17:18:28 2019 -0400
+++ b/simple_model_fit.py	Thu Nov 07 05:31:59 2019 -0500
@@ -7,6 +7,44 @@
 from sklearn.pipeline import Pipeline
 
 
+N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
+
+
+# TODO import from galaxy_ml.utils in future versions
+def clean_params(estimator, n_jobs=None):
+    """clean unwanted hyperparameter settings
+
+    If n_jobs is not None, set it into the estimator, if applicable
+
+    Return
+    ------
+    Cleaned estimator object
+    """
+    ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN',
+                         'ReduceLROnPlateau', 'CSVLogger', 'None')
+
+    estimator_params = estimator.get_params()
+
+    for name, p in estimator_params.items():
+        # all potential unauthorized file write
+        if name == 'memory' or name.endswith('__memory') \
+                or name.endswith('_path'):
+            new_p = {name: None}
+            estimator.set_params(**new_p)
+        elif n_jobs is not None and (name == 'n_jobs' or
+                                     name.endswith('__n_jobs')):
+            new_p = {name: n_jobs}
+            estimator.set_params(**new_p)
+        elif name.endswith('callbacks'):
+            for cb in p:
+                cb_type = cb['callback_selection']['callback_type']
+                if cb_type not in ALLOWED_CALLBACKS:
+                    raise ValueError(
+                        "Prohibited callback type: %s!" % cb_type)
+
+    return estimator
+
+
 def _get_X_y(params, infile1, infile2):
     """ read from inputs and output X and y
 
@@ -107,6 +145,7 @@
     # load model
     with open(infile_estimator, 'rb') as est_handler:
         estimator = load_model(est_handler)
+    estimator = clean_params(estimator, n_jobs=N_JOBS)
 
     X_train, y_train = _get_X_y(params, infile1, infile2)