diff train_test_split.py @ 36:92e09b827300 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:43:23 +0000
parents 318484f56b6a
children 5af054432771
line wrap: on
line diff
--- a/train_test_split.py	Tue Apr 13 22:34:09 2021 +0000
+++ b/train_test_split.py	Sat May 01 01:43:23 2021 +0000
@@ -28,17 +28,23 @@
 
     # read groups
     if infile_groups:
-        header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None
-        column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][
-            "selected_column_selector_option_g"
-        ]
+        header = (
+            "infer"
+            if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"])
+            else None
+        )
+        column_option = params["mode_selection"]["cv_selector"]["groups_selector"][
+            "column_selector_options_g"
+        ]["selected_column_selector_option_g"]
         if column_option in [
             "by_index_number",
             "all_but_by_index_number",
             "by_header_name",
             "all_but_by_header_name",
         ]:
-            c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
+            c = params["mode_selection"]["cv_selector"]["groups_selector"][
+                "column_selector_options_g"
+            ]["col_g"]
         else:
             c = None
 
@@ -67,7 +73,10 @@
 
     total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
     if nth_split > total_n_splits:
-        raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split))
+        raise ValueError(
+            "Total number of splits is {}, but got `nth_split` "
+            "= {}".format(total_n_splits, nth_split)
+        )
 
     i = 1
     for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
@@ -137,7 +146,9 @@
 
     # cv splitter
     else:
-        train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups)
+        train, test = _get_single_cv_split(
+            params, array, infile_labels=infile_labels, infile_groups=infile_groups
+        )
 
     print("Input shape: %s" % repr(array.shape))
     print("Train shape: %s" % repr(train.shape))