changeset 23:2c6624cae3c5 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 4fade0f8199988fd9cf56cbcb49fa4b949b659ec
author goeckslab
date Sun, 25 Jan 2026 01:09:56 +0000
parents ccbcc012d78d
children
files split_data.py
diffstat 1 files changed, 64 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/split_data.py	Fri Jan 23 20:25:27 2026 +0000
+++ b/split_data.py	Sun Jan 25 01:09:56 2026 +0000
@@ -97,32 +97,77 @@
         group_column = None
 
     def _allocate_split_counts(n_total: int, probs: list) -> list:
-        """Allocate exact split counts using largest remainder rounding."""
+        """Allocate exact split counts using largest remainder rounding.
+
+        Ensures at least one sample per active split *after* proportional allocation,
+        by moving samples from other splits when possible. Tie-breaking for leftovers
+        and corrections prioritizes: train (0), test (2), validation (1).
+        """
         if n_total <= 0:
             return [0 for _ in probs]
 
         counts = [0 for _ in probs]
         active = [i for i, p in enumerate(probs) if p > 0]
-        remainder = n_total
+        if not active:
+            return counts
 
-        if active and n_total >= len(active):
-            for i in active:
-                counts[i] = 1
-            remainder -= len(active)
+        # If there are fewer samples than active splits, fill in order and return.
+        if n_total < len(active):
+            priority_order = [0, 2, 1]
+            ordered = [i for i in priority_order if i in active] + [
+                i for i in active if i not in priority_order
+            ]
+            remaining = n_total
+            for idx in ordered:
+                if remaining <= 0:
+                    break
+                counts[idx] = 1
+                remaining -= 1
+            return counts
+
+        probs_arr = np.array(probs, dtype=float)
+        total_prob = probs_arr.sum()
+        if total_prob <= 0:
+            return counts
+        probs_arr = probs_arr / total_prob
+        raw = n_total * probs_arr
+        floors = np.floor(raw).astype(int)
+        counts = floors.tolist()
 
-        if remainder > 0:
-            probs_arr = np.array(probs, dtype=float)
-            probs_arr = probs_arr / probs_arr.sum()
-            raw = remainder * probs_arr
-            floors = np.floor(raw).astype(int)
-            for i, value in enumerate(floors.tolist()):
-                counts[i] += value
-            leftover = remainder - int(floors.sum())
-            if leftover > 0 and active:
-                frac = raw - floors
-                order = sorted(active, key=lambda i: (-frac[i], i))
-                for i in range(leftover):
-                    counts[order[i % len(order)]] += 1
+        leftover = n_total - int(floors.sum())
+        if leftover > 0:
+            frac = raw - floors
+            priority_order = [0, 2, 1]
+            order = sorted(
+                active,
+                key=lambda i: (-frac[i], priority_order.index(i) if i in priority_order else 999),
+            )
+            for i in range(leftover):
+                counts[order[i % len(order)]] += 1
+
+        # Ensure at least one per active split by moving from other splits.
+        missing = [i for i in active if counts[i] == 0]
+        if missing:
+            priority_order = [0, 2, 1]
+            missing_ordered = [i for i in priority_order if i in missing] + [
+                i for i in missing if i not in priority_order
+            ]
+            for idx in missing_ordered:
+                donors = [i for i in active if counts[i] > 1 and i != idx]
+                if not donors:
+                    break
+                # Prefer taking from lower-priority splits first (val -> test -> train)
+                donor_priority = [1, 2, 0]
+                donors_sorted = sorted(
+                    donors,
+                    key=lambda i: (
+                        -counts[i],
+                        donor_priority.index(i) if i in donor_priority else 999,
+                    ),
+                )
+                donor = donors_sorted[0]
+                counts[donor] -= 1
+                counts[idx] += 1
 
         return counts