Mercurial > repos > bgruening > keras_train_and_eval
changeset 14:818f9b69d8a0 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author | bgruening |
---|---|
date | Mon, 02 Oct 2023 10:00:27 +0000 |
parents | 0af678661e20 |
children | bba502278c25 |
files | keras_train_and_eval.py keras_train_and_eval.xml |
diffstat | 2 files changed, 26 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/keras_train_and_eval.py Fri Sep 22 17:04:27 2023 +0000 +++ b/keras_train_and_eval.py Mon Oct 02 10:00:27 2023 +0000 @@ -188,6 +188,7 @@ infile1, infile2, outfile_result, + outfile_history=None, outfile_object=None, outfile_y_true=None, outfile_y_preds=None, @@ -215,6 +216,9 @@ outfile_result : str File path to save the results, either cv_results or test result. + outfile_history : str, optional + File path to save the training history. + outfile_object : str, optional File path to save searchCV object. @@ -253,9 +257,7 @@ swapping = params["experiment_schemes"]["hyperparams_swapping"] swap_params = _eval_swap_params(swapping) estimator.set_params(**swap_params) - estimator_params = estimator.get_params() - # store read dataframe object loaded_df = {} @@ -448,12 +450,20 @@ # train and eval if hasattr(estimator, "config") and hasattr(estimator, "model_type"): if exp_scheme == "train_val_test": - estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) + history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) else: - estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) + history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) else: - estimator.fit(X_train, y_train) - + history = estimator.fit(X_train, y_train) + if "callbacks" in estimator_params: + for cb in estimator_params["callbacks"]: + if cb["callback_selection"]["callback_type"] == "CSVLogger": + hist_df = pd.DataFrame(history.history) + hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1) + epo_col = hist_df.pop('epoch') + hist_df.insert(0, 'epoch', epo_col) + hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False) + break if isinstance(estimator, KerasGBatchClassifier): scores = {} steps = estimator.prediction_steps @@ -526,6 +536,7 @@ aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-y", "--infile2", dest="infile2") aparser.add_argument("-O", "--outfile_result", dest="outfile_result") + aparser.add_argument("-hi", "--outfile_history", dest="outfile_history") aparser.add_argument("-o", "--outfile_object", dest="outfile_object") aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") @@ -542,6 +553,7 @@ args.infile1, args.infile2, args.outfile_result, + outfile_history=args.outfile_history, outfile_object=args.outfile_object, outfile_y_true=args.outfile_y_true, outfile_y_preds=args.outfile_y_preds,
--- a/keras_train_and_eval.xml Fri Sep 22 17:04:27 2023 +0000 +++ b/keras_train_and_eval.xml Mon Oct 02 10:00:27 2023 +0000 @@ -8,7 +8,7 @@ <expand macro="macro_stdio" /> <version_command>echo "@VERSION@"</version_command> <command> - <![CDATA[ + <![CDATA[ export HDF5_USE_FILE_LOCKING='FALSE'; #if $input_options.selected_input == 'refseq_and_interval' bgzip -c '$input_options.target_file' > '${target_file.element_identifier}.gz' && @@ -29,6 +29,9 @@ #end if --infile2 '$input_options.infile2' --outfile_result '$outfile_result' + #if $save and 'save_csvlogger' in str($save) + --outfile_history '$outfile_history' + #end if #if $save and 'save_estimator' in str($save) --outfile_object '$outfile_object' #end if @@ -39,7 +42,6 @@ #if $experiment_schemes.test_split.split_algos.shuffle == 'group' --groups '$experiment_schemes.test_split.split_algos.groups_selector.infile_g' #end if - ]]> </command> <configfiles> @@ -81,10 +83,14 @@ <param name="save" type="select" multiple='true' display="checkboxes" label="Save the fitted model" optional="true" help="Evaluation scores will be output by default."> <option value="save_estimator" selected="true">Fitted estimator</option> <option value="save_prediction">True labels and prediction results from evaluation for downstream analysis</option> + <option value="save_csvlogger">Display CSVLogger if selected as a callback in the Keras model builder tool</option> </param> </inputs> <outputs> <data format="tabular" name="outfile_result" /> + <data format="tabular" name="outfile_history" label="Deep learning training history log on ${on_string}"> + <filter>str(save) and 'save_csvlogger' in str(save)</filter> + </data> <data format="h5mlm" name="outfile_object" label="Fitted estimator or estimator skeleton on ${on_string}"> <filter>str(save) and 'save_estimator' in str(save)</filter> </data>