Mercurial > repos > bgruening > sklearn_searchcv
comparison search_model_validation.xml @ 5:0987bc3904a0 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2a058459e6daf0486871f93845f00fdb4a4eaca1
author | bgruening |
---|---|
date | Sat, 29 Sep 2018 07:26:39 -0400 |
parents | 2e6540c11251 |
children | 7509d7059040 |
comparison
equal
deleted
inserted
replaced
4:2e6540c11251 | 5:0987bc3904a0 |
---|---|
23 import pandas | 23 import pandas |
24 import skrebate | 24 import skrebate |
25 from sklearn import model_selection | 25 from sklearn import model_selection |
26 from sklearn.exceptions import FitFailedWarning | 26 from sklearn.exceptions import FitFailedWarning |
27 | 27 |
28 execfile("$__tool_directory__/sk_whitelist.py") | 28 with open("$__tool_directory__/sk_whitelist.json", "r") as f: |
29 execfile("$__tool_directory__/utils.py", globals()) | 29 sk_whitelist = json.load(f) |
30 exec(open("$__tool_directory__/utils.py").read(), globals()) | |
30 | 31 |
31 warnings.simplefilter('ignore') | 32 warnings.simplefilter('ignore') |
32 | 33 |
33 input_json_path = sys.argv[1] | 34 input_json_path = sys.argv[1] |
34 with open(input_json_path, "r") as param_handler: | 35 with open(input_json_path, "r") as param_handler: |
93 options['refit'] = 'primary' | 94 options['refit'] = 'primary' |
94 if 'pre_dispatch' in options and options['pre_dispatch'] == '': | 95 if 'pre_dispatch' in options and options['pre_dispatch'] == '': |
95 options['pre_dispatch'] = None | 96 options['pre_dispatch'] = None |
96 | 97 |
97 with open(infile_pipeline, 'rb') as pipeline_handler: | 98 with open(infile_pipeline, 'rb') as pipeline_handler: |
98 pipeline = SafePickler.load(pipeline_handler) | 99 pipeline = load_model(pipeline_handler) |
99 | 100 |
100 search_params = get_search_params(params_builder) | 101 search_params = get_search_params(params_builder) |
101 searcher = optimizers(pipeline, search_params, **options) | 102 searcher = optimizers(pipeline, search_params, **options) |
102 | 103 |
103 warnings.simplefilter('always', FitFailedWarning) | 104 if options['error_score'] == 'raise': |
104 with warnings.catch_warnings(record=True) as w: | 105 searcher.fit(X, y) |
105 try: | 106 else: |
106 searcher.fit(X, y) | 107 warnings.simplefilter('always', FitFailedWarning) |
107 except ValueError: | 108 with warnings.catch_warnings(record=True) as w: |
108 pass | 109 try: |
109 for warning in w: | 110 searcher.fit(X, y) |
110 print(repr(warning.message)) | 111 except ValueError: |
112 pass | |
113 for warning in w: | |
114 print(repr(warning.message)) | |
111 | 115 |
112 cv_result = pandas.DataFrame(searcher.cv_results_) | 116 cv_result = pandas.DataFrame(searcher.cv_results_) |
113 cv_result.to_csv(path_or_buf=outfile_result, sep='\t', header=True, index=False) | 117 cv_result.to_csv(path_or_buf=outfile_result, sep='\t', header=True, index=False) |
114 | 118 |
115 #if $save: | 119 #if $save: |
167 <param name="header1" value="true" /> | 171 <param name="header1" value="true" /> |
168 <param name="selected_column_selector_option" value="all_columns"/> | 172 <param name="selected_column_selector_option" value="all_columns"/> |
169 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> | 173 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> |
170 <param name="header2" value="true" /> | 174 <param name="header2" value="true" /> |
171 <param name="selected_column_selector_option2" value="all_columns"/> | 175 <param name="selected_column_selector_option2" value="all_columns"/> |
172 <output name="outfile_result" > | 176 <output name="outfile_result"> |
173 <assert_contents> | 177 <assert_contents> |
174 <has_text_matching expression="[^/d]+0.7938837807353147[^/d]+{u'estimator__C': 1, u'preprocessing_2__k': 9}[^/d]+1" /> | 178 <has_n_columns n="13"/> |
175 <has_text text="0.0"/> | 179 <has_text text="0.7938837807353147"/> |
176 </assert_contents> | 180 <has_text text="{'estimator__C': 1, 'preprocessing_2__k': 9}"/> |
177 </output> | 181 </assert_contents> |
182 </output> | |
183 </test> | |
184 <test expect_failure="true"> | |
185 <param name="selected_search_scheme" value="GridSearchCV"/> | |
186 <param name="infile_pipeline" value="pipeline01" ftype="zip"/> | |
187 <conditional name="search_param_selector"> | |
188 <param name="search_p" value="C: [1, 10, 100, 1000]"/> | |
189 <param name="selected_param_type" value="final_estimator_p"/> | |
190 </conditional> | |
191 <conditional name="search_param_selector"> | |
192 <param name="search_p" value="k: [-1, 3, 5, 7, 9]"/> | |
193 <param name="selected_param_type" value="prep_2_p"/> | |
194 </conditional> | |
195 <param name="error_score" value="true"/> | |
196 <param name="infile1" value="regression_X.tabular" ftype="tabular"/> | |
197 <param name="header1" value="true" /> | |
198 <param name="selected_column_selector_option" value="all_columns"/> | |
199 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> | |
200 <param name="header2" value="true" /> | |
201 <param name="selected_column_selector_option2" value="all_columns"/> | |
178 </test> | 202 </test> |
179 <test> | 203 <test> |
180 <param name="selected_search_scheme" value="RandomizedSearchCV"/> | 204 <param name="selected_search_scheme" value="RandomizedSearchCV"/> |
181 <param name="infile_pipeline" value="pipeline01" ftype="zip"/> | 205 <param name="infile_pipeline" value="pipeline01" ftype="zip"/> |
182 <conditional name="search_param_selector"> | 206 <conditional name="search_param_selector"> |
370 <param name="header1" value="true" /> | 394 <param name="header1" value="true" /> |
371 <param name="selected_column_selector_option" value="all_columns"/> | 395 <param name="selected_column_selector_option" value="all_columns"/> |
372 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> | 396 <param name="infile2" value="regression_y.tabular" ftype="tabular"/> |
373 <param name="header2" value="true" /> | 397 <param name="header2" value="true" /> |
374 <param name="selected_column_selector_option2" value="all_columns"/> | 398 <param name="selected_column_selector_option2" value="all_columns"/> |
375 <output name="outfile_estimator" file="searchCV01" compare="sim_size" delta="1"/> | 399 <output name="outfile_estimator" file="searchCV02" compare="sim_size" delta="1"/> |
376 </test> | 400 </test> |
377 <test> | 401 <test> |
378 <param name="selected_search_scheme" value="GridSearchCV"/> | 402 <param name="selected_search_scheme" value="GridSearchCV"/> |
379 <param name="infile_pipeline" value="pipeline03" ftype="zip"/> | 403 <param name="infile_pipeline" value="pipeline03" ftype="zip"/> |
380 <conditional name="search_param_selector"> | 404 <conditional name="search_param_selector"> |
394 <param name="header2" value="true" /> | 418 <param name="header2" value="true" /> |
395 <param name="selected_column_selector_option2" value="all_columns"/> | 419 <param name="selected_column_selector_option2" value="all_columns"/> |
396 <output name="outfile_result" > | 420 <output name="outfile_result" > |
397 <assert_contents> | 421 <assert_contents> |
398 <has_n_columns n="13" /> | 422 <has_n_columns n="13" /> |
399 <has_text text="0.05366527890058046"/> | 423 <has_text text="0.09003449195911103"/> |
400 </assert_contents> | 424 </assert_contents> |
401 </output> | 425 </output> |
402 </test> | 426 </test> |
403 <test> | 427 <test> |
404 <param name="selected_search_scheme" value="GridSearchCV"/> | 428 <param name="selected_search_scheme" value="GridSearchCV"/> |