Repository 'sklearn_model_fit'
hg clone https://toolshed.g2.bx.psu.edu/repos/bgruening/sklearn_model_fit

Changeset 15:8e447b95e6a8 (2023-10-02)
Previous changeset 14:adb084b901cc (2023-08-09) Next changeset 16:8f469d961e30 (2023-11-03)
Commit message:
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
modified:
keras_train_and_eval.py
ml_visualization_ex.py
removed:
pdb70_cs219.ffdata
b
diff -r adb084b901cc -r 8e447b95e6a8 keras_train_and_eval.py
--- a/keras_train_and_eval.py Wed Aug 09 14:26:26 2023 +0000
+++ b/keras_train_and_eval.py Mon Oct 02 09:58:11 2023 +0000
[
@@ -188,6 +188,7 @@
     infile1,
     infile2,
     outfile_result,
+    outfile_history=None,
     outfile_object=None,
     outfile_y_true=None,
     outfile_y_preds=None,
@@ -215,6 +216,9 @@
     outfile_result : str
         File path to save the results, either cv_results or test result.
 
+    outfile_history : str, optional
+        File path to save the training history.
+
     outfile_object : str, optional
         File path to save searchCV object.
 
@@ -253,9 +257,7 @@
     swapping = params["experiment_schemes"]["hyperparams_swapping"]
     swap_params = _eval_swap_params(swapping)
     estimator.set_params(**swap_params)
-
     estimator_params = estimator.get_params()
-
     # store read dataframe object
     loaded_df = {}
 
@@ -448,12 +450,20 @@
     # train and eval
     if hasattr(estimator, "config") and hasattr(estimator, "model_type"):
         if exp_scheme == "train_val_test":
-            estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
+            history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
         else:
-            estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
+            history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
     else:
-        estimator.fit(X_train, y_train)
-
+        history = estimator.fit(X_train, y_train)
+    if "callbacks" in estimator_params:
+        for cb in estimator_params["callbacks"]:
+            if cb["callback_selection"]["callback_type"] == "CSVLogger":
+                hist_df = pd.DataFrame(history.history)
+                hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1)
+                epo_col = hist_df.pop('epoch')
+                hist_df.insert(0, 'epoch', epo_col)
+                hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False)
+                break
     if isinstance(estimator, KerasGBatchClassifier):
         scores = {}
         steps = estimator.prediction_steps
@@ -526,6 +536,7 @@
     aparser.add_argument("-X", "--infile1", dest="infile1")
     aparser.add_argument("-y", "--infile2", dest="infile2")
     aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
+    aparser.add_argument("-hi", "--outfile_history", dest="outfile_history")
     aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
     aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
     aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
@@ -542,6 +553,7 @@
         args.infile1,
         args.infile2,
         args.outfile_result,
+        outfile_history=args.outfile_history,
         outfile_object=args.outfile_object,
         outfile_y_true=args.outfile_y_true,
         outfile_y_preds=args.outfile_y_preds,
b
diff -r adb084b901cc -r 8e447b95e6a8 ml_visualization_ex.py
--- a/ml_visualization_ex.py Wed Aug 09 14:26:26 2023 +0000
+++ b/ml_visualization_ex.py Mon Oct 02 09:58:11 2023 +0000
[
@@ -15,6 +15,7 @@
 from sklearn.metrics import (
     auc,
     average_precision_score,
+    confusion_matrix,
     precision_recall_curve,
     roc_curve,
 )
@@ -258,6 +259,30 @@
     os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
 
 
+def get_dataframe(file_path, plot_selection, header_name, column_name):
+    header = "infer" if plot_selection[header_name] else None
+    column_option = plot_selection[column_name]["selected_column_selector_option"]
+    if column_option in [
+        "by_index_number",
+        "all_but_by_index_number",
+        "by_header_name",
+        "all_but_by_header_name",
+    ]:
+        col = plot_selection[column_name]["col1"]
+    else:
+        col = None
+    _, input_df = read_columns(
+        file_path,
+        c=col,
+        c_option=column_option,
+        return_df=True,
+        sep="\t",
+        header=header,
+        parse_dates=True,
+    )
+    return input_df
+
+
 def main(
     inputs,
     infile_estimator=None,
@@ -271,6 +296,10 @@
     targets=None,
     fasta_path=None,
     model_config=None,
+    true_labels=None,
+    predicted_labels=None,
+    plot_color=None,
+    title=None,
 ):
     """
     Parameter
@@ -311,6 +340,18 @@
 
     model_config : str, default is None
         File path to dataset containing JSON config for neural networks
+
+    true_labels : str, default is None
+        File path to dataset containing true labels
+
+    predicted_labels : str, default is None
+        File path to dataset containing true predicted labels
+
+    plot_color : str, default is None
+        Color of the confusion matrix heatmap
+
+    title : str, default is None
+        Title of the confusion matrix heatmap
     """
     warnings.simplefilter("ignore")
 
@@ -534,6 +575,36 @@
 
         return 0
 
+    elif plot_type == "classification_confusion_matrix":
+        plot_selection = params["plotting_selection"]
+        input_true = get_dataframe(
+            true_labels, plot_selection, "header_true", "column_selector_options_true"
+        )
+        header_predicted = "infer" if plot_selection["header_predicted"] else None
+        input_predicted = pd.read_csv(
+            predicted_labels, sep="\t", parse_dates=True, header=header_predicted
+        )
+        true_classes = input_true.iloc[:, -1].copy()
+        predicted_classes = input_predicted.iloc[:, -1].copy()
+        axis_labels = list(set(true_classes))
+        c_matrix = confusion_matrix(true_classes, predicted_classes)
+        fig, ax = plt.subplots(figsize=(7, 7))
+        im = plt.imshow(c_matrix, cmap=plot_color)
+        for i in range(len(c_matrix)):
+            for j in range(len(c_matrix)):
+                ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+        ax.set_ylabel("True class labels")
+        ax.set_xlabel("Predicted class labels")
+        ax.set_title(title)
+        ax.set_xticks(axis_labels)
+        ax.set_yticks(axis_labels)
+        fig.colorbar(im, ax=ax)
+        fig.tight_layout()
+        plt.savefig("output.png", dpi=125)
+        os.rename("output.png", "output")
+
+        return 0
+
     # save pdf file to disk
     # fig.write_image("image.pdf", format='pdf')
     # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
@@ -553,6 +624,10 @@
     aparser.add_argument("-t", "--targets", dest="targets")
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
     aparser.add_argument("-c", "--model_config", dest="model_config")
+    aparser.add_argument("-tl", "--true_labels", dest="true_labels")
+    aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
+    aparser.add_argument("-pc", "--plot_color", dest="plot_color")
+    aparser.add_argument("-pt", "--title", dest="title")
     args = aparser.parse_args()
 
     main(
@@ -568,4 +643,8 @@
         targets=args.targets,
         fasta_path=args.fasta_path,
         model_config=args.model_config,
+        true_labels=args.true_labels,
+        predicted_labels=args.predicted_labels,
+        plot_color=args.plot_color,
+        title=args.title,
     )
b
diff -r adb084b901cc -r 8e447b95e6a8 pdb70_cs219.ffdata
--- a/pdb70_cs219.ffdata Wed Aug 09 14:26:26 2023 +0000
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
[
@@ -1,189 +0,0 @@
-
-
-
-
-<!DOCTYPE HTML>
-<html>
-    <!--base.mako-->
-    
-
-
-    <head>
-        <meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
-        <meta name = "viewport" content = "maximum-scale=1.0">
-        <meta http-equiv="X-UA-Compatible" content="IE=Edge,chrome=1">
-
-        <title>
-            Galaxy
-            | Europe
-            | 
-        </title>
-
-        <link rel="index" href="/"/>
-
-        
-        
-    <link href="/static/style/bootstrap-tour.css?v=1618364054" media="screen" rel="stylesheet" type="text/css" />
-    <link href="/static/dist/base.css?v=1618364054" media="screen" rel="stylesheet" type="text/css" />
-
-        
-    <script src="/static/dist/libs.chunk.js?v=1618364054" type="text/javascript"></script>
-<script src="/static/dist/base.chunk.js?v=1618364054" type="text/javascript"></script>
-<script src="/static/dist/generic.bundled.js?v=1618364054" type="text/javascript"></script>
-
-        
-    <!-- message.mako javascript_app() -->
-    
-
-    
-    <script type="text/javascript">
-        // galaxy_client_app.mako, load
-
-        var bootstrapped;
-        try {
-            bootstrapped = 
-{}
-;
-        } catch(err) {
-            console.warn("Unable to parse bootstrapped variable", err);
-            bootstrapped = {};
-        }
-
-        var options = {
-            root: '/',
-            config: 
-    
-{
-"display_galaxy_brand": true,
-"chunk_upload_size": 104857600,
-"use_remote_user": null,
-"enable_oidc": true,
-"mailing_join_addr": null,
-"select_type_workflow_threshold": -1,
-"myexperiment_target_url": "www.myexperiment.org:80",
-"tool_recommendation_model_path": "https://github.com/galaxyproject/galaxy-test-data/raw/master/tool_recommendation_model.hdf5",
-"simplified_workflow_run_ui_target_history": "current",
-"interactivetools_enable": true,
-"is_admin_user": false,
-"show_welcome_with_login": true,
-"welcome_url": "/static/welcome.html",
-"allow_user_impersonation": true,
-"overwrite_model_recommendations": false,
-"topk_recommendations": 10,
-"user_library_import_dir_available": false,
-"ga_code": null,
-"enable_beta_markdown_export": true,
-"visualizations_visible": true,
-"enable_tool_recommendations": true,
-"enable_unique_workflow_defaults": false,
-"registration_warning_message": "Please register only one account. The usegalaxy.eu service is provided free of charge and has limited computational and data storage resources. <strong>Registration and usage of multiple accounts is tracked and such accounts are subject to termination and data deletion.<\/strong>",
-"logo_src": "/static/favicon.png",
-"enable_quotas": true,
-"server_mail_configured": true,
-"citation_url": "https://galaxyproject.org/citing-galaxy",
-"allow_user_dataset_purge": true,
-"ftp_upload_site": "ftp://ftp.usegalaxy.eu",
-"terms_url": "https://usegalaxy.eu/terms",
-"upload_from_form_button": "always-on",
-"wiki_url": "https://galaxyproject.org/",
-"logo_src_secondary": null,
-"aws_estimate": true,
-"single_user": false,
-"datatypes_disable_auto": false,
-"brand": "Europe",
-"mailing_lists": "https://galaxyproject.org/mailing-lists/",
-"python": [
-3,
-6
-],
-"release_doc_base_url": "https://docs.galaxyproject.org/en/release_",
-"enable_openid": false,
-"cookie_domain": null,
-"message_box_content": "You are using the new UseGalaxy.eu backend server, let us know if you encounter any issues!",
-"admin_tool_recommendations_path": "/opt/galaxy/config/tool_recommendations_overwrite.yml",
-"search_url": "https://galaxyproject.org/search/",
-"remote_user_logout_href": null,
-"default_locale": "auto",
-"screencasts_url": "https://vimeo.com/galaxyproject",
-"quota_url": "https://galaxyproject.org/support/account-quotas/",
-"version_major": "21.01",
-"simplified_workflow_run_ui": "prefer",
-"allow_user_creation": true,
-"lims_doc_url": "https://usegalaxy.org/u/rkchak/p/sts",
-"message_box_visible": false,
-"has_user_tool_filters": true,
-"message_box_class": "info",
-"require_login": false,
-"logo_url": "/",
-"support_url": "https://galaxyproject.org/support/",
-"simplified_workflow_run_ui_job_cache": "off",
-"server_startttime": 1618364054,
-"oidc": {
-"elixir": {
-"icon": "https://elixir-europe.org/sites/default/files/images/login-button-orange.png"
-}
-},
-"version_minor": "",
-"helpsite_url": "https://help.galaxyproject.org/c/usegalaxy-eu-support",
-"file_sources_configured": true,
-"inactivity_box_content": "Your account has not been activated yet.  Feel free to browse around and see what's available, but you won't be able to upload data or run jobs until you have verified your email address.",
-"nginx_upload_path": "/_upload"
-}
-,
-            user: 
-    
-{
-"total_disk_usage": 0,
-"nice_total_disk_usage": "0 bytes",
-"quota_percent": null
-}
-,
-            session_csrf_token: 'c3ae71f65be7de55dd5bd5f97f316000'
-        };
-
-        config.set({
-            options: options,
-            bootstrapped: bootstrapped
-        });
-
-
-    </script>
-
-    
-
-
-
-
-    
-
-    
-    <script type="text/javascript">
-        config.addInitialization(function() {
-            if (parent.handle_minwidth_hint) {
-                parent.handle_minwidth_hint(-1);
-            }
-        });
-    </script>
-
-    </head>
-    <body class="inbound">
-        
-    
-    
-    <div class="message mt-2 alert alert-danger">You are not allowed to access this dataset</div>
-
-
-    </body>
-</html>
-
-
-
-
-
-
-
-
-
-
-
-