comparison train_test_eval.py @ 35:1e99cfb71f40 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author bgruening
date Tue, 13 Apr 2021 17:52:15 +0000
parents df579b31311d
children 999e07f0a9fa
comparison
equal deleted inserted replaced
34:7068b5fcd623 35:1e99cfb71f40
1 import argparse 1 import argparse
2 import joblib
3 import json 2 import json
4 import numpy as np
5 import os 3 import os
6 import pandas as pd
7 import pickle 4 import pickle
8 import warnings 5 import warnings
6
9 from itertools import chain 7 from itertools import chain
8
9 import joblib
10 import numpy as np
11 import pandas as pd
12 from galaxy_ml.model_validations import train_test_split
13 from galaxy_ml.utils import (
14 get_module,
15 get_scoring,
16 load_model,
17 read_columns,
18 SafeEval,
19 try_get_attr,
20 )
10 from scipy.io import mmread 21 from scipy.io import mmread
11 from sklearn.base import clone 22 from sklearn import pipeline
12 from sklearn import (cluster, compose, decomposition, ensemble,
13 feature_extraction, feature_selection,
14 gaussian_process, kernel_approximation, metrics,
15 model_selection, naive_bayes, neighbors,
16 pipeline, preprocessing, svm, linear_model,
17 tree, discriminant_analysis)
18 from sklearn.exceptions import FitFailedWarning
19 from sklearn.metrics.scorer import _check_multimetric_scoring 23 from sklearn.metrics.scorer import _check_multimetric_scoring
20 from sklearn.model_selection._validation import _score, cross_validate 24 from sklearn.model_selection._validation import _score
21 from sklearn.model_selection import _search, _validation 25 from sklearn.model_selection import _search, _validation
26 from sklearn.model_selection._validation import _score
22 from sklearn.utils import indexable, safe_indexing 27 from sklearn.utils import indexable, safe_indexing
23 28
24 from galaxy_ml.model_validations import train_test_split 29
25 from galaxy_ml.utils import (SafeEval, get_scoring, load_model, 30 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
26 read_columns, try_get_attr, get_module) 31 setattr(_search, "_fit_and_score", _fit_and_score)
27 32 setattr(_validation, "_fit_and_score", _fit_and_score)
28 33
29 _fit_and_score = try_get_attr('galaxy_ml.model_validations', '_fit_and_score') 34 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
30 setattr(_search, '_fit_and_score', _fit_and_score) 35 CACHE_DIR = os.path.join(os.getcwd(), "cached")
31 setattr(_validation, '_fit_and_score', _fit_and_score)
32
33 N_JOBS = int(os.environ.get('GALAXY_SLOTS', 1))
34 CACHE_DIR = os.path.join(os.getcwd(), 'cached')
35 del os 36 del os
36 NON_SEARCHABLE = ('n_jobs', 'pre_dispatch', 'memory', '_path', 37 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
37 'nthread', 'callbacks') 38 ALLOWED_CALLBACKS = (
38 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', 'ReduceLROnPlateau', 39 "EarlyStopping",
39 'CSVLogger', 'None') 40 "TerminateOnNaN",
41 "ReduceLROnPlateau",
42 "CSVLogger",
43 "None",
44 )
40 45
41 46
42 def _eval_swap_params(params_builder): 47 def _eval_swap_params(params_builder):
43 swap_params = {} 48 swap_params = {}
44 49
45 for p in params_builder['param_set']: 50 for p in params_builder["param_set"]:
46 swap_value = p['sp_value'].strip() 51 swap_value = p["sp_value"].strip()
47 if swap_value == '': 52 if swap_value == "":
48 continue 53 continue
49 54
50 param_name = p['sp_name'] 55 param_name = p["sp_name"]
51 if param_name.lower().endswith(NON_SEARCHABLE): 56 if param_name.lower().endswith(NON_SEARCHABLE):
52 warnings.warn("Warning: `%s` is not eligible for search and was " 57 warnings.warn(
53 "omitted!" % param_name) 58 "Warning: `%s` is not eligible for search and was "
59 "omitted!" % param_name
60 )
54 continue 61 continue
55 62
56 if not swap_value.startswith(':'): 63 if not swap_value.startswith(":"):
57 safe_eval = SafeEval(load_scipy=True, load_numpy=True) 64 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
58 ev = safe_eval(swap_value) 65 ev = safe_eval(swap_value)
59 else: 66 else:
60 # Have `:` before search list, asks for estimator evaluatio 67 # Have `:` before search list, asks for estimator evaluatio
61 safe_eval_es = SafeEval(load_estimators=True) 68 safe_eval_es = SafeEval(load_estimators=True)
78 if arr is None: 85 if arr is None:
79 nones.append(idx) 86 nones.append(idx)
80 else: 87 else:
81 new_arrays.append(arr) 88 new_arrays.append(arr)
82 89
83 if kwargs['shuffle'] == 'None': 90 if kwargs["shuffle"] == "None":
84 kwargs['shuffle'] = None 91 kwargs["shuffle"] = None
85 92
86 group_names = kwargs.pop('group_names', None) 93 group_names = kwargs.pop("group_names", None)
87 94
88 if group_names is not None and group_names.strip(): 95 if group_names is not None and group_names.strip():
89 group_names = [name.strip() for name in 96 group_names = [name.strip() for name in group_names.split(",")]
90 group_names.split(',')]
91 new_arrays = indexable(*new_arrays) 97 new_arrays = indexable(*new_arrays)
92 groups = kwargs['labels'] 98 groups = kwargs["labels"]
93 n_samples = new_arrays[0].shape[0] 99 n_samples = new_arrays[0].shape[0]
94 index_arr = np.arange(n_samples) 100 index_arr = np.arange(n_samples)
95 test = index_arr[np.isin(groups, group_names)] 101 test = index_arr[np.isin(groups, group_names)]
96 train = index_arr[~np.isin(groups, group_names)] 102 train = index_arr[~np.isin(groups, group_names)]
97 rval = list(chain.from_iterable( 103 rval = list(
98 (safe_indexing(a, train), 104 chain.from_iterable(
99 safe_indexing(a, test)) for a in new_arrays)) 105 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays
106 )
107 )
100 else: 108 else:
101 rval = train_test_split(*new_arrays, **kwargs) 109 rval = train_test_split(*new_arrays, **kwargs)
102 110
103 for pos in nones: 111 for pos in nones:
104 rval[pos * 2: 2] = [None, None] 112 rval[pos * 2: 2] = [None, None]
105 113
106 return rval 114 return rval
107 115
108 116
109 def main(inputs, infile_estimator, infile1, infile2, 117 def main(
110 outfile_result, outfile_object=None, 118 inputs,
111 outfile_weights=None, groups=None, 119 infile_estimator,
112 ref_seq=None, intervals=None, targets=None, 120 infile1,
113 fasta_path=None): 121 infile2,
122 outfile_result,
123 outfile_object=None,
124 outfile_weights=None,
125 groups=None,
126 ref_seq=None,
127 intervals=None,
128 targets=None,
129 fasta_path=None,
130 ):
114 """ 131 """
115 Parameter 132 Parameter
116 --------- 133 ---------
117 inputs : str 134 inputs : str
118 File path to galaxy tool parameter 135 File path to galaxy tool parameter
148 File path to dataset compressed target bed file 165 File path to dataset compressed target bed file
149 166
150 fasta_path : str 167 fasta_path : str
151 File path to dataset containing fasta file 168 File path to dataset containing fasta file
152 """ 169 """
153 warnings.simplefilter('ignore') 170 warnings.simplefilter("ignore")
154 171
155 with open(inputs, 'r') as param_handler: 172 with open(inputs, "r") as param_handler:
156 params = json.load(param_handler) 173 params = json.load(param_handler)
157 174
158 # load estimator 175 # load estimator
159 with open(infile_estimator, 'rb') as estimator_handler: 176 with open(infile_estimator, "rb") as estimator_handler:
160 estimator = load_model(estimator_handler) 177 estimator = load_model(estimator_handler)
161 178
162 # swap hyperparameter 179 # swap hyperparameter
163 swapping = params['experiment_schemes']['hyperparams_swapping'] 180 swapping = params["experiment_schemes"]["hyperparams_swapping"]
164 swap_params = _eval_swap_params(swapping) 181 swap_params = _eval_swap_params(swapping)
165 estimator.set_params(**swap_params) 182 estimator.set_params(**swap_params)
166 183
167 estimator_params = estimator.get_params() 184 estimator_params = estimator.get_params()
168 185
169 # store read dataframe object 186 # store read dataframe object
170 loaded_df = {} 187 loaded_df = {}
171 188
172 input_type = params['input_options']['selected_input'] 189 input_type = params["input_options"]["selected_input"]
173 # tabular input 190 # tabular input
174 if input_type == 'tabular': 191 if input_type == "tabular":
175 header = 'infer' if params['input_options']['header1'] else None 192 header = "infer" if params["input_options"]["header1"] else None
176 column_option = (params['input_options']['column_selector_options_1'] 193 column_option = params["input_options"]["column_selector_options_1"][
177 ['selected_column_selector_option']) 194 "selected_column_selector_option"
178 if column_option in ['by_index_number', 'all_but_by_index_number', 195 ]
179 'by_header_name', 'all_but_by_header_name']: 196 if column_option in [
180 c = params['input_options']['column_selector_options_1']['col1'] 197 "by_index_number",
198 "all_but_by_index_number",
199 "by_header_name",
200 "all_but_by_header_name",
201 ]:
202 c = params["input_options"]["column_selector_options_1"]["col1"]
181 else: 203 else:
182 c = None 204 c = None
183 205
184 df_key = infile1 + repr(header) 206 df_key = infile1 + repr(header)
185 df = pd.read_csv(infile1, sep='\t', header=header, 207 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
186 parse_dates=True)
187 loaded_df[df_key] = df 208 loaded_df[df_key] = df
188 209
189 X = read_columns(df, c=c, c_option=column_option).astype(float) 210 X = read_columns(df, c=c, c_option=column_option).astype(float)
190 # sparse input 211 # sparse input
191 elif input_type == 'sparse': 212 elif input_type == "sparse":
192 X = mmread(open(infile1, 'r')) 213 X = mmread(open(infile1, "r"))
193 214
194 # fasta_file input 215 # fasta_file input
195 elif input_type == 'seq_fasta': 216 elif input_type == "seq_fasta":
196 pyfaidx = get_module('pyfaidx') 217 pyfaidx = get_module("pyfaidx")
197 sequences = pyfaidx.Fasta(fasta_path) 218 sequences = pyfaidx.Fasta(fasta_path)
198 n_seqs = len(sequences.keys()) 219 n_seqs = len(sequences.keys())
199 X = np.arange(n_seqs)[:, np.newaxis] 220 X = np.arange(n_seqs)[:, np.newaxis]
200 for param in estimator_params.keys(): 221 for param in estimator_params.keys():
201 if param.endswith('fasta_path'): 222 if param.endswith("fasta_path"):
202 estimator.set_params( 223 estimator.set_params(**{param: fasta_path})
203 **{param: fasta_path})
204 break 224 break
205 else: 225 else:
206 raise ValueError( 226 raise ValueError(
207 "The selected estimator doesn't support " 227 "The selected estimator doesn't support "
208 "fasta file input! Please consider using " 228 "fasta file input! Please consider using "
209 "KerasGBatchClassifier with " 229 "KerasGBatchClassifier with "
210 "FastaDNABatchGenerator/FastaProteinBatchGenerator " 230 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
211 "or having GenomeOneHotEncoder/ProteinOneHotEncoder " 231 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
212 "in pipeline!") 232 "in pipeline!"
213 233 )
214 elif input_type == 'refseq_and_interval': 234
235 elif input_type == "refseq_and_interval":
215 path_params = { 236 path_params = {
216 'data_batch_generator__ref_genome_path': ref_seq, 237 "data_batch_generator__ref_genome_path": ref_seq,
217 'data_batch_generator__intervals_path': intervals, 238 "data_batch_generator__intervals_path": intervals,
218 'data_batch_generator__target_path': targets 239 "data_batch_generator__target_path": targets,
219 } 240 }
220 estimator.set_params(**path_params) 241 estimator.set_params(**path_params)
221 n_intervals = sum(1 for line in open(intervals)) 242 n_intervals = sum(1 for line in open(intervals))
222 X = np.arange(n_intervals)[:, np.newaxis] 243 X = np.arange(n_intervals)[:, np.newaxis]
223 244
224 # Get target y 245 # Get target y
225 header = 'infer' if params['input_options']['header2'] else None 246 header = "infer" if params["input_options"]["header2"] else None
226 column_option = (params['input_options']['column_selector_options_2'] 247 column_option = params["input_options"]["column_selector_options_2"][
227 ['selected_column_selector_option2']) 248 "selected_column_selector_option2"
228 if column_option in ['by_index_number', 'all_but_by_index_number', 249 ]
229 'by_header_name', 'all_but_by_header_name']: 250 if column_option in [
230 c = params['input_options']['column_selector_options_2']['col2'] 251 "by_index_number",
252 "all_but_by_index_number",
253 "by_header_name",
254 "all_but_by_header_name",
255 ]:
256 c = params["input_options"]["column_selector_options_2"]["col2"]
231 else: 257 else:
232 c = None 258 c = None
233 259
234 df_key = infile2 + repr(header) 260 df_key = infile2 + repr(header)
235 if df_key in loaded_df: 261 if df_key in loaded_df:
236 infile2 = loaded_df[df_key] 262 infile2 = loaded_df[df_key]
237 else: 263 else:
238 infile2 = pd.read_csv(infile2, sep='\t', 264 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
239 header=header, parse_dates=True)
240 loaded_df[df_key] = infile2 265 loaded_df[df_key] = infile2
241 266
242 y = read_columns( 267 y = read_columns(infile2,
243 infile2, 268 c=c,
244 c=c, 269 c_option=column_option,
245 c_option=column_option, 270 sep='\t',
246 sep='\t', 271 header=header,
247 header=header, 272 parse_dates=True)
248 parse_dates=True)
249 if len(y.shape) == 2 and y.shape[1] == 1: 273 if len(y.shape) == 2 and y.shape[1] == 1:
250 y = y.ravel() 274 y = y.ravel()
251 if input_type == 'refseq_and_interval': 275 if input_type == "refseq_and_interval":
252 estimator.set_params( 276 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
253 data_batch_generator__features=y.ravel().tolist())
254 y = None 277 y = None
255 # end y 278 # end y
256 279
257 # load groups 280 # load groups
258 if groups: 281 if groups:
259 groups_selector = (params['experiment_schemes']['test_split'] 282 groups_selector = (
260 ['split_algos']).pop('groups_selector') 283 params["experiment_schemes"]["test_split"]["split_algos"]
261 284 ).pop("groups_selector")
262 header = 'infer' if groups_selector['header_g'] else None 285
263 column_option = \ 286 header = "infer" if groups_selector["header_g"] else None
264 (groups_selector['column_selector_options_g'] 287 column_option = groups_selector["column_selector_options_g"][
265 ['selected_column_selector_option_g']) 288 "selected_column_selector_option_g"
266 if column_option in ['by_index_number', 'all_but_by_index_number', 289 ]
267 'by_header_name', 'all_but_by_header_name']: 290 if column_option in [
268 c = groups_selector['column_selector_options_g']['col_g'] 291 "by_index_number",
292 "all_but_by_index_number",
293 "by_header_name",
294 "all_but_by_header_name",
295 ]:
296 c = groups_selector["column_selector_options_g"]["col_g"]
269 else: 297 else:
270 c = None 298 c = None
271 299
272 df_key = groups + repr(header) 300 df_key = groups + repr(header)
273 if df_key in loaded_df: 301 if df_key in loaded_df:
274 groups = loaded_df[df_key] 302 groups = loaded_df[df_key]
275 303
276 groups = read_columns( 304 groups = read_columns(groups,
277 groups, 305 c=c,
278 c=c, 306 c_option=column_option,
279 c_option=column_option, 307 sep='\t',
280 sep='\t', 308 header=header,
281 header=header, 309 parse_dates=True)
282 parse_dates=True)
283 groups = groups.ravel() 310 groups = groups.ravel()
284 311
285 # del loaded_df 312 # del loaded_df
286 del loaded_df 313 del loaded_df
287 314
288 # handle memory 315 # handle memory
289 memory = joblib.Memory(location=CACHE_DIR, verbose=0) 316 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
290 # cache iraps_core fits could increase search speed significantly 317 # cache iraps_core fits could increase search speed significantly
291 if estimator.__class__.__name__ == 'IRAPSClassifier': 318 if estimator.__class__.__name__ == "IRAPSClassifier":
292 estimator.set_params(memory=memory) 319 estimator.set_params(memory=memory)
293 else: 320 else:
294 # For iraps buried in pipeline 321 # For iraps buried in pipeline
295 new_params = {} 322 new_params = {}
296 for p, v in estimator_params.items(): 323 for p, v in estimator_params.items():
297 if p.endswith('memory'): 324 if p.endswith("memory"):
298 # for case of `__irapsclassifier__memory` 325 # for case of `__irapsclassifier__memory`
299 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 326 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
300 # cache iraps_core fits could increase search 327 # cache iraps_core fits could increase search
301 # speed significantly 328 # speed significantly
302 new_params[p] = memory 329 new_params[p] = memory
303 # security reason, we don't want memory being 330 # security reason, we don't want memory being
304 # modified unexpectedly 331 # modified unexpectedly
305 elif v: 332 elif v:
306 new_params[p] = None 333 new_params[p] = None
307 # handle n_jobs 334 # handle n_jobs
308 elif p.endswith('n_jobs'): 335 elif p.endswith("n_jobs"):
309 # For now, 1 CPU is suggested for iprasclassifier 336 # For now, 1 CPU is suggested for iprasclassifier
310 if len(p) > 8 and p[:-8].endswith('irapsclassifier'): 337 if len(p) > 8 and p[:-8].endswith("irapsclassifier"):
311 new_params[p] = 1 338 new_params[p] = 1
312 else: 339 else:
313 new_params[p] = N_JOBS 340 new_params[p] = N_JOBS
314 # for security reason, types of callback are limited 341 # for security reason, types of callback are limited
315 elif p.endswith('callbacks'): 342 elif p.endswith("callbacks"):
316 for cb in v: 343 for cb in v:
317 cb_type = cb['callback_selection']['callback_type'] 344 cb_type = cb["callback_selection"]["callback_type"]
318 if cb_type not in ALLOWED_CALLBACKS: 345 if cb_type not in ALLOWED_CALLBACKS:
319 raise ValueError( 346 raise ValueError("Prohibited callback type: %s!" % cb_type)
320 "Prohibited callback type: %s!" % cb_type)
321 347
322 estimator.set_params(**new_params) 348 estimator.set_params(**new_params)
323 349
324 # handle scorer, convert to scorer dict 350 # handle scorer, convert to scorer dict
325 scoring = params['experiment_schemes']['metrics']['scoring'] 351 # Check if scoring is specified
352 scoring = params["experiment_schemes"]["metrics"].get("scoring", None)
353 if scoring is not None:
354 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
355 # Check if secondary_scoring is specified
356 secondary_scoring = scoring.get("secondary_scoring", None)
357 if secondary_scoring is not None:
358 # If secondary_scoring is specified, convert the list into comman separated string
359 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
326 scorer = get_scoring(scoring) 360 scorer = get_scoring(scoring)
327 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) 361 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
328 362
329 # handle test (first) split 363 # handle test (first) split
330 test_split_options = (params['experiment_schemes'] 364 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
331 ['test_split']['split_algos']) 365
332 366 if test_split_options["shuffle"] == "group":
333 if test_split_options['shuffle'] == 'group': 367 test_split_options["labels"] = groups
334 test_split_options['labels'] = groups 368 if test_split_options["shuffle"] == "stratified":
335 if test_split_options['shuffle'] == 'stratified':
336 if y is not None: 369 if y is not None:
337 test_split_options['labels'] = y 370 test_split_options["labels"] = y
338 else: 371 else:
339 raise ValueError("Stratified shuffle split is not " 372 raise ValueError(
340 "applicable on empty target values!") 373 "Stratified shuffle split is not " "applicable on empty target values!"
341 374 )
342 X_train, X_test, y_train, y_test, groups_train, groups_test = \ 375
343 train_test_split_none(X, y, groups, **test_split_options) 376 X_train, X_test, y_train, y_test, groups_train, _groups_test = train_test_split_none(
344 377 X, y, groups, **test_split_options
345 exp_scheme = params['experiment_schemes']['selected_exp_scheme'] 378 )
379
380 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
346 381
347 # handle validation (second) split 382 # handle validation (second) split
348 if exp_scheme == 'train_val_test': 383 if exp_scheme == "train_val_test":
349 val_split_options = (params['experiment_schemes'] 384 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
350 ['val_split']['split_algos']) 385
351 386 if val_split_options["shuffle"] == "group":
352 if val_split_options['shuffle'] == 'group': 387 val_split_options["labels"] = groups_train
353 val_split_options['labels'] = groups_train 388 if val_split_options["shuffle"] == "stratified":
354 if val_split_options['shuffle'] == 'stratified':
355 if y_train is not None: 389 if y_train is not None:
356 val_split_options['labels'] = y_train 390 val_split_options["labels"] = y_train
357 else: 391 else:
358 raise ValueError("Stratified shuffle split is not " 392 raise ValueError(
359 "applicable on empty target values!") 393 "Stratified shuffle split is not "
360 394 "applicable on empty target values!"
361 X_train, X_val, y_train, y_val, groups_train, groups_val = \ 395 )
362 train_test_split_none(X_train, y_train, groups_train, 396
363 **val_split_options) 397 (
398 X_train,
399 X_val,
400 y_train,
401 y_val,
402 groups_train,
403 _groups_val,
404 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
364 405
365 # train and eval 406 # train and eval
366 if hasattr(estimator, 'validation_data'): 407 if hasattr(estimator, "validation_data"):
367 if exp_scheme == 'train_val_test': 408 if exp_scheme == "train_val_test":
368 estimator.fit(X_train, y_train, 409 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
369 validation_data=(X_val, y_val)) 410 else:
370 else: 411 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
371 estimator.fit(X_train, y_train,
372 validation_data=(X_test, y_test))
373 else: 412 else:
374 estimator.fit(X_train, y_train) 413 estimator.fit(X_train, y_train)
375 414
376 if hasattr(estimator, 'evaluate'): 415 if hasattr(estimator, "evaluate"):
377 scores = estimator.evaluate(X_test, y_test=y_test, 416 scores = estimator.evaluate(
378 scorer=scorer, 417 X_test, y_test=y_test, scorer=scorer, is_multimetric=True
379 is_multimetric=True) 418 )
380 else: 419 else:
381 scores = _score(estimator, X_test, y_test, scorer, 420 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True)
382 is_multimetric=True)
383 # handle output 421 # handle output
384 for name, score in scores.items(): 422 for name, score in scores.items():
385 scores[name] = [score] 423 scores[name] = [score]
386 df = pd.DataFrame(scores) 424 df = pd.DataFrame(scores)
387 df = df[sorted(df.columns)] 425 df = df[sorted(df.columns)]
388 df.to_csv(path_or_buf=outfile_result, sep='\t', 426 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
389 header=True, index=False)
390 427
391 memory.clear(warn=False) 428 memory.clear(warn=False)
392 429
393 if outfile_object: 430 if outfile_object:
394 main_est = estimator 431 main_est = estimator
395 if isinstance(estimator, pipeline.Pipeline): 432 if isinstance(estimator, pipeline.Pipeline):
396 main_est = estimator.steps[-1][-1] 433 main_est = estimator.steps[-1][-1]
397 434
398 if hasattr(main_est, 'model_') \ 435 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
399 and hasattr(main_est, 'save_weights'):
400 if outfile_weights: 436 if outfile_weights:
401 main_est.save_weights(outfile_weights) 437 main_est.save_weights(outfile_weights)
402 del main_est.model_ 438 if getattr(main_est, "model_", None):
403 del main_est.fit_params 439 del main_est.model_
404 del main_est.model_class_ 440 if getattr(main_est, "fit_params", None):
405 del main_est.validation_data 441 del main_est.fit_params
406 if getattr(main_est, 'data_generator_', None): 442 if getattr(main_est, "model_class_", None):
443 del main_est.model_class_
444 if getattr(main_est, "validation_data", None):
445 del main_est.validation_data
446 if getattr(main_est, "data_generator_", None):
407 del main_est.data_generator_ 447 del main_est.data_generator_
408 448
409 with open(outfile_object, 'wb') as output_handler: 449 with open(outfile_object, "wb") as output_handler:
410 pickle.dump(estimator, output_handler, 450 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
411 pickle.HIGHEST_PROTOCOL) 451
412 452
413 453 if __name__ == "__main__":
414 if __name__ == '__main__':
415 aparser = argparse.ArgumentParser() 454 aparser = argparse.ArgumentParser()
416 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 455 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
417 aparser.add_argument("-e", "--estimator", dest="infile_estimator") 456 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
418 aparser.add_argument("-X", "--infile1", dest="infile1") 457 aparser.add_argument("-X", "--infile1", dest="infile1")
419 aparser.add_argument("-y", "--infile2", dest="infile2") 458 aparser.add_argument("-y", "--infile2", dest="infile2")
425 aparser.add_argument("-b", "--intervals", dest="intervals") 464 aparser.add_argument("-b", "--intervals", dest="intervals")
426 aparser.add_argument("-t", "--targets", dest="targets") 465 aparser.add_argument("-t", "--targets", dest="targets")
427 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 466 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
428 args = aparser.parse_args() 467 args = aparser.parse_args()
429 468
430 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, 469 main(
431 args.outfile_result, outfile_object=args.outfile_object, 470 args.inputs,
432 outfile_weights=args.outfile_weights, groups=args.groups, 471 args.infile_estimator,
433 ref_seq=args.ref_seq, intervals=args.intervals, 472 args.infile1,
434 targets=args.targets, fasta_path=args.fasta_path) 473 args.infile2,
474 args.outfile_result,
475 outfile_object=args.outfile_object,
476 outfile_weights=args.outfile_weights,
477 groups=args.groups,
478 ref_seq=args.ref_seq,
479 intervals=args.intervals,
480 targets=args.targets,
481 fasta_path=args.fasta_path,
482 )