changeset 14:818f9b69d8a0 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author bgruening
date Mon, 02 Oct 2023 10:00:27 +0000
parents 0af678661e20
children bba502278c25
files keras_train_and_eval.py keras_train_and_eval.xml
diffstat 2 files changed, 26 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/keras_train_and_eval.py	Fri Sep 22 17:04:27 2023 +0000
+++ b/keras_train_and_eval.py	Mon Oct 02 10:00:27 2023 +0000
@@ -188,6 +188,7 @@
     infile1,
     infile2,
     outfile_result,
+    outfile_history=None,
     outfile_object=None,
     outfile_y_true=None,
     outfile_y_preds=None,
@@ -215,6 +216,9 @@
     outfile_result : str
         File path to save the results, either cv_results or test result.
 
+    outfile_history : str, optional
+        File path to save the training history.
+
     outfile_object : str, optional
         File path to save searchCV object.
 
@@ -253,9 +257,7 @@
     swapping = params["experiment_schemes"]["hyperparams_swapping"]
     swap_params = _eval_swap_params(swapping)
     estimator.set_params(**swap_params)
-
     estimator_params = estimator.get_params()
-
     # store read dataframe object
     loaded_df = {}
 
@@ -448,12 +450,20 @@
     # train and eval
     if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
         if exp_scheme == "train_val_test":
-            estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
+            history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
         else:
-            estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
+            history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
     else:
-        estimator.fit(X_train, y_train)
-
+        history = estimator.fit(X_train, y_train)
+    if "callbacks" in estimator_params:
+        for cb in estimator_params["callbacks"]:
+            if cb["callback_selection"]["callback_type"] == "CSVLogger":
+                hist_df = pd.DataFrame(history.history)
+                hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1)
+                epo_col = hist_df.pop('epoch')
+                hist_df.insert(0, 'epoch', epo_col)
+                hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False)
+                break
     if isinstance(estimator, KerasGBatchClassifier):
         scores = {}
         steps = estimator.prediction_steps
@@ -526,6 +536,7 @@
     aparser.add_argument("-X", "--infile1", dest="infile1")
     aparser.add_argument("-y", "--infile2", dest="infile2")
     aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
+    aparser.add_argument("-hi", "--outfile_history", dest="outfile_history")
     aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
     aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
     aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
@@ -542,6 +553,7 @@
         args.infile1,
         args.infile2,
         args.outfile_result,
+        outfile_history=args.outfile_history,
         outfile_object=args.outfile_object,
         outfile_y_true=args.outfile_y_true,
         outfile_y_preds=args.outfile_y_preds,
--- a/keras_train_and_eval.xml	Fri Sep 22 17:04:27 2023 +0000
+++ b/keras_train_and_eval.xml	Mon Oct 02 10:00:27 2023 +0000
@@ -8,7 +8,7 @@
     <expand macro="macro_stdio" />
     <version_command>echo "@VERSION@"</version_command>
     <command>
-        <![CDATA[
+        <![CDATA[        
         export HDF5_USE_FILE_LOCKING='FALSE';
         #if $input_options.selected_input == 'refseq_and_interval'
         bgzip -c '$input_options.target_file' > '${target_file.element_identifier}.gz' &&
@@ -29,6 +29,9 @@
             #end if
             --infile2 '$input_options.infile2'
             --outfile_result '$outfile_result'
+            #if $save and 'save_csvlogger' in str($save)
+            --outfile_history '$outfile_history'
+            #end if
             #if $save and 'save_estimator' in str($save)
             --outfile_object '$outfile_object'
             #end if
@@ -39,7 +42,6 @@
             #if $experiment_schemes.test_split.split_algos.shuffle == 'group'
             --groups '$experiment_schemes.test_split.split_algos.groups_selector.infile_g'
             #end if
-
         ]]>
     </command>
     <configfiles>
@@ -81,10 +83,14 @@
         <param name="save" type="select" multiple='true' display="checkboxes" label="Save the fitted model" optional="true" help="Evaluation scores will be output by default.">
             <option value="save_estimator" selected="true">Fitted estimator</option>
             <option value="save_prediction">True labels and prediction results from evaluation for downstream analysis</option>
+            <option value="save_csvlogger">Display CSVLogger if selected as a callback in the Keras model builder tool</option>
         </param>
     </inputs>
     <outputs>
         <data format="tabular" name="outfile_result" />
+         <data format="tabular" name="outfile_history" label="Deep learning training history log on ${on_string}">
+            <filter>str(save) and 'save_csvlogger' in str(save)</filter>
+        </data>
         <data format="h5mlm" name="outfile_object" label="Fitted estimator or estimator skeleton on ${on_string}">
             <filter>str(save) and 'save_estimator' in str(save)</filter>
         </data>