comparison train_test_eval.py @ 30:4b359039f09f draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:03:56 +0000
parents de360b57a5ab
children 1fe00785190d
comparison
equal deleted inserted replaced
29:de360b57a5ab 30:4b359039f09f
7 7
8 import joblib 8 import joblib
9 import numpy as np 9 import numpy as np
10 import pandas as pd 10 import pandas as pd
11 from galaxy_ml.model_validations import train_test_split 11 from galaxy_ml.model_validations import train_test_split
12 from galaxy_ml.utils import ( 12 from galaxy_ml.utils import (get_module, get_scoring, load_model,
13 get_module, 13 read_columns, SafeEval, try_get_attr)
14 get_scoring,
15 load_model,
16 read_columns,
17 SafeEval,
18 try_get_attr,
19 )
20 from scipy.io import mmread 14 from scipy.io import mmread
21 from sklearn import pipeline 15 from sklearn import pipeline
22 from sklearn.metrics.scorer import _check_multimetric_scoring 16 from sklearn.metrics.scorer import _check_multimetric_scoring
23 from sklearn.model_selection import _search, _validation 17 from sklearn.model_selection import _search, _validation
24 from sklearn.model_selection._validation import _score 18 from sklearn.model_selection._validation import _score
25 from sklearn.utils import indexable, safe_indexing 19 from sklearn.utils import indexable, safe_indexing
26
27 20
28 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") 21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
29 setattr(_search, "_fit_and_score", _fit_and_score) 22 setattr(_search, "_fit_and_score", _fit_and_score)
30 setattr(_validation, "_fit_and_score", _fit_and_score) 23 setattr(_validation, "_fit_and_score", _fit_and_score)
31 24
260 infile2 = loaded_df[df_key] 253 infile2 = loaded_df[df_key]
261 else: 254 else:
262 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) 255 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
263 loaded_df[df_key] = infile2 256 loaded_df[df_key] = infile2
264 257
265 y = read_columns(infile2, 258 y = read_columns(
266 c=c, 259 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
267 c_option=column_option, 260 )
268 sep='\t',
269 header=header,
270 parse_dates=True)
271 if len(y.shape) == 2 and y.shape[1] == 1: 261 if len(y.shape) == 2 and y.shape[1] == 1:
272 y = y.ravel() 262 y = y.ravel()
273 if input_type == "refseq_and_interval": 263 if input_type == "refseq_and_interval":
274 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) 264 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
275 y = None 265 y = None
297 287
298 df_key = groups + repr(header) 288 df_key = groups + repr(header)
299 if df_key in loaded_df: 289 if df_key in loaded_df:
300 groups = loaded_df[df_key] 290 groups = loaded_df[df_key]
301 291
302 groups = read_columns(groups, 292 groups = read_columns(
303 c=c, 293 groups,
304 c_option=column_option, 294 c=c,
305 sep='\t', 295 c_option=column_option,
306 header=header, 296 sep="\t",
307 parse_dates=True) 297 header=header,
298 parse_dates=True,
299 )
308 groups = groups.ravel() 300 groups = groups.ravel()
309 301
310 # del loaded_df 302 # del loaded_df
311 del loaded_df 303 del loaded_df
312 304
369 else: 361 else:
370 raise ValueError( 362 raise ValueError(
371 "Stratified shuffle split is not " "applicable on empty target values!" 363 "Stratified shuffle split is not " "applicable on empty target values!"
372 ) 364 )
373 365
374 X_train, X_test, y_train, y_test, groups_train, _groups_test = train_test_split_none( 366 (
375 X, y, groups, **test_split_options 367 X_train,
376 ) 368 X_test,
369 y_train,
370 y_test,
371 groups_train,
372 _groups_test,
373 ) = train_test_split_none(X, y, groups, **test_split_options)
377 374
378 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"] 375 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
379 376
380 # handle validation (second) split 377 # handle validation (second) split
381 if exp_scheme == "train_val_test": 378 if exp_scheme == "train_val_test":