diff train_test_split.py @ 17:a01fa4e8fe4f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 12:54:40 +0000
parents c9ddd20d25d0
children
line wrap: on
line diff
--- a/train_test_split.py	Thu Aug 11 09:52:07 2022 +0000
+++ b/train_test_split.py	Wed Aug 09 12:54:40 2023 +0000
@@ -1,8 +1,10 @@
 import argparse
 import json
 import warnings
+from distutils.version import LooseVersion as Version
 
 import pandas as pd
+from galaxy_ml import __version__ as galaxy_ml_version
 from galaxy_ml.model_validations import train_test_split
 from galaxy_ml.utils import get_cv, read_columns
 
@@ -69,7 +71,10 @@
         y = df.iloc[:, col_index].values
 
     # construct the cv splitter object
-    splitter, groups = get_cv(params["mode_selection"]["cv_selector"])
+    cv_selector = params["mode_selection"]["cv_selector"]
+    if Version(galaxy_ml_version) < Version("0.8.3"):
+        cv_selector.pop("n_stratification_bins", None)
+    splitter, groups = get_cv(cv_selector)
 
     total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
     if nth_split > total_n_splits: