Mercurial > repos > goeckslab > multimodal_learner
annotate split_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
| rev | line source |
|---|---|
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1 import logging |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
2 from typing import List, Optional |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
3 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
4 import pandas as pd |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
5 from sklearn.model_selection import train_test_split |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
6 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
7 logger = logging.getLogger(__name__) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
8 SPLIT_COL = "split" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
9 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
10 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
11 def _can_stratify(y: pd.Series) -> bool: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
12 return y.nunique() >= 2 and (y.value_counts() >= 2).all() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
13 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
14 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
15 def split_dataset( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
16 train_dataset: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
17 test_dataset: Optional[pd.DataFrame], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
18 target_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
19 split_probabilities: List[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
20 validation_size: float, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
21 random_seed: int = 42, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
22 ) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
23 if target_column not in train_dataset.columns: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
24 raise ValueError(f"Target column '{target_column}' not found") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
25 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
26 # Drop NaN labels early |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
27 before = len(train_dataset) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
28 train_dataset.dropna(subset=[target_column], inplace=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
29 if len(train_dataset) == 0: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
30 raise ValueError("No rows remain after dropping NaN targets") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
31 if before != len(train_dataset): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
32 logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
33 y = train_dataset[target_column] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
34 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
35 # Respect existing valid split column |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
36 if SPLIT_COL in train_dataset.columns: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
37 unique = set(train_dataset[SPLIT_COL].dropna().unique()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
38 valid = {"train", "val", "validation", "test"} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
39 if unique.issubset(valid | {"validation"}): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
40 train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
41 logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
42 return |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
43 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
44 train_dataset[SPLIT_COL] = "train" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
45 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
46 if test_dataset is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
47 stratify = y if _can_stratify(y) else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
48 train_idx, val_idx = train_test_split( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
49 train_dataset.index, test_size=validation_size, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
50 random_state=random_seed, stratify=stratify |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
51 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
52 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
53 logger.info(f"External test set → created val split ({validation_size:.0%})") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
54 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
55 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
56 p_train, p_val, p_test = split_probabilities |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
57 if abs(p_train + p_val + p_test - 1.0) > 1e-6: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
58 raise ValueError("split_probabilities must sum to 1.0") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
59 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
60 stratify = y if _can_stratify(y) else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
61 tv_idx, test_idx = train_test_split( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
62 train_dataset.index, test_size=p_test, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
63 random_state=random_seed, stratify=stratify |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
64 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
65 rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
66 strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
67 train_idx, val_idx = train_test_split( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
68 tv_idx, test_size=rel_val, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
69 random_state=random_seed, stratify=strat_tv |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
70 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
71 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
72 train_dataset.loc[val_idx, SPLIT_COL] = "val" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
73 train_dataset.loc[test_idx, SPLIT_COL] = "test" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
74 logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
75 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
76 logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}") |
