Mercurial > repos > goeckslab > multimodal_learner
comparison split_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:375c36923da1 |
|---|---|
| 1 import logging | |
| 2 from typing import List, Optional | |
| 3 | |
| 4 import pandas as pd | |
| 5 from sklearn.model_selection import train_test_split | |
| 6 | |
| 7 logger = logging.getLogger(__name__) | |
| 8 SPLIT_COL = "split" | |
| 9 | |
| 10 | |
| 11 def _can_stratify(y: pd.Series) -> bool: | |
| 12 return y.nunique() >= 2 and (y.value_counts() >= 2).all() | |
| 13 | |
| 14 | |
| 15 def split_dataset( | |
| 16 train_dataset: pd.DataFrame, | |
| 17 test_dataset: Optional[pd.DataFrame], | |
| 18 target_column: str, | |
| 19 split_probabilities: List[float], | |
| 20 validation_size: float, | |
| 21 random_seed: int = 42, | |
| 22 ) -> None: | |
| 23 if target_column not in train_dataset.columns: | |
| 24 raise ValueError(f"Target column '{target_column}' not found") | |
| 25 | |
| 26 # Drop NaN labels early | |
| 27 before = len(train_dataset) | |
| 28 train_dataset.dropna(subset=[target_column], inplace=True) | |
| 29 if len(train_dataset) == 0: | |
| 30 raise ValueError("No rows remain after dropping NaN targets") | |
| 31 if before != len(train_dataset): | |
| 32 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") | |
| 33 y = train_dataset[target_column] | |
| 34 | |
| 35 # Respect existing valid split column | |
| 36 if SPLIT_COL in train_dataset.columns: | |
| 37 unique = set(train_dataset[SPLIT_COL].dropna().unique()) | |
| 38 valid = {"train", "val", "validation", "test"} | |
| 39 if unique.issubset(valid | {"validation"}): | |
| 40 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") | |
| 41 logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") | |
| 42 return | |
| 43 | |
| 44 train_dataset[SPLIT_COL] = "train" | |
| 45 | |
| 46 if test_dataset is not None: | |
| 47 stratify = y if _can_stratify(y) else None | |
| 48 train_idx, val_idx = train_test_split( | |
| 49 train_dataset.index, test_size=validation_size, | |
| 50 random_state=random_seed, stratify=stratify | |
| 51 ) | |
| 52 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 53 logger.info(f"External test set → created val split ({validation_size:.0%})") | |
| 54 | |
| 55 else: | |
| 56 p_train, p_val, p_test = split_probabilities | |
| 57 if abs(p_train + p_val + p_test - 1.0) > 1e-6: | |
| 58 raise ValueError("split_probabilities must sum to 1.0") | |
| 59 | |
| 60 stratify = y if _can_stratify(y) else None | |
| 61 tv_idx, test_idx = train_test_split( | |
| 62 train_dataset.index, test_size=p_test, | |
| 63 random_state=random_seed, stratify=stratify | |
| 64 ) | |
| 65 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 | |
| 66 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None | |
| 67 train_idx, val_idx = train_test_split( | |
| 68 tv_idx, test_size=rel_val, | |
| 69 random_state=random_seed, stratify=strat_tv | |
| 70 ) | |
| 71 | |
| 72 train_dataset.loc[val_idx, SPLIT_COL] = "val" | |
| 73 train_dataset.loc[test_idx, SPLIT_COL] = "test" | |
| 74 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") | |
| 75 | |
| 76 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") |
