Mercurial > repos > goeckslab > image_learner
comparison split_data.py @ 23:2c6624cae3c5 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
| author | goeckslab |
|---|---|
| date | Sun, 25 Jan 2026 01:09:56 +0000 |
| parents | 64872c48a21f |
| children |
comparison
equal
deleted
inserted
replaced
| 22:ccbcc012d78d | 23:2c6624cae3c5 |
|---|---|
| 95 group_column, | 95 group_column, |
| 96 ) | 96 ) |
| 97 group_column = None | 97 group_column = None |
| 98 | 98 |
| 99 def _allocate_split_counts(n_total: int, probs: list) -> list: | 99 def _allocate_split_counts(n_total: int, probs: list) -> list: |
| 100 """Allocate exact split counts using largest remainder rounding.""" | 100 """Allocate exact split counts using largest remainder rounding. |
| 101 | |
| 102 Ensures at least one sample per active split *after* proportional allocation, | |
| 103 by moving samples from other splits when possible. Tie-breaking for leftovers | |
| 104 and corrections prioritizes: train (0), test (2), validation (1). | |
| 105 """ | |
| 101 if n_total <= 0: | 106 if n_total <= 0: |
| 102 return [0 for _ in probs] | 107 return [0 for _ in probs] |
| 103 | 108 |
| 104 counts = [0 for _ in probs] | 109 counts = [0 for _ in probs] |
| 105 active = [i for i, p in enumerate(probs) if p > 0] | 110 active = [i for i, p in enumerate(probs) if p > 0] |
| 106 remainder = n_total | 111 if not active: |
| 107 | 112 return counts |
| 108 if active and n_total >= len(active): | 113 |
| 109 for i in active: | 114 # If there are fewer samples than active splits, fill in order and return. |
| 110 counts[i] = 1 | 115 if n_total < len(active): |
| 111 remainder -= len(active) | 116 priority_order = [0, 2, 1] |
| 112 | 117 ordered = [i for i in priority_order if i in active] + [ |
| 113 if remainder > 0: | 118 i for i in active if i not in priority_order |
| 114 probs_arr = np.array(probs, dtype=float) | 119 ] |
| 115 probs_arr = probs_arr / probs_arr.sum() | 120 remaining = n_total |
| 116 raw = remainder * probs_arr | 121 for idx in ordered: |
| 117 floors = np.floor(raw).astype(int) | 122 if remaining <= 0: |
| 118 for i, value in enumerate(floors.tolist()): | 123 break |
| 119 counts[i] += value | 124 counts[idx] = 1 |
| 120 leftover = remainder - int(floors.sum()) | 125 remaining -= 1 |
| 121 if leftover > 0 and active: | 126 return counts |
| 122 frac = raw - floors | 127 |
| 123 order = sorted(active, key=lambda i: (-frac[i], i)) | 128 probs_arr = np.array(probs, dtype=float) |
| 124 for i in range(leftover): | 129 total_prob = probs_arr.sum() |
| 125 counts[order[i % len(order)]] += 1 | 130 if total_prob <= 0: |
| 131 return counts | |
| 132 probs_arr = probs_arr / total_prob | |
| 133 raw = n_total * probs_arr | |
| 134 floors = np.floor(raw).astype(int) | |
| 135 counts = floors.tolist() | |
| 136 | |
| 137 leftover = n_total - int(floors.sum()) | |
| 138 if leftover > 0: | |
| 139 frac = raw - floors | |
| 140 priority_order = [0, 2, 1] | |
| 141 order = sorted( | |
| 142 active, | |
| 143 key=lambda i: (-frac[i], priority_order.index(i) if i in priority_order else 999), | |
| 144 ) | |
| 145 for i in range(leftover): | |
| 146 counts[order[i % len(order)]] += 1 | |
| 147 | |
| 148 # Ensure at least one per active split by moving from other splits. | |
| 149 missing = [i for i in active if counts[i] == 0] | |
| 150 if missing: | |
| 151 priority_order = [0, 2, 1] | |
| 152 missing_ordered = [i for i in priority_order if i in missing] + [ | |
| 153 i for i in missing if i not in priority_order | |
| 154 ] | |
| 155 for idx in missing_ordered: | |
| 156 donors = [i for i in active if counts[i] > 1 and i != idx] | |
| 157 if not donors: | |
| 158 break | |
| 159 # Prefer taking from lower-priority splits first (val -> test -> train) | |
| 160 donor_priority = [1, 2, 0] | |
| 161 donors_sorted = sorted( | |
| 162 donors, | |
| 163 key=lambda i: ( | |
| 164 -counts[i], | |
| 165 donor_priority.index(i) if i in donor_priority else 999, | |
| 166 ), | |
| 167 ) | |
| 168 donor = donors_sorted[0] | |
| 169 counts[donor] -= 1 | |
| 170 counts[idx] += 1 | |
| 126 | 171 |
| 127 return counts | 172 return counts |
| 128 | 173 |
| 129 def _choose_split(counts: list, targets: list, active: list) -> int: | 174 def _choose_split(counts: list, targets: list, active: list) -> int: |
| 130 remaining = [targets[i] - counts[i] for i in range(len(targets))] | 175 remaining = [targets[i] - counts[i] for i in range(len(targets))] |
