Mercurial > repos > goeckslab > image_learner
comparison image_learner_cli.py @ 7:801a8b6973fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 67df782ea551181e1d240d463764016ba528eba9
author | goeckslab |
---|---|
date | Fri, 08 Aug 2025 13:06:28 +0000 |
parents | 09904b1f61f5 |
children |
comparison
equal
deleted
inserted
replaced
6:09904b1f61f5 | 7:801a8b6973fb |
---|---|
7 import tempfile | 7 import tempfile |
8 import zipfile | 8 import zipfile |
9 from pathlib import Path | 9 from pathlib import Path |
10 from typing import Any, Dict, Optional, Protocol, Tuple | 10 from typing import Any, Dict, Optional, Protocol, Tuple |
11 | 11 |
12 import numpy as np | |
12 import pandas as pd | 13 import pandas as pd |
13 import pandas.api.types as ptypes | 14 import pandas.api.types as ptypes |
14 import yaml | 15 import yaml |
15 from constants import ( | 16 from constants import ( |
16 IMAGE_PATH_COLUMN_NAME, | 17 IMAGE_PATH_COLUMN_NAME, |
416 | 417 |
417 | 418 |
418 def split_data_0_2( | 419 def split_data_0_2( |
419 df: pd.DataFrame, | 420 df: pd.DataFrame, |
420 split_column: str, | 421 split_column: str, |
421 validation_size: float = 0.15, | 422 validation_size: float = 0.1, |
422 random_state: int = 42, | 423 random_state: int = 42, |
423 label_column: Optional[str] = None, | 424 label_column: Optional[str] = None, |
424 ) -> pd.DataFrame: | 425 ) -> pd.DataFrame: |
425 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" | 426 """Given a DataFrame whose split_column only contains {0,2}, re-assign a portion of the 0s to become 1s (validation).""" |
426 out = df.copy() | 427 out = df.copy() |
429 idx_train = out.index[out[split_column] == 0].tolist() | 430 idx_train = out.index[out[split_column] == 0].tolist() |
430 | 431 |
431 if not idx_train: | 432 if not idx_train: |
432 logger.info("No rows with split=0; nothing to do.") | 433 logger.info("No rows with split=0; nothing to do.") |
433 return out | 434 return out |
435 | |
436 # Always use stratify if possible | |
434 stratify_arr = None | 437 stratify_arr = None |
435 if label_column and label_column in out.columns: | 438 if label_column and label_column in out.columns: |
436 label_counts = out.loc[idx_train, label_column].value_counts() | 439 label_counts = out.loc[idx_train, label_column].value_counts() |
437 if label_counts.size > 1 and (label_counts.min() * validation_size) >= 1: | 440 if label_counts.size > 1: |
441 # Force stratify even with fewer samples - adjust validation_size if needed | |
442 min_samples_per_class = label_counts.min() | |
443 if min_samples_per_class * validation_size < 1: | |
444 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size | |
445 adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) | |
446 if adjusted_validation_size != validation_size: | |
447 validation_size = adjusted_validation_size | |
448 logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") | |
438 stratify_arr = out.loc[idx_train, label_column] | 449 stratify_arr = out.loc[idx_train, label_column] |
450 logger.info("Using stratified split for validation set") | |
439 else: | 451 else: |
440 logger.warning( | 452 logger.warning("Only one label class found; cannot stratify") |
441 "Cannot stratify (too few labels); splitting without stratify." | 453 |
442 ) | |
443 if validation_size <= 0: | 454 if validation_size <= 0: |
444 logger.info("validation_size <= 0; keeping all as train.") | 455 logger.info("validation_size <= 0; keeping all as train.") |
445 return out | 456 return out |
446 if validation_size >= 1: | 457 if validation_size >= 1: |
447 logger.info("validation_size >= 1; moving all train → validation.") | 458 logger.info("validation_size >= 1; moving all train → validation.") |
448 out.loc[idx_train, split_column] = 1 | 459 out.loc[idx_train, split_column] = 1 |
449 return out | 460 return out |
461 | |
462 # Always try stratified split first | |
450 try: | 463 try: |
451 train_idx, val_idx = train_test_split( | 464 train_idx, val_idx = train_test_split( |
452 idx_train, | 465 idx_train, |
453 test_size=validation_size, | 466 test_size=validation_size, |
454 random_state=random_state, | 467 random_state=random_state, |
455 stratify=stratify_arr, | 468 stratify=stratify_arr, |
456 ) | 469 ) |
470 logger.info("Successfully applied stratified split") | |
457 except ValueError as e: | 471 except ValueError as e: |
458 logger.warning(f"Stratified split failed ({e}); retrying without stratify.") | 472 logger.warning(f"Stratified split failed ({e}); falling back to random split.") |
459 train_idx, val_idx = train_test_split( | 473 train_idx, val_idx = train_test_split( |
460 idx_train, | 474 idx_train, |
461 test_size=validation_size, | 475 test_size=validation_size, |
462 random_state=random_state, | 476 random_state=random_state, |
463 stratify=None, | 477 stratify=None, |
464 ) | 478 ) |
479 | |
465 out.loc[train_idx, split_column] = 0 | 480 out.loc[train_idx, split_column] = 0 |
466 out.loc[val_idx, split_column] = 1 | 481 out.loc[val_idx, split_column] = 1 |
467 out[split_column] = out[split_column].astype(int) | 482 out[split_column] = out[split_column].astype(int) |
468 return out | 483 return out |
484 | |
485 | |
486 def create_stratified_random_split( | |
487 df: pd.DataFrame, | |
488 split_column: str, | |
489 split_probabilities: list = [0.7, 0.1, 0.2], | |
490 random_state: int = 42, | |
491 label_column: Optional[str] = None, | |
492 ) -> pd.DataFrame: | |
493 """Create a stratified random split when no split column exists.""" | |
494 out = df.copy() | |
495 | |
496 # initialize split column | |
497 out[split_column] = 0 | |
498 | |
499 if not label_column or label_column not in out.columns: | |
500 logger.warning("No label column found; using random split without stratification") | |
501 # fall back to simple random assignment | |
502 indices = out.index.tolist() | |
503 np.random.seed(random_state) | |
504 np.random.shuffle(indices) | |
505 | |
506 n_total = len(indices) | |
507 n_train = int(n_total * split_probabilities[0]) | |
508 n_val = int(n_total * split_probabilities[1]) | |
509 | |
510 out.loc[indices[:n_train], split_column] = 0 | |
511 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
512 out.loc[indices[n_train + n_val:], split_column] = 2 | |
513 | |
514 return out.astype({split_column: int}) | |
515 | |
516 # check if stratification is possible | |
517 label_counts = out[label_column].value_counts() | |
518 min_samples_per_class = label_counts.min() | |
519 | |
520 # ensure we have enough samples for stratification: | |
521 # Each class must have at least as many samples as the number of splits, | |
522 # so that each split can receive at least one sample per class. | |
523 min_samples_required = len(split_probabilities) | |
524 if min_samples_per_class < min_samples_required: | |
525 logger.warning( | |
526 f"Insufficient samples per class for stratification (min: {min_samples_per_class}, required: {min_samples_required}); using random split" | |
527 ) | |
528 # fall back to simple random assignment | |
529 indices = out.index.tolist() | |
530 np.random.seed(random_state) | |
531 np.random.shuffle(indices) | |
532 | |
533 n_total = len(indices) | |
534 n_train = int(n_total * split_probabilities[0]) | |
535 n_val = int(n_total * split_probabilities[1]) | |
536 | |
537 out.loc[indices[:n_train], split_column] = 0 | |
538 out.loc[indices[n_train:n_train + n_val], split_column] = 1 | |
539 out.loc[indices[n_train + n_val:], split_column] = 2 | |
540 | |
541 return out.astype({split_column: int}) | |
542 | |
543 logger.info("Using stratified random split for train/validation/test sets") | |
544 | |
545 # first split: separate test set | |
546 train_val_idx, test_idx = train_test_split( | |
547 out.index.tolist(), | |
548 test_size=split_probabilities[2], | |
549 random_state=random_state, | |
550 stratify=out[label_column], | |
551 ) | |
552 | |
553 # second split: separate training and validation from remaining data | |
554 val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) | |
555 train_idx, val_idx = train_test_split( | |
556 train_val_idx, | |
557 test_size=val_size_adjusted, | |
558 random_state=random_state, | |
559 stratify=out.loc[train_val_idx, label_column], | |
560 ) | |
561 | |
562 # assign split values | |
563 out.loc[train_idx, split_column] = 0 | |
564 out.loc[val_idx, split_column] = 1 | |
565 out.loc[test_idx, split_column] = 2 | |
566 | |
567 logger.info("Successfully applied stratified random split") | |
568 logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") | |
569 | |
570 return out.astype({split_column: int}) | |
469 | 571 |
470 | 572 |
471 class Backend(Protocol): | 573 class Backend(Protocol): |
472 """Interface for a machine learning backend.""" | 574 """Interface for a machine learning backend.""" |
473 | 575 |
1087 raise | 1189 raise |
1088 | 1190 |
1089 if SPLIT_COLUMN_NAME in df.columns: | 1191 if SPLIT_COLUMN_NAME in df.columns: |
1090 df, split_config, split_info = self._process_fixed_split(df) | 1192 df, split_config, split_info = self._process_fixed_split(df) |
1091 else: | 1193 else: |
1092 logger.info("No split column; using random split") | 1194 logger.info("No split column; creating stratified random split") |
1195 df = create_stratified_random_split( | |
1196 df=df, | |
1197 split_column=SPLIT_COLUMN_NAME, | |
1198 split_probabilities=self.args.split_probabilities, | |
1199 random_state=self.args.random_seed, | |
1200 label_column=LABEL_COLUMN_NAME, | |
1201 ) | |
1093 split_config = { | 1202 split_config = { |
1094 "type": "random", | 1203 "type": "fixed", |
1095 "probabilities": self.args.split_probabilities, | 1204 "column": SPLIT_COLUMN_NAME, |
1096 } | 1205 } |
1097 split_info = ( | 1206 split_info = ( |
1098 f"No split column in CSV. Used random split: " | 1207 f"No split column in CSV. Created stratified random split: " |
1099 f"{[int(p * 100) for p in self.args.split_probabilities]}% " | 1208 f"{[int(p * 100) for p in self.args.split_probabilities]}% " |
1100 f"for train/val/test." | 1209 f"for train/val/test with balanced label distribution." |
1101 ) | 1210 ) |
1102 | 1211 |
1103 final_csv = self.temp_dir / TEMP_CSV_FILENAME | 1212 final_csv = self.temp_dir / TEMP_CSV_FILENAME |
1104 try: | 1213 try: |
1105 | 1214 |
1137 ) | 1246 ) |
1138 split_info = ( | 1247 split_info = ( |
1139 "Detected a split column (with values 0 and 2) in the input CSV. " | 1248 "Detected a split column (with values 0 and 2) in the input CSV. " |
1140 f"Used this column as a base and reassigned " | 1249 f"Used this column as a base and reassigned " |
1141 f"{self.args.validation_size * 100:.1f}% " | 1250 f"{self.args.validation_size * 100:.1f}% " |
1142 "of the training set (originally labeled 0) to validation (labeled 1)." | 1251 "of the training set (originally labeled 0) to validation (labeled 1) using stratified sampling." |
1143 ) | 1252 ) |
1144 logger.info("Applied custom 0/2 split.") | 1253 logger.info("Applied custom 0/2 split.") |
1145 elif unique.issubset({0, 1, 2}): | 1254 elif unique.issubset({0, 1, 2}): |
1146 split_info = "Used user-defined split column from CSV." | 1255 split_info = "Used user-defined split column from CSV." |
1147 logger.info("Using fixed split as-is.") | 1256 logger.info("Using fixed split as-is.") |
1317 help="Where to write outputs", | 1426 help="Where to write outputs", |
1318 ) | 1427 ) |
1319 parser.add_argument( | 1428 parser.add_argument( |
1320 "--validation-size", | 1429 "--validation-size", |
1321 type=float, | 1430 type=float, |
1322 default=0.15, | 1431 default=0.1, |
1323 help="Fraction for validation (0.0–1.0)", | 1432 help="Fraction for validation (0.0–1.0)", |
1324 ) | 1433 ) |
1325 parser.add_argument( | 1434 parser.add_argument( |
1326 "--preprocessing-num-processes", | 1435 "--preprocessing-num-processes", |
1327 type=int, | 1436 type=int, |