Mercurial > repos > goeckslab > image_learner
comparison split_data.py @ 12:bcfa2e234a80 draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
| author | goeckslab |
|---|---|
| date | Fri, 21 Nov 2025 15:58:13 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 11:c5150cceab47 | 12:bcfa2e234a80 |
|---|---|
| 1 import argparse | |
| 2 import logging | |
| 3 from typing import Optional | |
| 4 | |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 from sklearn.model_selection import train_test_split | |
| 8 | |
| 9 logger = logging.getLogger("ImageLearner") | |
| 10 | |
| 11 | |
| 12 def split_data_0_2( | |
| 13 df: pd.DataFrame, | |
| 14 split_column: str, | |
| 15 validation_size: float = 0.1, | |
| 16 random_state: int = 42, | |
| 17 label_column: Optional[str] = None, | |
| 18 ) -> pd.DataFrame: | |
| 19 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" | |
| 20 out = df.copy() | |
| 21 out[split_column] = pd.to_numeric(out[split_column], errors="coerce").astype(int) | |
| 22 | |
| 23 idx_train = out.index[out[split_column] == 0].tolist() | |
| 24 | |
| 25 if not idx_train: | |
| 26 logger.info("No rows with split=0; nothing to do.") | |
| 27 return out | |
| 28 stratify_arr = None | |
| 29 if label_column and label_column in out.columns: | |
| 30 label_counts = out.loc[idx_train, label_column].value_counts() | |
| 31 if label_counts.size > 1: | |
| 32 # Force stratify even with fewer samples - adjust validation_size if needed | |
| 33 min_samples_per_class = label_counts.min() | |
| 34 if min_samples_per_class * validation_size < 1: | |
| 35 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | |
| 36 adjusted_validation_size = min( | |
| 37 validation_size, 1.0 / min_samples_per_class | |
| 38 ) | |
| 39 if adjusted_validation_size != validation_size: | |
| 40 validation_size = adjusted_validation_size | |
| 41 logger.info( | |
| 42 f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" | |
| 43 ) | |
| 44 stratify_arr = out.loc[idx_train, label_column] | |
| 45 logger.info("Using stratified split for validation set") | |
| 46 else: | |
| 47 logger.warning("Only one label class found; cannot stratify") | |
| 48 if validation_size <= 0: | |
| 49 logger.info("validation_size <= 0; keeping all as train.") | |
| 50 return out | |
| 51 if validation_size >= 1: | |
| 52 logger.info("validation_size >= 1; moving all train → validation.") | |
| 53 out.loc[idx_train, split_column] = 1 | |
| 54 return out | |
| 55 # Always try stratified split first | |
| 56 try: | |
| 57 train_idx, val_idx = train_test_split( | |
| 58 idx_train, | |
| 59 test_size=validation_size, | |
| 60 random_state=random_state, | |
| 61 stratify=stratify_arr, | |
| 62 ) | |
| 63 logger.info("Successfully applied stratified split") | |
| 64 except ValueError as e: | |
| 65 logger.warning(f"Stratified split failed ({e}); falling back to random split.") | |
| 66 train_idx, val_idx = train_test_split( | |
| 67 idx_train, | |
| 68 test_size=validation_size, | |
| 69 random_state=random_state, | |
| 70 stratify=None, | |
| 71 ) | |
| 72 out.loc[train_idx, split_column] = 0 | |
| 73 out.loc[val_idx, split_column] = 1 | |
| 74 out[split_column] = out[split_column].astype(int) | |
| 75 return out | |
| 76 | |
| 77 | |
| 78 def create_stratified_random_split( | |
| 79 df: pd.DataFrame, | |
| 80 split_column: str, | |
| 81 split_probabilities: list = [0.7, 0.1, 0.2], | |
| 82 random_state: int = 42, | |
| 83 label_column: Optional[str] = None, | |
| 84 ) -> pd.DataFrame: | |
| 85 """Create a stratified random split when no split column exists.""" | |
| 86 out = df.copy() | |
| 87 | |
| 88 # initialize split column | |
| 89 out[split_column] = 0 | |
| 90 | |
| 91 if not label_column or label_column not in out.columns: | |
| 92 logger.warning( | |
| 93 "No label column found; using random split without stratification" | |
| 94 ) | |
| 95 # fall back to simple random assignment | |
| 96 indices = out.index.tolist() | |
| 97 np.random.seed(random_state) | |
| 98 np.random.shuffle(indices) | |
| 99 | |
| 100 n_total = len(indices) | |
| 101 n_train = int(n_total * split_probabilities[0]) | |
| 102 n_val = int(n_total * split_probabilities[1]) | |
| 103 | |
| 104 out.loc[indices[:n_train], split_column] = 0 | |
| 105 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 106 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 107 | |
| 108 return out.astype({split_column: int}) | |
| 109 | |
| 110 # check if stratification is possible | |
| 111 label_counts = out[label_column].value_counts() | |
| 112 min_samples_per_class = label_counts.min() | |
| 113 | |
| 114 # ensure we have enough samples for stratification: | |
| 115 # Each class must have at least as many samples as the number of splits, | |
| 116 # so that each split can receive at least one sample per class. | |
| 117 min_samples_required = len(split_probabilities) | |
| 118 if min_samples_per_class < min_samples_required: | |
| 119 logger.warning( | |
| 120 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" | |
| 121 ) | |
| 122 # fall back to simple random assignment | |
| 123 indices = out.index.tolist() | |
| 124 np.random.seed(random_state) | |
| 125 np.random.shuffle(indices) | |
| 126 | |
| 127 n_total = len(indices) | |
| 128 n_train = int(n_total * split_probabilities[0]) | |
| 129 n_val = int(n_total * split_probabilities[1]) | |
| 130 | |
| 131 out.loc[indices[:n_train], split_column] = 0 | |
| 132 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 133 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 134 | |
| 135 return out.astype({split_column: int}) | |
| 136 | |
| 137 logger.info("Using stratified random split for train/validation/test sets") | |
| 138 | |
| 139 # first split: separate test set | |
| 140 train_val_idx, test_idx = train_test_split( | |
| 141 out.index.tolist(), | |
| 142 test_size=split_probabilities[2], | |
| 143 random_state=random_state, | |
| 144 stratify=out[label_column], | |
| 145 ) | |
| 146 | |
| 147 # second split: separate training and validation from remaining data | |
| 148 val_size_adjusted = split_probabilities[1] / ( | |
| 149 split_probabilities[0] + split_probabilities[1] | |
| 150 ) | |
| 151 train_idx, val_idx = train_test_split( | |
| 152 train_val_idx, | |
| 153 test_size=val_size_adjusted, | |
| 154 random_state=random_state, | |
| 155 stratify=out.loc[train_val_idx, label_column] if label_column and label_column in out.columns else None, | |
| 156 ) | |
| 157 | |
| 158 # assign split values | |
| 159 out.loc[train_idx, split_column] = 0 | |
| 160 out.loc[val_idx, split_column] = 1 | |
| 161 out.loc[test_idx, split_column] = 2 | |
| 162 | |
| 163 logger.info("Successfully applied stratified random split") | |
| 164 logger.info( | |
| 165 f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" | |
| 166 ) | |
| 167 return out.astype({split_column: int}) | |
| 168 | |
| 169 | |
| 170 class SplitProbAction(argparse.Action): | |
| 171 def __call__(self, parser, namespace, values, option_string=None): | |
| 172 train, val, test = values | |
| 173 total = train + val + test | |
| 174 if abs(total - 1.0) > 1e-6: | |
| 175 parser.error( | |
| 176 f"--split-probabilities must sum to 1.0; " | |
| 177 f"got {train:.3f} + {val:.3f} + {test:.3f} = {total:.3f}" | |
| 178 ) | |
| 179 setattr(namespace, self.dest, values) |
