Mercurial > repos > goeckslab > multimodal_learner
diff split_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/split_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,76 @@ +import logging +from typing import List, Optional + +import pandas as pd +from sklearn.model_selection import train_test_split + +logger = logging.getLogger(__name__) +SPLIT_COL = "split" + + +def _can_stratify(y: pd.Series) -> bool: + return y.nunique() >= 2 and (y.value_counts() >= 2).all() + + +def split_dataset( + train_dataset: pd.DataFrame, + test_dataset: Optional[pd.DataFrame], + target_column: str, + split_probabilities: List[float], + validation_size: float, + random_seed: int = 42, +) -> None: + if target_column not in train_dataset.columns: + raise ValueError(f"Target column '{target_column}' not found") + + # Drop NaN labels early + before = len(train_dataset) + train_dataset.dropna(subset=[target_column], inplace=True) + if len(train_dataset) == 0: + raise ValueError("No rows remain after dropping NaN targets") + if before != len(train_dataset): + logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") + y = train_dataset[target_column] + + # Respect existing valid split column + if SPLIT_COL in train_dataset.columns: + unique = set(train_dataset[SPLIT_COL].dropna().unique()) + valid = {"train", "val", "validation", "test"} + if unique.issubset(valid | {"validation"}): + train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") + logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") + return + + train_dataset[SPLIT_COL] = "train" + + if test_dataset is not None: + stratify = y if _can_stratify(y) else None + train_idx, val_idx = train_test_split( + train_dataset.index, test_size=validation_size, + random_state=random_seed, stratify=stratify + ) + train_dataset.loc[val_idx, SPLIT_COL] = "val" + logger.info(f"External test set → created val split ({validation_size:.0%})") + + else: + p_train, p_val, p_test = split_probabilities + if abs(p_train + p_val + p_test - 1.0) > 1e-6: + raise ValueError("split_probabilities must sum to 1.0") + + stratify = y if _can_stratify(y) else None + tv_idx, test_idx = train_test_split( + train_dataset.index, test_size=p_test, + random_state=random_seed, stratify=stratify + ) + rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 + strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None + train_idx, val_idx = train_test_split( + tv_idx, test_size=rel_val, + random_state=random_seed, stratify=strat_tv + ) + + train_dataset.loc[val_idx, SPLIT_COL] = "val" + train_dataset.loc[test_idx, SPLIT_COL] = "test" + logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") + + logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")
