annotate train_test_split.py @ 4:aff86cee673c draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:48:19 +0000
parents c3bafda50176
children c48ffc96fe79
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
1 import argparse
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
2 import json
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
3 import warnings
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
4
4
aff86cee673c "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
bgruening
parents: 3
diff changeset
5 import pandas as pd
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
6 from galaxy_ml.model_validations import train_test_split
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
7 from galaxy_ml.utils import get_cv, read_columns
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
8
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
9
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
11 """output (train, test) subset from a cv splitter
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
12
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
13 Parameters
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
14 ----------
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
15 params : dict
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
16 Galaxy tool inputs
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
17 array : pandas DataFrame object
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
18 The target dataset to split
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
19 infile_labels : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
20 File path to dataset containing target values
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
21 infile_groups : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
22 File path to dataset containing group values
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
23 """
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
24 y = None
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
25 groups = None
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
26
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
27 nth_split = params["mode_selection"]["nth_split"]
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
28
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
29 # read groups
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
30 if infile_groups:
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
31 header = "infer" if (params["mode_selection"]["cv_selector"]["groups_selector"]["header_g"]) else None
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
32 column_option = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"][
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
33 "selected_column_selector_option_g"
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
34 ]
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
35 if column_option in [
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
36 "by_index_number",
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
37 "all_but_by_index_number",
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
38 "by_header_name",
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
39 "all_but_by_header_name",
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
40 ]:
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
41 c = params["mode_selection"]["cv_selector"]["groups_selector"]["column_selector_options_g"]["col_g"]
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
42 else:
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
43 c = None
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
44
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
45 groups = read_columns(
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
46 infile_groups,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
47 c=c,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
48 c_option=column_option,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
49 sep="\t",
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
50 header=header,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
51 parse_dates=True,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
52 )
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
53 groups = groups.ravel()
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
54
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
55 params["mode_selection"]["cv_selector"]["groups_selector"] = groups
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
56
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
57 # read labels
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
58 if infile_labels:
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
59 target_input = params["mode_selection"]["cv_selector"].pop("target_input")
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
60 header = "infer" if target_input["header1"] else None
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
61 col_index = target_input["col"][0] - 1
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
62 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
63 y = df.iloc[:, col_index].values
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
64
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
65 # construct the cv splitter object
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
66 splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
67
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
68 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
69 if nth_split > total_n_splits:
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
70 raise ValueError("Total number of splits is {}, but got `nth_split` " "= {}".format(total_n_splits, nth_split))
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
71
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
72 i = 1
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
73 for train_index, test_index in splitter.split(array.values, y=y, groups=groups):
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
74 # suppose nth_split >= 1
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
75 if i == nth_split:
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
76 break
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
77 else:
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
78 i += 1
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
79
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
80 train = array.iloc[train_index, :]
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
81 test = array.iloc[test_index, :]
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
82
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
83 return train, test
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
84
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
85
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
86 def main(
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
87 inputs,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
88 infile_array,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
89 outfile_train,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
90 outfile_test,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
91 infile_labels=None,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
92 infile_groups=None,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
93 ):
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
94 """
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
95 Parameter
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
96 ---------
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
97 inputs : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
98 File path to galaxy tool parameter
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
99
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
100 infile_array : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
101 File paths of input arrays separated by comma
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
102
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
103 infile_labels : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
104 File path to dataset containing labels
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
105
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
106 infile_groups : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
107 File path to dataset containing groups
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
108
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
109 outfile_train : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
110 File path to dataset containing train split
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
111
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
112 outfile_test : str
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
113 File path to dataset containing test split
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
114 """
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
115 warnings.simplefilter("ignore")
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
116
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
117 with open(inputs, "r") as param_handler:
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
118 params = json.load(param_handler)
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
119
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
120 input_header = params["header0"]
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
121 header = "infer" if input_header else None
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
122 array = pd.read_csv(infile_array, sep="\t", header=header, parse_dates=True)
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
123
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
124 # train test split
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
125 if params["mode_selection"]["selected_mode"] == "train_test_split":
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
126 options = params["mode_selection"]["options"]
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
127 shuffle_selection = options.pop("shuffle_selection")
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
128 options["shuffle"] = shuffle_selection["shuffle"]
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
129 if infile_labels:
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
130 header = "infer" if shuffle_selection["header1"] else None
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
131 col_index = shuffle_selection["col"][0] - 1
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
132 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
133 labels = df.iloc[:, col_index].values
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
134 options["labels"] = labels
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
135
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
136 train, test = train_test_split(array, **options)
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
137
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
138 # cv splitter
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
139 else:
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
140 train, test = _get_single_cv_split(params, array, infile_labels=infile_labels, infile_groups=infile_groups)
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
141
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
142 print("Input shape: %s" % repr(array.shape))
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
143 print("Train shape: %s" % repr(train.shape))
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
144 print("Test shape: %s" % repr(test.shape))
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
145 train.to_csv(outfile_train, sep="\t", header=input_header, index=False)
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
146 test.to_csv(outfile_test, sep="\t", header=input_header, index=False)
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
147
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
148
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
149 if __name__ == "__main__":
0
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
150 aparser = argparse.ArgumentParser()
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
151 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
152 aparser.add_argument("-X", "--infile_array", dest="infile_array")
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
153 aparser.add_argument("-y", "--infile_labels", dest="infile_labels")
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
154 aparser.add_argument("-g", "--infile_groups", dest="infile_groups")
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
155 aparser.add_argument("-o", "--outfile_train", dest="outfile_train")
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
156 aparser.add_argument("-t", "--outfile_test", dest="outfile_test")
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
157 args = aparser.parse_args()
13226b2ddfb4 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
bgruening
parents:
diff changeset
158
3
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
159 main(
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
160 args.inputs,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
161 args.infile_array,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
162 args.outfile_train,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
163 args.outfile_test,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
164 args.infile_labels,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
165 args.infile_groups,
c3bafda50176 "planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
bgruening
parents: 0
diff changeset
166 )