diff train_test_split.py @ 9:0a3f113397b2 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 17:29:01 +0000
parents 0a7f2b9e1fcb
children f5e7df4f7975
line wrap: on
line diff
--- a/train_test_split.py	Thu Oct 01 20:02:43 2020 +0000
+++ b/train_test_split.py	Tue Apr 13 17:29:01 2021 +0000
@@ -7,9 +7,8 @@
 from galaxy_ml.utils import get_cv, read_columns
 
 
-def _get_single_cv_split(params, array, infile_labels=None,
-                         infile_groups=None):
-    """ output (train, test) subset from a cv splitter
+def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
+    """output (train, test) subset from a cv splitter
 
     Parameters
     ----------
@@ -25,45 +24,50 @@
     y = None
     groups = None
 
-    nth_split = params['mode_selection']['nth_split']
+    nth_split = params["mode_selection"]["nth_split"]
 
     # read groups
     if infile_groups:
-        header = 'infer' if (params['mode_selection']['cv_selector']
-                             ['groups_selector']['header_g']) else None
-        column_option = (params['mode_selection']['cv_selector']
-                         ['groups_selector']['column_selector_options_g']
-                         ['selected_column_selector_option_g'])
-        if column_option in ['by_index_number', 'all_but_by_index_number',
-                             'by_header_name', 'all_but_by_header_name']:
-            c = (params['mode_selection']['cv_selector']['groups_selector']
-                 ['column_selector_options_g']['col_g'])
+        header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None
+        column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][
+            "selected_column_selector_option_g"
+        ]
+        if column_option in [
+            "by_index_number",
+            "all_but_by_index_number",
+            "by_header_name",
+            "all_but_by_header_name",
+        ]:
+            c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
         else:
             c = None
 
-        groups = read_columns(infile_groups, c=c, c_option=column_option,
-                              sep='\t', header=header, parse_dates=True)
+        groups = read_columns(
+            infile_groups,
+            c=c,
+            c_option=column_option,
+            sep="\t",
+            header=header,
+            parse_dates=True,
+        )
         groups = groups.ravel()
 
-        params['mode_selection']['cv_selector']['groups_selector'] = groups
+        params["mode_selection"]["cv_selector"]["groups_selector"] = groups
 
     # read labels
     if infile_labels:
-        target_input = (params['mode_selection']
-                        ['cv_selector'].pop('target_input'))
-        header = 'infer' if target_input['header1'] else None
-        col_index = target_input['col'][0] - 1
-        df = pd.read_csv(infile_labels, sep='\t', header=header,
-                         parse_dates=True)
+        target_input = params["mode_selection"]["cv_selector"].pop("target_input")
+        header = "infer" if target_input["header1"] else None
+        col_index = target_input["col"][0] - 1
+        df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
         y = df.iloc[:, col_index].values
 
     # construct the cv splitter object
-    splitter, groups = get_cv(params['mode_selection']['cv_selector'])
+    splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
 
     total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
     if nth_split > total_n_splits:
-        raise ValueError("Total number of splits is {}, but got `nth_split` "
-                         "= {}".format(total_n_splits, nth_split))
+        raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split))
 
     i = 1
     for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
@@ -79,8 +83,14 @@
     return train, test
 
 
-def main(inputs, infile_array, outfile_train, outfile_test,
-         infile_labels=None, infile_groups=None):
+def main(
+    inputs,
+    infile_array,
+    outfile_train,
+    outfile_test,
+    infile_labels=None,
+    infile_groups=None,
+):
     """
     Parameter
     ---------
@@ -102,45 +112,41 @@
     outfile_test : str
         File path to dataset containing test split
     """
-    warnings.simplefilter('ignore')
+    warnings.simplefilter("ignore")
 
-    with open(inputs, 'r') as param_handler:
+    with open(inputs, "r") as param_handler:
         params = json.load(param_handler)
 
-    input_header = params['header0']
-    header = 'infer' if input_header else None
-    array = pd.read_csv(infile_array, sep='\t', header=header,
-                        parse_dates=True)
+    input_header = params["header0"]
+    header = "infer" if input_header else None
+    array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
 
     # train test split
-    if params['mode_selection']['selected_mode'] == 'train_test_split':
-        options = params['mode_selection']['options']
-        shuffle_selection = options.pop('shuffle_selection')
-        options['shuffle'] = shuffle_selection['shuffle']
+    if params["mode_selection"]["selected_mode"] == "train_test_split":
+        options = params["mode_selection"]["options"]
+        shuffle_selection = options.pop("shuffle_selection")
+        options["shuffle"] = shuffle_selection["shuffle"]
         if infile_labels:
-            header = 'infer' if shuffle_selection['header1'] else None
-            col_index = shuffle_selection['col'][0] - 1
-            df = pd.read_csv(infile_labels, sep='\t', header=header,
-                             parse_dates=True)
+            header = "infer" if shuffle_selection["header1"] else None
+            col_index = shuffle_selection["col"][0] - 1
+            df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
             labels = df.iloc[:, col_index].values
-            options['labels'] = labels
+            options["labels"] = labels
 
         train, test = train_test_split(array, **options)
 
     # cv splitter
     else:
-        train, test = _get_single_cv_split(params, array,
-                                           infile_labels=infile_labels,
-                                           infile_groups=infile_groups)
+        train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups)
 
     print("Input shape: %s" % repr(array.shape))
     print("Train shape: %s" % repr(train.shape))
     print("Test shape: %s" % repr(test.shape))
-    train.to_csv(outfile_train, sep='\t', header=input_header, index=False)
-    test.to_csv(outfile_test, sep='\t', header=input_header, index=False)
+    train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
+    test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-X", "--infile_array", dest="infile_array")
@@ -150,5 +156,11 @@
     aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
     args = aparser.parse_args()
 
-    main(args.inputs, args.infile_array, args.outfile_train,
-         args.outfile_test, args.infile_labels, args.infile_groups)
+    main(
+        args.inputs,
+        args.infile_array,
+        args.outfile_train,
+        args.outfile_test,
+        args.infile_labels,
+        args.infile_groups,
+    )