Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 7:801a8b6973fb draft
planemo upload for repository https://github.com/goeckslab/gleam.git commit 67df782ea551181e1d240d463764016ba528eba9
| author | goeckslab |
|---|---|
| date | Fri, 08 Aug 2025 13:06:28 +0000 |
| parents | 09904b1f61f5 |
| children | 85e6f4b2ad18 |
comparison
equal
deleted
inserted
replaced
| 6:09904b1f61f5 | 7:801a8b6973fb |
|---|---|
| 7 import tempfile | 7 import tempfile |
| 8 import zipfile | 8 import zipfile |
| 9 from pathlib import Path | 9 from pathlib import Path |
| 10 from typing import Any, Dict, Optional, Protocol, Tuple | 10 from typing import Any, Dict, Optional, Protocol, Tuple |
| 11 | 11 |
| 12 import numpy as np | |
| 12 import pandas as pd | 13 import pandas as pd |
| 13 import pandas.api.types as ptypes | 14 import pandas.api.types as ptypes |
| 14 import yaml | 15 import yaml |
| 15 from constants import ( | 16 from constants import ( |
| 16 IMAGE_PATH_COLUMN_NAME, | 17 IMAGE_PATH_COLUMN_NAME, |
| 416 | 417 |
| 417 | 418 |
| 418 def split_data_0_2( | 419 def split_data_0_2( |
| 419 df: pd.DataFrame, | 420 df: pd.DataFrame, |
| 420 split_column: str, | 421 split_column: str, |
| 421 validation_size: float = 0.15, | 422 validation_size: float = 0.1, |
| 422 random_state: int = 42, | 423 random_state: int = 42, |
| 423 label_column: Optional[str] = None, | 424 label_column: Optional[str] = None, |
| 424 ) -> pd.DataFrame: | 425 ) -> pd.DataFrame: |
| 425 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" | 426 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
| 426 out = df.copy() | 427 out = df.copy() |
| 429 idx_train = out.index[out[split_column] == 0].tolist() | 430 idx_train = out.index[out[split_column] == 0].tolist() |
| 430 | 431 |
| 431 if not idx_train: | 432 if not idx_train: |
| 432 logger.info("No rows with split=0; nothing to do.") | 433 logger.info("No rows with split=0; nothing to do.") |
| 433 return out | 434 return out |
| 435 | |
| 436 # Always use stratify if possible | |
| 434 stratify_arr = None | 437 stratify_arr = None |
| 435 if label_column and label_column in out.columns: | 438 if label_column and label_column in out.columns: |
| 436 label_counts = out.loc[idx_train, label_column].value_counts() | 439 label_counts = out.loc[idx_train, label_column].value_counts() |
| 437 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | 440 if label_counts.size > 1: |
| 441 # Force stratify even with fewer samples - adjust validation_size if needed | |
| 442 min_samples_per_class = label_counts.min() | |
| 443 if min_samples_per_class * validation_size < 1: | |
| 444 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | |
| 445 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) | |
| 446 if adjusted_validation_size != validation_size: | |
| 447 validation_size = adjusted_validation_size | |
| 448 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") | |
| 438 stratify_arr = out.loc[idx_train, label_column] | 449 stratify_arr = out.loc[idx_train, label_column] |
| 450 logger.info("Using stratified split for validation set") | |
| 439 else: | 451 else: |
| 440 logger.warning( | 452 logger.warning("Only one label class found; cannot stratify") |
| 441 "Cannot stratify (too few labels); splitting without stratify." | 453 |
| 442 ) | |
| 443 if validation_size <= 0: | 454 if validation_size <= 0: |
| 444 logger.info("validation_size <= 0; keeping all as train.") | 455 logger.info("validation_size <= 0; keeping all as train.") |
| 445 return out | 456 return out |
| 446 if validation_size >= 1: | 457 if validation_size >= 1: |
| 447 logger.info("validation_size >= 1; moving all train → validation.") | 458 logger.info("validation_size >= 1; moving all train → validation.") |
| 448 out.loc[idx_train, split_column] = 1 | 459 out.loc[idx_train, split_column] = 1 |
| 449 return out | 460 return out |
| 461 | |
| 462 # Always try stratified split first | |
| 450 try: | 463 try: |
| 451 train_idx, val_idx = train_test_split( | 464 train_idx, val_idx = train_test_split( |
| 452 idx_train, | 465 idx_train, |
| 453 test_size=validation_size, | 466 test_size=validation_size, |
| 454 random_state=random_state, | 467 random_state=random_state, |
| 455 stratify=stratify_arr, | 468 stratify=stratify_arr, |
| 456 ) | 469 ) |
| 470 logger.info("Successfully applied stratified split") | |
| 457 except ValueError as e: | 471 except ValueError as e: |
| 458 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | 472 logger.warning(f"Stratified split failed ({e}); falling back to random split.") |
| 459 train_idx, val_idx = train_test_split( | 473 train_idx, val_idx = train_test_split( |
| 460 idx_train, | 474 idx_train, |
| 461 test_size=validation_size, | 475 test_size=validation_size, |
| 462 random_state=random_state, | 476 random_state=random_state, |
| 463 stratify=None, | 477 stratify=None, |
| 464 ) | 478 ) |
| 479 | |
| 465 out.loc[train_idx, split_column] = 0 | 480 out.loc[train_idx, split_column] = 0 |
| 466 out.loc[val_idx, split_column] = 1 | 481 out.loc[val_idx, split_column] = 1 |
| 467 out[split_column] = out[split_column].astype(int) | 482 out[split_column] = out[split_column].astype(int) |
| 468 return out | 483 return out |
| 484 | |
| 485 | |
| 486 def create_stratified_random_split( | |
| 487 df: pd.DataFrame, | |
| 488 split_column: str, | |
| 489 split_probabilities: list = [0.7, 0.1, 0.2], | |
| 490 random_state: int = 42, | |
| 491 label_column: Optional[str] = None, | |
| 492 ) -> pd.DataFrame: | |
| 493 """Create a stratified random split when no split column exists.""" | |
| 494 out = df.copy() | |
| 495 | |
| 496 # initialize split column | |
| 497 out[split_column] = 0 | |
| 498 | |
| 499 if not label_column or label_column not in out.columns: | |
| 500 logger.warning("No label column found; using random split without stratification") | |
| 501 # fall back to simple random assignment | |
| 502 indices = out.index.tolist() | |
| 503 np.random.seed(random_state) | |
| 504 np.random.shuffle(indices) | |
| 505 | |
| 506 n_total = len(indices) | |
| 507 n_train = int(n_total * split_probabilities[0]) | |
| 508 n_val = int(n_total * split_probabilities[1]) | |
| 509 | |
| 510 out.loc[indices[:n_train], split_column] = 0 | |
| 511 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 512 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 513 | |
| 514 return out.astype({split_column: int}) | |
| 515 | |
| 516 # check if stratification is possible | |
| 517 label_counts = out[label_column].value_counts() | |
| 518 min_samples_per_class = label_counts.min() | |
| 519 | |
| 520 # ensure we have enough samples for stratification: | |
| 521 # Each class must have at least as many samples as the number of splits, | |
| 522 # so that each split can receive at least one sample per class. | |
| 523 min_samples_required = len(split_probabilities) | |
| 524 if min_samples_per_class < min_samples_required: | |
| 525 logger.warning( | |
| 526 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" | |
| 527 ) | |
| 528 # fall back to simple random assignment | |
| 529 indices = out.index.tolist() | |
| 530 np.random.seed(random_state) | |
| 531 np.random.shuffle(indices) | |
| 532 | |
| 533 n_total = len(indices) | |
| 534 n_train = int(n_total * split_probabilities[0]) | |
| 535 n_val = int(n_total * split_probabilities[1]) | |
| 536 | |
| 537 out.loc[indices[:n_train], split_column] = 0 | |
| 538 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
| 539 out.loc[indices[n_train + n_val:], split_column] = 2 | |
| 540 | |
| 541 return out.astype({split_column: int}) | |
| 542 | |
| 543 logger.info("Using stratified random split for train/validation/test sets") | |
| 544 | |
| 545 # first split: separate test set | |
| 546 train_val_idx, test_idx = train_test_split( | |
| 547 out.index.tolist(), | |
| 548 test_size=split_probabilities[2], | |
| 549 random_state=random_state, | |
| 550 stratify=out[label_column], | |
| 551 ) | |
| 552 | |
| 553 # second split: separate training and validation from remaining data | |
| 554 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) | |
| 555 train_idx, val_idx = train_test_split( | |
| 556 train_val_idx, | |
| 557 test_size=val_size_adjusted, | |
| 558 random_state=random_state, | |
| 559 stratify=out.loc[train_val_idx, label_column], | |
| 560 ) | |
| 561 | |
| 562 # assign split values | |
| 563 out.loc[train_idx, split_column] = 0 | |
| 564 out.loc[val_idx, split_column] = 1 | |
| 565 out.loc[test_idx, split_column] = 2 | |
| 566 | |
| 567 logger.info("Successfully applied stratified random split") | |
| 568 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") | |
| 569 | |
| 570 return out.astype({split_column: int}) | |
| 469 | 571 |
| 470 | 572 |
| 471 class Backend(Protocol): | 573 class Backend(Protocol): |
| 472 """Interface for a machine learning backend.""" | 574 """Interface for a machine learning backend.""" |
| 473 | 575 |
| 1087 raise | 1189 raise |
| 1088 | 1190 |
| 1089 if SPLIT_COLUMN_NAME in df.columns: | 1191 if SPLIT_COLUMN_NAME in df.columns: |
| 1090 df, split_config, split_info = self._process_fixed_split(df) | 1192 df, split_config, split_info = self._process_fixed_split(df) |
| 1091 else: | 1193 else: |
| 1092 logger.info("No split column; using random split") | 1194 logger.info("No split column; creating stratified random split") |
| 1195 df = create_stratified_random_split( | |
| 1196 df=df, | |
| 1197 split_column=SPLIT_COLUMN_NAME, | |
| 1198 split_probabilities=self.args.split_probabilities, | |
| 1199 random_state=self.args.random_seed, | |
| 1200 label_column=LABEL_COLUMN_NAME, | |
| 1201 ) | |
| 1093 split_config = { | 1202 split_config = { |
| 1094 "type": "random", | 1203 "type": "fixed", |
| 1095 "probabilities": self.args.split_probabilities, | 1204 "column": SPLIT_COLUMN_NAME, |
| 1096 } | 1205 } |
| 1097 split_info = ( | 1206 split_info = ( |
| 1098 f"No split column in CSV. Used random split: " | 1207 f"No split column in CSV. Created stratified random split: " |
| 1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1208 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
| 1100 f"for train/val/test." | 1209 f"for train/val/test with balanced label distribution." |
| 1101 ) | 1210 ) |
| 1102 | 1211 |
| 1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME | 1212 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
| 1104 try: | 1213 try: |
| 1105 | 1214 |
| 1137 ) | 1246 ) |
| 1138 split_info = ( | 1247 split_info = ( |
| 1139 "Detected a split column (with values 0 and 2) in the input CSV. " | 1248 "Detected a split column (with values 0 and 2) in the input CSV. " |
| 1140 f"Used this column as a base and reassigned " | 1249 f"Used this column as a base and reassigned " |
| 1141 f"{self.args.validation_size * 100:.1f}% " | 1250 f"{self.args.validation_size * 100:.1f}% " |
| 1142 "of the training set (originally labeled 0) to validation (labeled 1)." | 1251 "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling." |
| 1143 ) | 1252 ) |
| 1144 logger.info("Applied custom 0/2 split.") | 1253 logger.info("Applied custom 0/2 split.") |
| 1145 elif unique.issubset({0, 1, 2}): | 1254 elif unique.issubset({0, 1, 2}): |
| 1146 split_info = "Used user-defined split column from CSV." | 1255 split_info = "Used user-defined split column from CSV." |
| 1147 logger.info("Using fixed split as-is.") | 1256 logger.info("Using fixed split as-is.") |
| 1317 help="Where to write outputs", | 1426 help="Where to write outputs", |
| 1318 ) | 1427 ) |
| 1319 parser.add_argument( | 1428 parser.add_argument( |
| 1320 "--validation-size", | 1429 "--validation-size", |
| 1321 type=float, | 1430 type=float, |
| 1322 default=0.15, | 1431 default=0.1, |
| 1323 help="Fraction for validation (0.0–1.0)", | 1432 help="Fraction for validation (0.0–1.0)", |
| 1324 ) | 1433 ) |
| 1325 parser.add_argument( | 1434 parser.add_argument( |
| 1326 "--preprocessing-num-processes", | 1435 "--preprocessing-num-processes", |
| 1327 type=int, | 1436 type=int, |
