comparison train_test_split.py @ 19:d67dcd63f6cb draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 17:32:55 +0000
parents 9295c5f34630
children a5665e1b06b0
comparison
equal deleted inserted replaced
18:02d2be90dedb 19:d67dcd63f6cb
5 5
6 from galaxy_ml.model_validations import train_test_split 6 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns 7 from galaxy_ml.utils import get_cv, read_columns
8 8
9 9
10 def _get_single_cv_split(params, array, infile_labels=None, 10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
11 infile_groups=None): 11 """output (train, test) subset from a cv splitter
12 """ output (train, test) subset from a cv splitter
13 12
14 Parameters 13 Parameters
15 ---------- 14 ----------
16 params : dict 15 params : dict
17 Galaxy tool inputs 16 Galaxy tool inputs
23 File path to dataset containing group values 22 File path to dataset containing group values
24 """ 23 """
25 y = None 24 y = None
26 groups = None 25 groups = None
27 26
28 nth_split = params['mode_selection']['nth_split'] 27 nth_split = params["mode_selection"]["nth_split"]
29 28
30 # read groups 29 # read groups
31 if infile_groups: 30 if infile_groups:
32 header = 'infer' if (params['mode_selection']['cv_selector'] 31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None
33 ['groups_selector']['header_g']) else None 32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][
34 column_option = (params['mode_selection']['cv_selector'] 33 "selected_column_selector_option_g"
35 ['groups_selector']['column_selector_options_g'] 34 ]
36 ['selected_column_selector_option_g']) 35 if column_option in [
37 if column_option in ['by_index_number', 'all_but_by_index_number', 36 "by_index_number",
38 'by_header_name', 'all_but_by_header_name']: 37 "all_but_by_index_number",
39 c = (params['mode_selection']['cv_selector']['groups_selector'] 38 "by_header_name",
40 ['column_selector_options_g']['col_g']) 39 "all_but_by_header_name",
40 ]:
41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
41 else: 42 else:
42 c = None 43 c = None
43 44
44 groups = read_columns(infile_groups, c=c, c_option=column_option, 45 groups = read_columns(
45 sep='\t', header=header, parse_dates=True) 46 infile_groups,
47 c=c,
48 c_option=column_option,
49 sep="\t",
50 header=header,
51 parse_dates=True,
52 )
46 groups = groups.ravel() 53 groups = groups.ravel()
47 54
48 params['mode_selection']['cv_selector']['groups_selector'] = groups 55 params["mode_selection"]["cv_selector"]["groups_selector"] = groups
49 56
50 # read labels 57 # read labels
51 if infile_labels: 58 if infile_labels:
52 target_input = (params['mode_selection'] 59 target_input = params["mode_selection"]["cv_selector"].pop("target_input")
53 ['cv_selector'].pop('target_input')) 60 header = "infer" if target_input["header1"] else None
54 header = 'infer' if target_input['header1'] else None 61 col_index = target_input["col"][0] - 1
55 col_index = target_input['col'][0] - 1 62 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
56 df = pd.read_csv(infile_labels, sep='\t', header=header,
57 parse_dates=True)
58 y = df.iloc[:, col_index].values 63 y = df.iloc[:, col_index].values
59 64
60 # construct the cv splitter object 65 # construct the cv splitter object
61 splitter, groups = get_cv(params['mode_selection']['cv_selector']) 66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
62 67
63 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) 68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
64 if nth_split > total_n_splits: 69 if nth_split > total_n_splits:
65 raise ValueError("Total number of splits is {}, but got `nth_split` " 70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split))
66 "= {}".format(total_n_splits, nth_split))
67 71
68 i = 1 72 i = 1
69 for train_index, test_index in splitter.split(array.values, y=y, groups=groups): 73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
70 # suppose nth_split >= 1 74 # suppose nth_split >= 1
71 if i == nth_split: 75 if i == nth_split:
77 test = array.iloc[test_index, :] 81 test = array.iloc[test_index, :]
78 82
79 return train, test 83 return train, test
80 84
81 85
82 def main(inputs, infile_array, outfile_train, outfile_test, 86 def main(
83 infile_labels=None, infile_groups=None): 87 inputs,
88 infile_array,
89 outfile_train,
90 outfile_test,
91 infile_labels=None,
92 infile_groups=None,
93 ):
84 """ 94 """
85 Parameter 95 Parameter
86 --------- 96 ---------
87 inputs : str 97 inputs : str
88 File path to galaxy tool parameter 98 File path to galaxy tool parameter
100 File path to dataset containing train split 110 File path to dataset containing train split
101 111
102 outfile_test : str 112 outfile_test : str
103 File path to dataset containing test split 113 File path to dataset containing test split
104 """ 114 """
105 warnings.simplefilter('ignore') 115 warnings.simplefilter("ignore")
106 116
107 with open(inputs, 'r') as param_handler: 117 with open(inputs, "r") as param_handler:
108 params = json.load(param_handler) 118 params = json.load(param_handler)
109 119
110 input_header = params['header0'] 120 input_header = params["header0"]
111 header = 'infer' if input_header else None 121 header = "infer" if input_header else None
112 array = pd.read_csv(infile_array, sep='\t', header=header, 122 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
113 parse_dates=True)
114 123
115 # train test split 124 # train test split
116 if params['mode_selection']['selected_mode'] == 'train_test_split': 125 if params["mode_selection"]["selected_mode"] == "train_test_split":
117 options = params['mode_selection']['options'] 126 options = params["mode_selection"]["options"]
118 shuffle_selection = options.pop('shuffle_selection') 127 shuffle_selection = options.pop("shuffle_selection")
119 options['shuffle'] = shuffle_selection['shuffle'] 128 options["shuffle"] = shuffle_selection["shuffle"]
120 if infile_labels: 129 if infile_labels:
121 header = 'infer' if shuffle_selection['header1'] else None 130 header = "infer" if shuffle_selection["header1"] else None
122 col_index = shuffle_selection['col'][0] - 1 131 col_index = shuffle_selection["col"][0] - 1
123 df = pd.read_csv(infile_labels, sep='\t', header=header, 132 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
124 parse_dates=True)
125 labels = df.iloc[:, col_index].values 133 labels = df.iloc[:, col_index].values
126 options['labels'] = labels 134 options["labels"] = labels
127 135
128 train, test = train_test_split(array, **options) 136 train, test = train_test_split(array, **options)
129 137
130 # cv splitter 138 # cv splitter
131 else: 139 else:
132 train, test = _get_single_cv_split(params, array, 140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups)
133 infile_labels=infile_labels,
134 infile_groups=infile_groups)
135 141
136 print("Input shape: %s" % repr(array.shape)) 142 print("Input shape: %s" % repr(array.shape))
137 print("Train shape: %s" % repr(train.shape)) 143 print("Train shape: %s" % repr(train.shape))
138 print("Test shape: %s" % repr(test.shape)) 144 print("Test shape: %s" % repr(test.shape))
139 train.to_csv(outfile_train, sep='\t', header=input_header, index=False) 145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
140 test.to_csv(outfile_test, sep='\t', header=input_header, index=False) 146 test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
141 147
142 148
143 if __name__ == '__main__': 149 if __name__ == "__main__":
144 aparser = argparse.ArgumentParser() 150 aparser = argparse.ArgumentParser()
145 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 151 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
146 aparser.add_argument("-X", "--infile_array", dest="infile_array") 152 aparser.add_argument("-X", "--infile_array", dest="infile_array")
147 aparser.add_argument("-y", "--infile_labels", dest="infile_labels") 153 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
148 aparser.add_argument("-g", "--infile_groups", dest="infile_groups") 154 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
149 aparser.add_argument("-o", "--outfile_train", dest="outfile_train") 155 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
150 aparser.add_argument("-t", "--outfile_test", dest="outfile_test") 156 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
151 args = aparser.parse_args() 157 args = aparser.parse_args()
152 158
153 main(args.inputs, args.infile_array, args.outfile_train, 159 main(
154 args.outfile_test, args.infile_labels, args.infile_groups) 160 args.inputs,
161 args.infile_array,
162 args.outfile_train,
163 args.outfile_test,
164 args.infile_labels,
165 args.infile_groups,
166 )