comparison train_test_split.py @ 29:95d0f81e46e8 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author bgruening
date Fri, 01 Nov 2019 17:14:31 -0400
parents
children eeaf989f1024
comparison
equal deleted inserted replaced
28:7696d389675c 29:95d0f81e46e8
1 import argparse
2 import json
3 import pandas as pd
4 import warnings
5
6 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns
8
9
10 def _get_single_cv_split(params, array, infile_labels=None,
11 infile_groups=None):
12 """ output (train, test) subset from a cv splitter
13
14 Parameters
15 ----------
16 params : dict
17 Galaxy tool inputs
18 array : pandas DataFrame object
19 The target dataset to split
20 infile_labels : str
21 File path to dataset containing target values
22 infile_groups : str
23 File path to dataset containing group values
24 """
25 y = None
26 groups = None
27
28 nth_split = params['mode_selection']['nth_split']
29
30 # read groups
31 if infile_groups:
32 header = 'infer' if (params['mode_selection']['cv_selector']
33 ['groups_selector']['header_g']) else None
34 column_option = (params['mode_selection']['cv_selector']
35 ['groups_selector']['column_selector_options_g']
36 ['selected_column_selector_option_g'])
37 if column_option in ['by_index_number', 'all_but_by_index_number',
38 'by_header_name', 'all_but_by_header_name']:
39 c = (params['mode_selection']['cv_selector']['groups_selector']
40 ['column_selector_options_g']['col_g'])
41 else:
42 c = None
43
44 groups = read_columns(infile_groups, c=c, c_option=column_option,
45 sep='\t', header=header, parse_dates=True)
46 groups = groups.ravel()
47
48 params['mode_selection']['cv_selector']['groups_selector'] = groups
49
50 # read labels
51 if infile_labels:
52 target_input = (params['mode_selection']
53 ['cv_selector'].pop('target_input'))
54 header = 'infer' if target_input['header1'] else None
55 col_index = target_input['col'][0] - 1
56 df = pd.read_csv(infile_labels, sep='\t', header=header,
57 parse_dates=True)
58 y = df.iloc[:, col_index].values
59
60 # construct the cv splitter object
61 splitter, groups = get_cv(params['mode_selection']['cv_selector'])
62
63 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
64 if nth_split > total_n_splits:
65 raise ValueError("Total number of splits is {}, but got `nth_split` "
66 "= {}".format(total_n_splits, nth_split))
67
68 i = 1
69 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
70 # suppose nth_split >= 1
71 if i == nth_split:
72 break
73 else:
74 i += 1
75
76 train = array.iloc[train_index, :]
77 test = array.iloc[test_index, :]
78
79 return train, test
80
81
82 def main(inputs, infile_array, outfile_train, outfile_test,
83 infile_labels=None, infile_groups=None):
84 """
85 Parameter
86 ---------
87 inputs : str
88 File path to galaxy tool parameter
89
90 infile_array : str
91 File paths of input arrays separated by comma
92
93 infile_labels : str
94 File path to dataset containing labels
95
96 infile_groups : str
97 File path to dataset containing groups
98
99 outfile_train : str
100 File path to dataset containing train split
101
102 outfile_test : str
103 File path to dataset containing test split
104 """
105 warnings.simplefilter('ignore')
106
107 with open(inputs, 'r') as param_handler:
108 params = json.load(param_handler)
109
110 input_header = params['header0']
111 header = 'infer' if input_header else None
112 array = pd.read_csv(infile_array, sep='\t', header=header,
113 parse_dates=True)
114
115 # train test split
116 if params['mode_selection']['selected_mode'] == 'train_test_split':
117 options = params['mode_selection']['options']
118 shuffle_selection = options.pop('shuffle_selection')
119 options['shuffle'] = shuffle_selection['shuffle']
120 if infile_labels:
121 header = 'infer' if shuffle_selection['header1'] else None
122 col_index = shuffle_selection['col'][0] - 1
123 df = pd.read_csv(infile_labels, sep='\t', header=header,
124 parse_dates=True)
125 labels = df.iloc[:, col_index].values
126 options['labels'] = labels
127
128 train, test = train_test_split(array, **options)
129
130 # cv splitter
131 else:
132 train, test = _get_single_cv_split(params, array,
133 infile_labels=infile_labels,
134 infile_groups=infile_groups)
135
136 print("Input shape: %s" % repr(array.shape))
137 print("Train shape: %s" % repr(train.shape))
138 print("Test shape: %s" % repr(test.shape))
139 train.to_csv(outfile_train, sep='\t', header=input_header, index=False)
140 test.to_csv(outfile_test, sep='\t', header=input_header, index=False)
141
142
143 if __name__ == '__main__':
144 aparser = argparse.ArgumentParser()
145 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
146 aparser.add_argument("-X", "--infile_array", dest="infile_array")
147 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
148 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
149 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
150 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
151 args = aparser.parse_args()
152
153 main(args.inputs, args.infile_array, args.outfile_train,
154 args.outfile_test, args.infile_labels, args.infile_groups)