Mercurial > repos > bgruening > ml_visualization_ex
diff train_test_split.py @ 14:9c19cf3c4ea0 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:08:43 +0000 |
parents | 4a229a7ad638 |
children |
line wrap: on
line diff
--- a/train_test_split.py Thu Aug 11 09:29:32 2022 +0000 +++ b/train_test_split.py Wed Aug 09 13:08:43 2023 +0000 @@ -1,8 +1,10 @@ import argparse import json import warnings +from distutils.version import LooseVersion as Version import pandas as pd +from galaxy_ml import __version__ as galaxy_ml_version from galaxy_ml.model_validations import train_test_split from galaxy_ml.utils import get_cv, read_columns @@ -69,7 +71,10 @@ y = df.iloc[:, col_index].values # construct the cv splitter object - splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) + cv_selector = params["mode_selection"]["cv_selector"] + if Version(galaxy_ml_version) < Version("0.8.3"): + cv_selector.pop("n_stratification_bins", None) + splitter, groups = get_cv(cv_selector) total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) if nth_split > total_n_splits: