Mercurial > repos > bgruening > sklearn_build_pipeline
annotate simple_model_fit.py @ 16:840f29aad145 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 756f8be9c3cd437e131e6410cd625c24fe078e8c"
author | bgruening |
---|---|
date | Wed, 22 Jan 2020 08:02:37 -0500 |
parents | c33145a815ee |
children | 4de3d598c116 |
rev | line source |
---|---|
13
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
1 import argparse |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
2 import json |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
3 import pandas as pd |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
4 import pickle |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
5 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
6 from galaxy_ml.utils import load_model, read_columns |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
7 from sklearn.pipeline import Pipeline |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
8 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
9 |
14
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
11 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
12 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
13 # TODO import from galaxy_ml.utils in future versions |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
14 def clean_params(estimator, n_jobs=None): |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
15 """clean unwanted hyperparameter settings |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
16 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
17 If n_jobs is not None, set it into the estimator, if applicable |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
18 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
19 Return |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
20 ------ |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
21 Cleaned estimator object |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
22 """ |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
24 'ReduceLROnPlateau', 'CSVLogger', 'None') |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
25 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
26 estimator_params = estimator.get_params() |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
27 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
28 for name, p in estimator_params.items(): |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
29 # all potential unauthorized file write |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
30 if name == 'memory' or name.endswith('__memory') \ |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
31 or name.endswith('_path'): |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
32 new_p = {name: None} |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
33 estimator.set_params(**new_p) |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
34 elif n_jobs is not None and (name == 'n_jobs' or |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
35 name.endswith('__n_jobs')): |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
36 new_p = {name: n_jobs} |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
37 estimator.set_params(**new_p) |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
38 elif name.endswith('callbacks'): |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
39 for cb in p: |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
40 cb_type = cb['callback_selection']['callback_type'] |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
41 if cb_type not in ALLOWED_CALLBACKS: |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
42 raise ValueError( |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
43 "Prohibited callback type: %s!" % cb_type) |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
44 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
45 return estimator |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
46 |
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
47 |
13
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
48 def _get_X_y(params, infile1, infile2): |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
49 """ read from inputs and output X and y |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
50 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
51 Parameters |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
52 ---------- |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
53 params : dict |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
54 Tool inputs parameter |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
55 infile1 : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
56 File path to dataset containing features |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
57 infile2 : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
58 File path to dataset containing target values |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
59 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
60 """ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
61 # store read dataframe object |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
62 loaded_df = {} |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
63 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
64 input_type = params['input_options']['selected_input'] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
65 # tabular input |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
66 if input_type == 'tabular': |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
67 header = 'infer' if params['input_options']['header1'] else None |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
68 column_option = (params['input_options']['column_selector_options_1'] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
69 ['selected_column_selector_option']) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
70 if column_option in ['by_index_number', 'all_but_by_index_number', |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
71 'by_header_name', 'all_but_by_header_name']: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
72 c = params['input_options']['column_selector_options_1']['col1'] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
73 else: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
74 c = None |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
75 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
76 df_key = infile1 + repr(header) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
77 df = pd.read_csv(infile1, sep='\t', header=header, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
78 parse_dates=True) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
79 loaded_df[df_key] = df |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
80 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
81 X = read_columns(df, c=c, c_option=column_option).astype(float) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
82 # sparse input |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
83 elif input_type == 'sparse': |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
84 X = mmread(open(infile1, 'r')) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
85 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
86 # Get target y |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
87 header = 'infer' if params['input_options']['header2'] else None |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
88 column_option = (params['input_options']['column_selector_options_2'] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
89 ['selected_column_selector_option2']) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
90 if column_option in ['by_index_number', 'all_but_by_index_number', |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
91 'by_header_name', 'all_but_by_header_name']: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
92 c = params['input_options']['column_selector_options_2']['col2'] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
93 else: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
94 c = None |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
95 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
96 df_key = infile2 + repr(header) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
97 if df_key in loaded_df: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
98 infile2 = loaded_df[df_key] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
99 else: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
100 infile2 = pd.read_csv(infile2, sep='\t', |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
101 header=header, parse_dates=True) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
102 loaded_df[df_key] = infile2 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
103 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
104 y = read_columns( |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
105 infile2, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
106 c=c, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
107 c_option=column_option, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
108 sep='\t', |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
109 header=header, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
110 parse_dates=True) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
111 if len(y.shape) == 2 and y.shape[1] == 1: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
112 y = y.ravel() |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
113 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
114 return X, y |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
115 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
116 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
117 def main(inputs, infile_estimator, infile1, infile2, out_object, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
118 out_weights=None): |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
119 """ main |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
120 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
121 Parameters |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
122 ---------- |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
123 inputs : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
124 File path to galaxy tool parameter |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
125 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
126 infile_estimator : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
127 File paths of input estimator |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
128 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
129 infile1 : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
130 File path to dataset containing features |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
131 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
132 infile2 : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
133 File path to dataset containing target labels |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
134 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
135 out_object : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
136 File path for output of fitted model or skeleton |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
137 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
138 out_weights : str |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
139 File path for output of weights |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
140 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
141 """ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
142 with open(inputs, 'r') as param_handler: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
143 params = json.load(param_handler) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
144 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
145 # load model |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
146 with open(infile_estimator, 'rb') as est_handler: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
147 estimator = load_model(est_handler) |
14
c33145a815ee
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
bgruening
parents:
13
diff
changeset
|
148 estimator = clean_params(estimator, n_jobs=N_JOBS) |
13
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
149 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
150 X_train, y_train = _get_X_y(params, infile1, infile2) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
151 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
152 estimator.fit(X_train, y_train) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
153 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
154 main_est = estimator |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
155 if isinstance(main_est, Pipeline): |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
156 main_est = main_est.steps[-1][-1] |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
157 if hasattr(main_est, 'model_') \ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
158 and hasattr(main_est, 'save_weights'): |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
159 if out_weights: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
160 main_est.save_weights(out_weights) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
161 del main_est.model_ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
162 del main_est.fit_params |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
163 del main_est.model_class_ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
164 del main_est.validation_data |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
165 if getattr(main_est, 'data_generator_', None): |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
166 del main_est.data_generator_ |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
167 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
168 with open(out_object, 'wb') as output_handler: |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
169 pickle.dump(estimator, output_handler, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
170 pickle.HIGHEST_PROTOCOL) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
171 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
172 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
173 if __name__ == '__main__': |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
174 aparser = argparse.ArgumentParser() |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
175 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
176 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator") |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
177 aparser.add_argument("-y", "--infile1", dest="infile1") |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
178 aparser.add_argument("-g", "--infile2", dest="infile2") |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
179 aparser.add_argument("-o", "--out_object", dest="out_object") |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
180 aparser.add_argument("-t", "--out_weights", dest="out_weights") |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
181 args = aparser.parse_args() |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
182 |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
183 main(args.inputs, args.infile_estimator, args.infile1, |
6d056a1b2249
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
184 args.infile2, args.out_object, args.out_weights) |