comparison train_test_split.py @ 0:59e8b4328c82 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:40:10 +0000
parents
children f93f0cdbaf18
comparison
equal deleted inserted replaced
-1:000000000000 0:59e8b4328c82
1 import argparse
2 import json
3 import warnings
4
5 import pandas as pd
6 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns
8
9
10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
11 """output (train, test) subset from a cv splitter
12
13 Parameters
14 ----------
15 params : dict
16 Galaxy tool inputs
17 array : pandas DataFrame object
18 The target dataset to split
19 infile_labels : str
20 File path to dataset containing target values
21 infile_groups : str
22 File path to dataset containing group values
23 """
24 y = None
25 groups = None
26
27 nth_split = params["mode_selection"]["nth_split"]
28
29 # read groups
30 if infile_groups:
31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None
32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][
33 "selected_column_selector_option_g"
34 ]
35 if column_option in [
36 "by_index_number",
37 "all_but_by_index_number",
38 "by_header_name",
39 "all_but_by_header_name",
40 ]:
41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
42 else:
43 c = None
44
45 groups = read_columns(
46 infile_groups,
47 c=c,
48 c_option=column_option,
49 sep="\t",
50 header=header,
51 parse_dates=True,
52 )
53 groups = groups.ravel()
54
55 params["mode_selection"]["cv_selector"]["groups_selector"] = groups
56
57 # read labels
58 if infile_labels:
59 target_input = params["mode_selection"]["cv_selector"].pop("target_input")
60 header = "infer" if target_input["header1"] else None
61 col_index = target_input["col"][0] - 1
62 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
63 y = df.iloc[:, col_index].values
64
65 # construct the cv splitter object
66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
67
68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
69 if nth_split > total_n_splits:
70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split))
71
72 i = 1
73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
74 # suppose nth_split >= 1
75 if i == nth_split:
76 break
77 else:
78 i += 1
79
80 train = array.iloc[train_index, :]
81 test = array.iloc[test_index, :]
82
83 return train, test
84
85
86 def main(
87 inputs,
88 infile_array,
89 outfile_train,
90 outfile_test,
91 infile_labels=None,
92 infile_groups=None,
93 ):
94 """
95 Parameter
96 ---------
97 inputs : str
98 File path to galaxy tool parameter
99
100 infile_array : str
101 File paths of input arrays separated by comma
102
103 infile_labels : str
104 File path to dataset containing labels
105
106 infile_groups : str
107 File path to dataset containing groups
108
109 outfile_train : str
110 File path to dataset containing train split
111
112 outfile_test : str
113 File path to dataset containing test split
114 """
115 warnings.simplefilter("ignore")
116
117 with open(inputs, "r") as param_handler:
118 params = json.load(param_handler)
119
120 input_header = params["header0"]
121 header = "infer" if input_header else None
122 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
123
124 # train test split
125 if params["mode_selection"]["selected_mode"] == "train_test_split":
126 options = params["mode_selection"]["options"]
127 shuffle_selection = options.pop("shuffle_selection")
128 options["shuffle"] = shuffle_selection["shuffle"]
129 if infile_labels:
130 header = "infer" if shuffle_selection["header1"] else None
131 col_index = shuffle_selection["col"][0] - 1
132 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
133 labels = df.iloc[:, col_index].values
134 options["labels"] = labels
135
136 train, test = train_test_split(array, **options)
137
138 # cv splitter
139 else:
140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups)
141
142 print("Input shape: %s" % repr(array.shape))
143 print("Train shape: %s" % repr(train.shape))
144 print("Test shape: %s" % repr(test.shape))
145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
146 test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
147
148
149 if __name__ == "__main__":
150 aparser = argparse.ArgumentParser()
151 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
152 aparser.add_argument("-X", "--infile_array", dest="infile_array")
153 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
154 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
155 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
156 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
157 args = aparser.parse_args()
158
159 main(
160 args.inputs,
161 args.infile_array,
162 args.outfile_train,
163 args.outfile_test,
164 args.infile_labels,
165 args.infile_groups,
166 )