Mercurial > repos > goeckslab > image_learner
annotate ludwig_backend.py @ 19:c460abae83eb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
| author | goeckslab |
|---|---|
| date | Thu, 18 Dec 2025 16:59:58 +0000 |
| parents | bbf30253c99f |
| children |
| rev | line source |
|---|---|
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1 import inspect |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
2 import json |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
3 import logging |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
4 import os |
|
19
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
5 import zipfile |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
6 from pathlib import Path |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
7 from typing import Any, Dict, List, Optional, Protocol, Tuple |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
8 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
9 import pandas as pd |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
10 import pandas.api.types as ptypes |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
11 import yaml |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
12 from constants import ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
13 IMAGE_PATH_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
14 LABEL_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
15 MODEL_ENCODER_TEMPLATES, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
16 SPLIT_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
17 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
18 from html_structure import ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
19 build_tabbed_html, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
20 encode_image_to_base64, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
21 format_config_table_html, |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
22 format_dataset_overview_table, |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
23 format_stats_table_html, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
24 format_test_merged_stats_table_html, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
25 format_train_val_stats_table_html, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
26 get_html_closing, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
27 get_html_template, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
28 get_metrics_help_modal, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
29 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
30 from ludwig.globals import ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
31 DESCRIPTION_FILE_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
32 PREDICTIONS_PARQUET_FILE_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
33 TEST_STATISTICS_FILE_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
34 TRAIN_SET_METADATA_FILE_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
35 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
36 from ludwig.utils.data_utils import get_split_path |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
37 from metaformer_setup import get_visualizations_registry, META_DEFAULT_CFGS |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
38 from plotly_plots import ( |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
39 build_binary_threshold_plot, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
40 build_classification_plots, |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
41 build_multiclass_metric_plots, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
42 build_prediction_diagnostics, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
43 build_regression_test_plots, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
44 build_regression_train_val_plots, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
45 build_train_validation_plots, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
46 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
47 from utils import detect_output_type, extract_metrics_from_json |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
48 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
49 logger = logging.getLogger("ImageLearner") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
50 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
51 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
52 class Backend(Protocol): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
53 """Interface for a machine learning backend.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
54 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
55 def prepare_config( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
56 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
57 config_params: Dict[str, Any], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
58 split_config: Dict[str, Any], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
59 ) -> str: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
60 ... |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
61 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
62 def run_experiment( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
63 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
64 dataset_path: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
65 config_path: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
66 output_dir: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
67 random_seed: int, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
68 ) -> None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
69 ... |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
70 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
71 def generate_plots(self, output_dir: Path) -> None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
72 ... |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
73 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
74 def generate_html_report( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
75 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
76 title: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
77 output_dir: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
78 config: Dict[str, Any], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
79 split_info: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
80 ) -> Path: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
81 ... |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
82 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
83 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
84 class LudwigDirectBackend: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
85 """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
86 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
87 _torchvision_patched = False |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
88 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
89 def _detect_image_dimensions(self, image_zip_path: str) -> Tuple[int, int]: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
90 """Detect image dimensions from the first image in the dataset.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
91 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
92 import zipfile |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
93 from PIL import Image |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
94 import io |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
95 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
96 # Check if image_zip is provided |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
97 if not image_zip_path: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
98 logger.warning("No image zip provided, using default 224x224") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
99 return 224, 224 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
100 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
101 # Extract first image to detect dimensions |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
102 with zipfile.ZipFile(image_zip_path, 'r') as z: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
103 image_files = [f for f in z.namelist() if f.lower().endswith(('.png', '.jpg', '.jpeg'))] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
104 if not image_files: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
105 logger.warning("No image files found in zip, using default 224x224") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
106 return 224, 224 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
107 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
108 # Check first image |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
109 with z.open(image_files[0]) as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
110 img = Image.open(io.BytesIO(f.read())) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
111 width, height = img.size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
112 logger.info(f"Detected image dimensions: {width}x{height}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
113 return height, width # Return as (height, width) to match encoder config |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
114 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
115 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
116 logger.warning(f"Error detecting image dimensions: {e}, using default 224x224") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
117 return 224, 224 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
118 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
119 def prepare_config( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
120 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
121 config_params: Dict[str, Any], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
122 split_config: Dict[str, Any], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
123 ) -> str: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
124 logger.info("LudwigDirectBackend: Preparing YAML configuration.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
125 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
126 model_name = config_params.get("model_name", "resnet18") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
127 use_pretrained = config_params.get("use_pretrained", False) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
128 fine_tune = config_params.get("fine_tune", False) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
129 if use_pretrained: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
130 trainable = bool(fine_tune) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
131 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
132 trainable = True |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
133 epochs = config_params.get("epochs", 10) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
134 batch_size = config_params.get("batch_size") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
135 num_processes = config_params.get("preprocessing_num_processes", 1) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
136 early_stop = config_params.get("early_stop", None) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
137 learning_rate = config_params.get("learning_rate") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
138 learning_rate = "auto" if learning_rate is None else float(learning_rate) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
139 raw_encoder = MODEL_ENCODER_TEMPLATES.get(model_name, model_name) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
140 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
141 # --- MetaFormer detection and config logic --- |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
142 def _is_metaformer(name: str) -> bool: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
143 return isinstance(name, str) and name.startswith( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
144 ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
145 "identityformer_", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
146 "randformer_", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
147 "poolformerv2_", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
148 "convformer_", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
149 "caformer_", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
150 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
151 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
152 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
153 # Check if this is a MetaFormer model (either direct name or in custom_model) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
154 is_metaformer = ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
155 _is_metaformer(model_name) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
156 or (isinstance(raw_encoder, dict) and "custom_model" in raw_encoder and _is_metaformer(raw_encoder["custom_model"])) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
157 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
158 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
159 metaformer_resize: Optional[Tuple[int, int]] = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
160 metaformer_channels = 3 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
161 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
162 if is_metaformer: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
163 # Handle MetaFormer models |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
164 custom_model = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
165 if isinstance(raw_encoder, dict) and "custom_model" in raw_encoder: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
166 custom_model = raw_encoder["custom_model"] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
167 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
168 custom_model = model_name |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
169 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
170 logger.info(f"DETECTED MetaFormer model: {custom_model}") |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
171 # Stash the model name for patched Stacked2DCNN in case Ludwig drops custom_model from kwargs |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
172 try: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
173 from MetaFormer.metaformer_stacked_cnn import set_current_metaformer_model |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
174 |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
175 set_current_metaformer_model(custom_model) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
176 except Exception: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
177 logger.debug("Could not set current MetaFormer model hint; proceeding without global override") |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
178 # Also pass via environment to survive process boundaries (e.g., Ray workers) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
179 os.environ["GLEAM_META_FORMER_MODEL"] = custom_model |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
180 cfg_channels, cfg_height, cfg_width = 3, 224, 224 |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
181 model_cfg = {} |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
182 if META_DEFAULT_CFGS: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
183 model_cfg = META_DEFAULT_CFGS.get(custom_model, {}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
184 input_size = model_cfg.get("input_size") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
185 if isinstance(input_size, (list, tuple)) and len(input_size) == 3: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
186 cfg_channels, cfg_height, cfg_width = ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
187 int(input_size[0]), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
188 int(input_size[1]), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
189 int(input_size[2]), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
190 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
191 |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
192 weights_url = None |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
193 if isinstance(model_cfg, dict): |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
194 weights_url = model_cfg.get("url") |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
195 logger.info( |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
196 "MetaFormer cfg lookup: model=%s has_cfg=%s url=%s use_pretrained=%s", |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
197 custom_model, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
198 bool(model_cfg), |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
199 weights_url, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
200 use_pretrained, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
201 ) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
202 if use_pretrained and not weights_url: |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
203 logger.warning( |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
204 "MetaFormer pretrained requested for %s but no URL found in default cfgs; model will be randomly initialized", |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
205 custom_model, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
206 ) |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
207 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
208 resize_value = config_params.get("image_resize") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
209 if resize_value and resize_value != "original": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
210 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
211 dimensions = resize_value.split("x") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
212 if len(dimensions) == 2: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
213 target_height, target_width = int(dimensions[0]), int(dimensions[1]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
214 if target_height <= 0 or target_width <= 0: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
215 raise ValueError( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
216 f"Image resize must be positive integers, received {resize_value}." |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
217 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
218 logger.info(f"MetaFormer explicit resize: {target_height}x{target_width}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
219 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
220 raise ValueError(resize_value) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
221 except (ValueError, IndexError): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
222 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
223 "Invalid image resize format '%s'; falling back to model default %sx%s", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
224 resize_value, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
225 cfg_height, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
226 cfg_width, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
227 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
228 target_height, target_width = cfg_height, cfg_width |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
229 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
230 image_zip_path = config_params.get("image_zip", "") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
231 detected_height, detected_width = self._detect_image_dimensions(image_zip_path) |
|
16
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
232 target_height, target_width = detected_height, detected_width |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
233 if use_pretrained and (detected_height, detected_width) != (cfg_height, cfg_width): |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
234 logger.info( |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
235 "MetaFormer pretrained weights expect %sx%s; proceeding with detected %sx%s", |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
236 cfg_height, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
237 cfg_width, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
238 detected_height, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
239 detected_width, |
|
8729f69e9207
planemo upload for repository https://github.com/goeckslab/gleam.git commit bb4bcdc888d73bbfd85d78ce8999a1080fe813ff
goeckslab
parents:
15
diff
changeset
|
240 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
241 if target_height <= 0 or target_width <= 0: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
242 raise ValueError( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
243 f"Invalid detected image dimensions for MetaFormer: {target_height}x{target_width}." |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
244 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
245 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
246 metaformer_channels = cfg_channels |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
247 metaformer_resize = (target_height, target_width) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
248 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
249 encoder_config = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
250 "type": "stacked_cnn", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
251 "height": target_height, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
252 "width": target_width, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
253 "num_channels": metaformer_channels, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
254 "output_size": 128, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
255 "use_pretrained": use_pretrained, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
256 "trainable": trainable, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
257 "custom_model": custom_model, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
258 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
259 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
260 elif isinstance(raw_encoder, dict): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
261 # Handle image resize for regular encoders |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
262 # Note: Standard encoders like ResNet don't support height/width parameters |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
263 # Resize will be handled at the preprocessing level by Ludwig |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
264 if config_params.get("image_resize") and config_params["image_resize"] != "original": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
265 logger.info(f"Resize requested: {config_params['image_resize']} for standard encoder. Resize will be handled at preprocessing level.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
266 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
267 encoder_config = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
268 **raw_encoder, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
269 "use_pretrained": use_pretrained, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
270 "trainable": trainable, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
271 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
272 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
273 encoder_config = {"type": raw_encoder} |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
274 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
275 # Set a human-friendly architecture string for reporting |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
276 arch_display = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
277 if is_metaformer and custom_model: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
278 arch_display = str(custom_model) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
279 elif isinstance(raw_encoder, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
280 enc_type = raw_encoder.get("type") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
281 enc_variant = raw_encoder.get("model_variant") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
282 if enc_type: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
283 base = str(enc_type).replace("_", " ").title() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
284 arch_display = f"{base} {enc_variant}" if enc_variant is not None else base |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
285 else: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
286 arch_display = str(raw_encoder).replace("_", " ").title() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
287 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
288 if not arch_display: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
289 arch_display = str(model_name) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
290 config_params["architecture"] = arch_display |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
291 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
292 batch_size_cfg = batch_size or "auto" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
293 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
294 label_column_path = config_params.get("label_column_data_path") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
295 label_series = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
296 label_metadata_hint = config_params.get("label_metadata") or {} |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
297 output_type_hint = config_params.get("output_type_hint") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
298 num_unique_labels = int(label_metadata_hint.get("num_unique", 2)) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
299 numeric_binary_labels = bool(label_metadata_hint.get("is_numeric_binary", False)) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
300 likely_regression = bool(label_metadata_hint.get("likely_regression", False)) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
301 if label_column_path is not None and Path(label_column_path).exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
302 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
303 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
304 non_na = label_series.dropna() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
305 if not non_na.empty: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
306 num_unique_labels = non_na.nunique() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
307 is_numeric = ptypes.is_numeric_dtype(label_series.dtype) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
308 numeric_binary_labels = is_numeric and num_unique_labels == 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
309 likely_regression = ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
310 is_numeric and not numeric_binary_labels and num_unique_labels > 10 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
311 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
312 if numeric_binary_labels: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
313 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
314 "Detected numeric binary labels in '%s'; configuring Ludwig for binary classification.", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
315 LABEL_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
316 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
317 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
318 logger.warning(f"Could not read label column for task detection: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
319 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
320 if output_type_hint == "binary": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
321 num_unique_labels = 2 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
322 numeric_binary_labels = numeric_binary_labels or bool( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
323 label_metadata_hint.get("is_numeric", False) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
324 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
325 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
326 if numeric_binary_labels: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
327 task_type = "classification" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
328 elif likely_regression: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
329 task_type = "regression" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
330 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
331 task_type = "classification" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
332 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
333 if task_type == "regression" and numeric_binary_labels: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
334 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
335 "Numeric binary labels detected but regression task chosen; forcing classification to avoid invalid Ludwig config." |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
336 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
337 task_type = "classification" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
338 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
339 config_params["task_type"] = task_type |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
340 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
341 image_feat: Dict[str, Any] = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
342 "name": IMAGE_PATH_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
343 "type": "image", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
344 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
345 # Set preprocessing dimensions FIRST for MetaFormer models |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
346 if is_metaformer: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
347 if metaformer_resize is None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
348 metaformer_resize = (224, 224) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
349 height, width = metaformer_resize |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
350 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
351 # CRITICAL: Set preprocessing dimensions FIRST for MetaFormer models |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
352 # This is essential for MetaFormer models to work properly |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
353 if "preprocessing" not in image_feat: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
354 image_feat["preprocessing"] = {} |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
355 image_feat["preprocessing"]["height"] = height |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
356 image_feat["preprocessing"]["width"] = width |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
357 # Use infer_image_dimensions=True to allow Ludwig to read images for validation |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
358 # but set explicit max dimensions to control the output size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
359 image_feat["preprocessing"]["infer_image_dimensions"] = True |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
360 image_feat["preprocessing"]["infer_image_max_height"] = height |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
361 image_feat["preprocessing"]["infer_image_max_width"] = width |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
362 image_feat["preprocessing"]["num_channels"] = metaformer_channels |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
363 image_feat["preprocessing"]["resize_method"] = "interpolate" # Use interpolation for better quality |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
364 image_feat["preprocessing"]["standardize_image"] = "imagenet1k" # Use ImageNet standardization |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
365 # Force Ludwig to respect our dimensions by setting additional parameters |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
366 image_feat["preprocessing"]["requires_equal_dimensions"] = False |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
367 logger.info(f"Set preprocessing dimensions for MetaFormer: {height}x{width} (infer_dimensions=True with max dimensions to allow validation)") |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
368 config_params["image_size"] = f"{height}x{width}" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
369 # Now set the encoder configuration |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
370 image_feat["encoder"] = encoder_config |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
371 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
372 if config_params.get("augmentation") is not None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
373 image_feat["augmentation"] = config_params["augmentation"] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
374 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
375 # Add resize configuration for standard encoders (ResNet, etc.) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
376 # FIXED: MetaFormer models now respect user dimensions completely |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
377 # Previously there was a double resize issue where MetaFormer would force 224x224 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
378 # Now both MetaFormer and standard encoders respect user's resize choice |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
379 if (not is_metaformer) and config_params.get("image_resize") and config_params["image_resize"] != "original": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
380 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
381 dimensions = config_params["image_resize"].split("x") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
382 if len(dimensions) == 2: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
383 height, width = int(dimensions[0]), int(dimensions[1]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
384 if height <= 0 or width <= 0: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
385 raise ValueError( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
386 f"Image resize must be positive integers, received {config_params['image_resize']}." |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
387 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
388 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
389 # Add resize to preprocessing for standard encoders |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
390 if "preprocessing" not in image_feat: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
391 image_feat["preprocessing"] = {} |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
392 image_feat["preprocessing"]["height"] = height |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
393 image_feat["preprocessing"]["width"] = width |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
394 # Use infer_image_dimensions=True to allow Ludwig to read images for validation |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
395 # but set explicit max dimensions to control the output size |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
396 image_feat["preprocessing"]["infer_image_dimensions"] = True |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
397 image_feat["preprocessing"]["infer_image_max_height"] = height |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
398 image_feat["preprocessing"]["infer_image_max_width"] = width |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
399 logger.info(f"Added resize preprocessing: {height}x{width} for standard encoder with infer_image_dimensions=True and max dimensions") |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
400 config_params["image_size"] = f"{height}x{width}" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
401 except (ValueError, IndexError): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
402 logger.warning(f"Invalid image resize format: {config_params['image_resize']}, skipping resize preprocessing") |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
403 elif not is_metaformer: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
404 # No explicit resize provided; keep for reporting purposes |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
405 config_params.setdefault("image_size", "original") |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
406 |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
407 def _resolve_validation_metric( |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
408 task: str, requested: Optional[str], output_feature: Dict[str, Any] |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
409 ) -> Optional[str]: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
410 """ |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
411 Pick a validation metric that Ludwig will accept for the resolved task/output. |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
412 If the requested metric is invalid, fall back to a safe option or omit it entirely. |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
413 """ |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
414 default_map = { |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
415 "regression": "mean_squared_error", |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
416 "binary": "roc_auc", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
417 "category": "accuracy", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
418 } |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
419 allowed_map = { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
420 "regression": { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
421 "mean_absolute_error", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
422 "mean_squared_error", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
423 "root_mean_squared_error", |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
424 "root_mean_squared_percentage_error", |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
425 "loss", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
426 }, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
427 "binary": { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
428 "roc_auc", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
429 "accuracy", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
430 "precision", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
431 "recall", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
432 "specificity", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
433 "loss", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
434 }, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
435 "category": { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
436 "accuracy", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
437 "balanced_accuracy", |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
438 "hits_at_k", |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
439 "loss", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
440 }, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
441 } |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
442 alias_map = { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
443 "regression": { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
444 "mae": "mean_absolute_error", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
445 "mse": "mean_squared_error", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
446 "rmse": "root_mean_squared_error", |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
447 "rmspe": "root_mean_squared_percentage_error", |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
448 }, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
449 "category": {}, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
450 "binary": { |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
451 "roc_auc": "roc_auc", |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
452 }, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
453 } |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
454 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
455 default_metric = default_map.get(task) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
456 metric = requested or default_metric |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
457 if metric is None: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
458 return None |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
459 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
460 metric = alias_map.get(task, {}).get(metric, metric) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
461 |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
462 # Prefer Ludwig's own metric registry when available; intersect with known-safe sets. |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
463 registry_metrics = None |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
464 try: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
465 from ludwig.features.feature_registries import output_type_registry |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
466 |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
467 feature_cls = output_type_registry.get(output_feature.get("type")) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
468 if feature_cls: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
469 feature_obj = feature_cls(feature=output_feature) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
470 metrics_attr = getattr(feature_obj, "metric_functions", None) or getattr( |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
471 feature_obj, "metrics", None |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
472 ) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
473 if isinstance(metrics_attr, dict): |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
474 registry_metrics = set(metrics_attr.keys()) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
475 except Exception as exc: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
476 logger.debug( |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
477 "Could not inspect Ludwig metrics for output type %s: %s", |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
478 output_feature.get("type"), |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
479 exc, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
480 ) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
481 |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
482 allowed = set(allowed_map.get(task, set())) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
483 if registry_metrics: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
484 # Only keep metrics that Ludwig actually exposes for this output type; |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
485 # if the intersection is empty, fall back to the registry set. |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
486 intersected = allowed.intersection(registry_metrics) |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
487 allowed = intersected or registry_metrics |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
488 |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
489 if allowed and metric not in allowed: |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
490 fallback_candidates = [ |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
491 default_metric if default_metric in allowed else None, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
492 "loss" if "loss" in allowed else None, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
493 next(iter(allowed), None), |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
494 ] |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
495 fallback = next((m for m in fallback_candidates if m in allowed), None) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
496 if requested: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
497 logger.warning( |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
498 "Validation metric '%s' is not supported for %s outputs; %s", |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
499 requested, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
500 task, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
501 (f"using '{fallback}' instead." if fallback else "omitting validation_metric."), |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
502 ) |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
503 metric = fallback |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
504 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
505 return metric |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
506 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
507 if task_type == "regression": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
508 output_feat = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
509 "name": LABEL_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
510 "type": "number", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
511 "decoder": {"type": "regressor"}, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
512 "loss": {"type": "mean_squared_error"}, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
513 } |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
514 val_metric = _resolve_validation_metric( |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
515 "regression", |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
516 config_params.get("validation_metric"), |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
517 output_feat, |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
518 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
519 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
520 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
521 if num_unique_labels == 2: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
522 output_feat = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
523 "name": LABEL_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
524 "type": "binary", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
525 "loss": {"type": "binary_weighted_cross_entropy"}, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
526 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
527 if config_params.get("threshold") is not None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
528 output_feat["threshold"] = float(config_params["threshold"]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
529 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
530 output_feat = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
531 "name": LABEL_COLUMN_NAME, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
532 "type": "category", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
533 "loss": {"type": "softmax_cross_entropy"}, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
534 } |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
535 val_metric = _resolve_validation_metric( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
536 "binary" if num_unique_labels == 2 else "category", |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
537 config_params.get("validation_metric"), |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
538 output_feat, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
539 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
540 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
541 # Propagate the resolved validation metric (including any task-based fallback or alias normalization) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
542 config_params["validation_metric"] = val_metric |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
543 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
544 conf: Dict[str, Any] = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
545 "model_type": "ecd", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
546 "input_features": [image_feat], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
547 "output_features": [output_feat], |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
548 "combiner": {"type": "concat"}, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
549 "trainer": { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
550 "epochs": epochs, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
551 "early_stop": early_stop, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
552 "batch_size": batch_size_cfg, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
553 "learning_rate": learning_rate, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
554 # set validation_metric when provided |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
555 **({"validation_metric": val_metric} if val_metric else {}), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
556 }, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
557 "preprocessing": { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
558 "split": split_config, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
559 "num_processes": num_processes, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
560 "in_memory": False, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
561 }, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
562 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
563 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
564 logger.debug("LudwigDirectBackend: Config dict built.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
565 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
566 yaml_str = yaml.dump(conf, sort_keys=False, indent=2) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
567 logger.info("LudwigDirectBackend: YAML config generated.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
568 return yaml_str |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
569 except Exception: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
570 logger.error( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
571 "LudwigDirectBackend: Failed to serialize YAML.", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
572 exc_info=True, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
573 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
574 raise |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
575 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
576 def _patch_torchvision_download(self) -> None: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
577 """ |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
578 Torchvision weight downloads sometimes fail checksum validation behind |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
579 corporate proxies that rewrite binaries. Skip hash checking to allow |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
580 pre-trained weights to load in those environments. |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
581 """ |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
582 if LudwigDirectBackend._torchvision_patched: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
583 return |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
584 try: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
585 import torch.hub as torch_hub |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
586 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
587 original = torch_hub.load_state_dict_from_url |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
588 original_download = torch_hub.download_url_to_file |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
589 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
590 def _no_hash(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
591 return original( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
592 url, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
593 model_dir=model_dir, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
594 map_location=map_location, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
595 progress=progress, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
596 check_hash=False, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
597 file_name=file_name, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
598 ) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
599 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
600 def _download_no_hash(url, dst, hash_prefix=None, progress=True): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
601 # Torchvision's download_url_to_file signature does not accept check_hash in older versions. |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
602 return original_download(url, dst, hash_prefix=None, progress=progress) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
603 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
604 torch_hub.load_state_dict_from_url = _no_hash # type: ignore[assignment] |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
605 torch_hub.download_url_to_file = _download_no_hash # type: ignore[assignment] |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
606 LudwigDirectBackend._torchvision_patched = True |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
607 logger.info("Disabled torchvision weight hash verification to avoid proxy-corrupted downloads.") |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
608 except Exception as exc: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
609 logger.warning(f"Could not patch torchvision download hash check: {exc}") |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
610 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
611 def run_experiment( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
612 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
613 dataset_path: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
614 config_path: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
615 output_dir: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
616 random_seed: int = 42, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
617 ) -> None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
618 """Invoke Ludwig's internal experiment_cli function to run the experiment.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
619 logger.info("LudwigDirectBackend: Starting experiment execution.") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
620 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
621 # Avoid strict hash validation for torchvision weights (common in proxied environments) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
622 self._patch_torchvision_download() |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
623 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
624 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
625 from ludwig.experiment import experiment_cli |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
626 except ImportError as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
627 logger.error( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
628 "LudwigDirectBackend: Could not import experiment_cli.", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
629 exc_info=True, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
630 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
631 raise RuntimeError("Ludwig import failed.") from e |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
632 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
633 output_dir.mkdir(parents=True, exist_ok=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
634 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
635 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
636 experiment_cli( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
637 dataset=str(dataset_path), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
638 config=str(config_path), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
639 output_directory=str(output_dir), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
640 random_seed=random_seed, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
641 skip_preprocessing=True, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
642 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
643 logger.info( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
644 f"LudwigDirectBackend: Experiment completed. Results in {output_dir}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
645 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
646 except TypeError as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
647 logger.error( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
648 "LudwigDirectBackend: Argument mismatch in experiment_cli call.", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
649 exc_info=True, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
650 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
651 raise RuntimeError("Ludwig argument error.") from e |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
652 except Exception: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
653 logger.error( |
|
18
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
654 "LudwigDirectBackend: Experiment execution error. " |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
655 "If this relates to validation_metric, confirm the XML task selection " |
|
bbf30253c99f
planemo upload for repository https://github.com/goeckslab/gleam.git commit c155acd3616dfac920b17653179d7bc38ba48e39
goeckslab
parents:
17
diff
changeset
|
656 "passes a metric that matches the inferred task type.", |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
657 exc_info=True, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
658 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
659 raise |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
660 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
661 def get_training_process(self, output_dir) -> Optional[Dict[str, Any]]: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
662 """Retrieve the learning rate used in the most recent Ludwig run.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
663 output_dir = Path(output_dir) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
664 exp_dirs = sorted( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
665 output_dir.glob("experiment_run*"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
666 key=lambda p: p.stat().st_mtime, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
667 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
668 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
669 if not exp_dirs: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
670 logger.warning(f"No experiment run directories found in {output_dir}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
671 return None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
672 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
673 progress_file = exp_dirs[-1] / "model" / "training_progress.json" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
674 if not progress_file.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
675 logger.warning(f"No training_progress.json found in {progress_file}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
676 return None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
677 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
678 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
679 with progress_file.open("r", encoding="utf-8") as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
680 data = json.load(f) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
681 return { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
682 "learning_rate": data.get("learning_rate"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
683 "batch_size": data.get("batch_size"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
684 "epoch": data.get("epoch"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
685 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
686 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
687 logger.warning(f"Failed to read training progress info: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
688 return {} |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
689 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
690 def convert_parquet_to_csv(self, output_dir: Path): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
691 """Convert the predictions Parquet file to CSV.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
692 output_dir = Path(output_dir) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
693 exp_dirs = sorted( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
694 output_dir.glob("experiment_run*"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
695 key=lambda p: p.stat().st_mtime, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
696 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
697 if not exp_dirs: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
698 logger.warning(f"No experiment run dirs found in {output_dir}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
699 return |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
700 exp_dir = exp_dirs[-1] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
701 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
702 csv_path = exp_dir / "predictions.csv" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
703 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
704 # Check if parquet file exists before trying to convert |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
705 if not parquet_path.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
706 logger.info(f"Predictions parquet file not found at {parquet_path}, skipping conversion") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
707 return |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
708 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
709 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
710 df = pd.read_parquet(parquet_path) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
711 df.to_csv(csv_path, index=False) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
712 logger.info(f"Converted Parquet to CSV: {csv_path}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
713 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
714 logger.error(f"Error converting Parquet to CSV: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
715 |
|
19
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
716 def _get_latest_experiment_dir(self, output_dir: Path) -> Optional[Path]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
717 """Return the most recent experiment_run* directory, if present.""" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
718 output_dir = Path(output_dir) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
719 exp_dirs = sorted( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
720 output_dir.glob("experiment_run*"), |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
721 key=lambda p: p.stat().st_mtime, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
722 ) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
723 return exp_dirs[-1] if exp_dirs else None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
724 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
725 def _extract_preprocessing_config( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
726 self, exp_dir: Path, config: Dict[str, Any] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
727 ) -> Tuple[Optional[Dict[str, Any]], Optional[Path]]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
728 """Parse Ludwig preprocessing settings from train_set_metadata or description.""" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
729 image_meta: Dict[str, Any] = {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
730 meta_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
731 if meta_path.exists(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
732 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
733 with meta_path.open("r", encoding="utf-8") as f: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
734 meta_json = json.load(f) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
735 image_list = meta_json.get("input_features") or [] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
736 if image_list: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
737 image_meta = image_list[0] or {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
738 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
739 logger.warning("Unable to read train_set_metadata: %s", exc) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
740 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
741 # Fallback to description config for preprocessing hints |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
742 desc_cfg: Dict[str, Any] = {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
743 desc_path = exp_dir / DESCRIPTION_FILE_NAME |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
744 if desc_path.exists(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
745 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
746 with desc_path.open("r", encoding="utf-8") as f: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
747 desc_json = json.load(f) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
748 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
749 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
750 logger.warning("Unable to read description.json for preprocessing: %s", exc) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
751 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
752 preprocessing = {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
753 if isinstance(image_meta, dict): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
754 preprocessing = image_meta.get("preprocessing") or {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
755 if not preprocessing and desc_cfg: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
756 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
757 preprocessing = ( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
758 desc_cfg.get("input_features", [{}])[0].get("preprocessing") or {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
759 ) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
760 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
761 preprocessing = {} |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
762 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
763 # If height/width are missing but max inferred dimensions exist, use them as fallback |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
764 if isinstance(preprocessing, dict): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
765 if not preprocessing.get("height") and preprocessing.get("infer_image_max_height"): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
766 preprocessing["height"] = preprocessing.get("infer_image_max_height") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
767 if not preprocessing.get("width") and preprocessing.get("infer_image_max_width"): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
768 preprocessing["width"] = preprocessing.get("infer_image_max_width") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
769 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
770 # Keep label path for downstream sampling |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
771 label_path = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
772 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
773 label_path_cfg = config.get("label_column_data_path") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
774 if label_path_cfg: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
775 label_path = Path(label_path_cfg) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
776 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
777 label_path = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
778 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
779 return preprocessing if isinstance(preprocessing, dict) else {}, label_path |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
780 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
781 def _find_last_conv_layer(self, encoder_obj: Any) -> Optional[Any]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
782 """Identify the last Conv2d layer within the encoder.""" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
783 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
784 import torch.nn as nn |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
785 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
786 return None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
787 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
788 target_model = encoder_obj |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
789 if hasattr(encoder_obj, "model"): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
790 target_model = encoder_obj.model |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
791 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
792 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
793 modules = list(target_model.named_modules()) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
794 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
795 return None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
796 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
797 for _, module in reversed(modules): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
798 if isinstance(module, nn.Conv2d): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
799 return module |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
800 return None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
801 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
802 def _generate_gradcam_heatmaps( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
803 self, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
804 exp_dir: Path, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
805 config: Dict[str, Any], |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
806 output_type: Optional[str], |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
807 ) -> Dict[str, Any]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
808 """Compute Grad-CAM overlays for convolutional encoders, when possible.""" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
809 result: Dict[str, Any] = { |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
810 "status": "skipped", |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
811 "reason": "", |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
812 "preview_paths": [], |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
813 "zip_path": None, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
814 "dir_path": None, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
815 } |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
816 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
817 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
818 import numpy as np |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
819 import torch |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
820 import torch.nn.functional as F |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
821 from matplotlib import cm |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
822 from PIL import Image |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
823 from ludwig.api import LudwigModel |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
824 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
825 result["reason"] = f"Missing dependency for Grad-CAM: {exc}" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
826 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
827 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
828 exp_dir = Path(exp_dir) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
829 model_dir = exp_dir / "model" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
830 if not model_dir.exists(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
831 result["reason"] = "Model directory not found; skipping Grad-CAM." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
832 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
833 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
834 preprocessing, label_path = self._extract_preprocessing_config(exp_dir, config) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
835 height = preprocessing.get("height") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
836 width = preprocessing.get("width") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
837 if not height or not width: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
838 result["reason"] = "Image resize/height not found in Ludwig preprocessing." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
839 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
840 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
841 label_csv = label_path if label_path and label_path.exists() else None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
842 if not label_csv: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
843 result["reason"] = "Prepared label CSV not available for Grad-CAM sampling." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
844 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
845 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
846 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
847 df_all = pd.read_csv(label_csv) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
848 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
849 result["reason"] = f"Could not read prepared CSV: {exc}" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
850 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
851 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
852 if IMAGE_PATH_COLUMN_NAME not in df_all.columns: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
853 result["reason"] = "Image column missing from prepared CSV; cannot build Grad-CAM inputs." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
854 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
855 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
856 # Prefer test split; otherwise fall back to the full dataset |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
857 df_candidates = df_all |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
858 if SPLIT_COLUMN_NAME in df_all.columns: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
859 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
860 df_candidates = df_all[df_all[SPLIT_COLUMN_NAME] == 2] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
861 if df_candidates.empty: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
862 df_candidates = df_all |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
863 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
864 df_candidates = df_all |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
865 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
866 # Cap the number of samples |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
867 df_candidates = df_candidates.head(12) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
868 if df_candidates.empty: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
869 result["reason"] = "No samples available for Grad-CAM generation." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
870 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
871 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
872 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
873 ludwig_model = LudwigModel.load(str(model_dir)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
874 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
875 result["reason"] = f"Unable to load LudwigModel for Grad-CAM: {exc}" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
876 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
877 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
878 base_model = getattr(ludwig_model, "model", None) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
879 if base_model is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
880 result["reason"] = "Ludwig model missing underlying torch model." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
881 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
882 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
883 image_feature_name = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
884 image_feature = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
885 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
886 for name, feat in getattr(base_model, "input_features", {}).items(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
887 if hasattr(feat, "encoder_obj"): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
888 image_feature_name = name |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
889 image_feature = feat |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
890 break |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
891 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
892 image_feature_name = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
893 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
894 if not image_feature or not image_feature_name: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
895 result["reason"] = "Image input feature not found; skipping Grad-CAM." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
896 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
897 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
898 target_layer = self._find_last_conv_layer(getattr(image_feature, "encoder_obj", None)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
899 if target_layer is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
900 result["reason"] = "No convolutional layer detected in the encoder (heatmaps unsupported)." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
901 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
902 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
903 standardize = preprocessing.get("standardize_image") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
904 mean = preprocessing.get("mean") or preprocessing.get("img_mean") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
905 std = preprocessing.get("std") or preprocessing.get("img_std") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
906 encoder_obj = getattr(image_feature, "encoder_obj", None) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
907 if hasattr(encoder_obj, "normalize_mean") and encoder_obj.normalize_mean: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
908 mean = encoder_obj.normalize_mean |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
909 if hasattr(encoder_obj, "normalize_std") and encoder_obj.normalize_std: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
910 std = encoder_obj.normalize_std |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
911 if isinstance(standardize, str) and standardize.lower() == "imagenet1k": |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
912 mean = [0.485, 0.456, 0.406] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
913 std = [0.229, 0.224, 0.225] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
914 if mean is None or std is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
915 result["reason"] = "Normalization parameters (mean/std) not found in the saved encoder; skipping heatmaps to avoid mismatch." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
916 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
917 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
918 output_feature_name = LABEL_COLUMN_NAME |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
919 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
920 if getattr(base_model, "output_features", None): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
921 output_feature_name = next(iter(base_model.output_features.keys())) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
922 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
923 output_feature_name = LABEL_COLUMN_NAME |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
924 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
925 device = torch.device("cpu") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
926 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
927 base_model.to(device) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
928 base_model.eval() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
929 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
930 logger.debug("Could not move model to CPU for Grad-CAM; continuing on default device.") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
931 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
932 heatmap_dir = exp_dir / "feature_importance_examples" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
933 heatmap_dir.mkdir(parents=True, exist_ok=True) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
934 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
935 def _load_tensor(image_path: Path) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
936 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
937 img = Image.open(image_path).convert("RGB") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
938 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
939 return None, None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
940 resized = img.resize((int(width), int(height))) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
941 arr = np.asarray(resized).astype("float32") / 255.0 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
942 arr = np.transpose(arr, (2, 0, 1)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
943 tensor = torch.from_numpy(arr) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
944 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
945 mean_tensor = torch.tensor(mean).view(-1, 1, 1) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
946 std_tensor = torch.tensor(std).view(-1, 1, 1) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
947 tensor = (tensor - mean_tensor) / std_tensor |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
948 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
949 return None, None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
950 return tensor.unsqueeze(0).to(device), resized |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
951 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
952 generated: List[Path] = [] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
953 pairs: List[Tuple[Path, Path]] = [] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
954 image_root = label_csv.parent |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
955 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
956 for _, row in df_candidates.iterrows(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
957 raw_path = row.get(IMAGE_PATH_COLUMN_NAME) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
958 if not isinstance(raw_path, str): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
959 continue |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
960 abs_path = (image_root / raw_path).resolve() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
961 if not abs_path.exists(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
962 continue |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
963 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
964 tensor, resized_img = _load_tensor(abs_path) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
965 if tensor is None or resized_img is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
966 continue |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
967 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
968 activations: List[torch.Tensor] = [] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
969 gradients: List[torch.Tensor] = [] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
970 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
971 def _fwd_hook(_module, _inp, output): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
972 activations.append(output) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
973 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
974 def _bwd_hook(_module, _grad_in, grad_out): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
975 if grad_out and isinstance(grad_out[0], torch.Tensor): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
976 gradients.append(grad_out[0]) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
977 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
978 handle_fwd = target_layer.register_forward_hook(_fwd_hook) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
979 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
980 handle_bwd = target_layer.register_full_backward_hook(_bwd_hook) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
981 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
982 handle_bwd = target_layer.register_backward_hook(_bwd_hook) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
983 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
984 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
985 base_model.zero_grad(set_to_none=True) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
986 with torch.enable_grad(): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
987 outputs = base_model({image_feature_name: tensor}) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
988 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
989 logits = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
990 if isinstance(outputs, dict): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
991 feature_out = outputs.get(output_feature_name) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
992 if isinstance(feature_out, dict): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
993 logits = feature_out.get("logits") or feature_out.get("logit") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
994 elif isinstance(feature_out, torch.Tensor): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
995 logits = feature_out |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
996 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
997 # Ludwig 0.10+ uses namespaced keys: "<feature>::logits" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
998 if logits is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
999 ns_key = f"{output_feature_name}::logits" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1000 if isinstance(outputs.get(ns_key), torch.Tensor): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1001 logits = outputs[ns_key] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1002 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1003 # Fallback: a top-level logits tensor |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1004 if logits is None and isinstance(outputs.get("logits"), torch.Tensor): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1005 logits = outputs.get("logits") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1006 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1007 if logits is None: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1008 raise ValueError("Could not locate logits for Grad-CAM.") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1009 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1010 if logits.dim() == 1: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1011 target_logit = logits.unsqueeze(0) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1012 else: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1013 target_class = 0 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1014 if output_type != "regression" and logits.shape[-1] > 1: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1015 target_class = int(torch.argmax(logits, dim=-1).item()) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1016 target_logit = logits[:, target_class] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1017 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1018 target_logit.sum().backward() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1019 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1020 if not activations or not gradients: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1021 raise ValueError("Missing activations or gradients for Grad-CAM.") |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1022 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1023 act = activations[-1] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1024 grad = gradients[-1] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1025 weights = grad.mean(dim=(2, 3), keepdim=True) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1026 cam = (weights * act).sum(dim=1) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1027 cam = torch.relu(cam) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1028 cam = cam.squeeze(0) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1029 cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(int(height), int(width)), mode="bilinear", align_corners=False) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1030 cam = cam.squeeze().detach().cpu().numpy() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1031 if cam.max() > 0: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1032 cam = cam / cam.max() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1033 heatmap_rgba = np.uint8(cm.get_cmap("jet")(cam) * 255) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1034 heatmap_img = Image.fromarray(heatmap_rgba).convert("RGBA").resize(resized_img.size) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1035 overlay = Image.blend(resized_img.convert("RGBA"), heatmap_img, alpha=0.45) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1036 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1037 stem = Path(raw_path).stem |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1038 out_path = heatmap_dir / f"{stem}_gradcam.png" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1039 overlay.save(out_path) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1040 orig_path = heatmap_dir / f"{stem}_original.png" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1041 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1042 resized_img.save(orig_path) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1043 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1044 orig_path = None |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1045 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1046 generated.append(out_path) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1047 if orig_path: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1048 pairs.append((orig_path, out_path)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1049 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1050 logger.warning("Grad-CAM failed for %s: %s", raw_path, exc) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1051 finally: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1052 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1053 handle_fwd.remove() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1054 handle_bwd.remove() |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1055 except Exception: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1056 pass |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1057 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1058 if not generated: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1059 result["reason"] = "No heatmaps were generated (model may be non-convolutional or preprocessing missing)." |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1060 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1061 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1062 zip_path = exp_dir / "feature_importance_examples.zip" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1063 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1064 with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1065 for png in generated: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1066 zf.write(png, png.name) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1067 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1068 logger.warning("Failed to create Grad-CAM zip: %s", exc) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1069 |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1070 result.update( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1071 { |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1072 "status": "generated", |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1073 "preview_paths": generated[:6], |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1074 "pairs": pairs[:6], |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1075 "zip_path": zip_path if zip_path.exists() else None, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1076 "dir_path": heatmap_dir, |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1077 } |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1078 ) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1079 return result |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1080 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1081 @staticmethod |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1082 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1083 """Pull the first numeric metric list we can find for the requested split.""" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1084 if not isinstance(stats, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1085 return None, None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1086 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1087 split_stats = stats.get(split, {}) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1088 ordered_metrics: List[Tuple[str, List[float]]] = [] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1089 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1090 def _append_metrics(metric_map: Dict[str, Any]) -> None: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1091 for metric_name, values in metric_map.items(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1092 if isinstance(values, list) and any(isinstance(v, (int, float)) for v in values): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1093 ordered_metrics.append((metric_name, values)) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1094 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1095 if isinstance(split_stats, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1096 combined = split_stats.get("combined") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1097 if isinstance(combined, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1098 _append_metrics(combined) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1099 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1100 for feature_name, feature_metrics in split_stats.items(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1101 if feature_name == "combined" or not isinstance(feature_metrics, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1102 continue |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1103 _append_metrics(feature_metrics) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1104 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1105 if prefer: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1106 for metric_name, values in ordered_metrics: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1107 if metric_name == prefer: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1108 return metric_name, values |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1109 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1110 return ordered_metrics[0] if ordered_metrics else (None, None) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1111 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1112 def generate_plots(self, output_dir: Path) -> None: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1113 """Generate Ludwig visualizations (train/val + test) for the latest experiment run.""" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1114 logger.info("Generating Ludwig visualizations (train/val + test)…") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1115 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1116 # Train/validation visualizations |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1117 train_plots = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1118 "learning_curves", |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1119 } |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1120 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1121 # Test visualizations (multi-class transparency) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1122 test_plots = { |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1123 "confusion_matrix", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1124 "compare_performance", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1125 "compare_classifiers_multiclass_multimetric", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1126 "frequency_vs_f1", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1127 "confidence_thresholding", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1128 "confidence_thresholding_data_vs_acc", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1129 "confidence_thresholding_data_vs_acc_subset", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1130 "confidence_thresholding_data_vs_acc_subset_per_class", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1131 # Binary-only visualizations will still be attempted; multi-class replacements handled elsewhere |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1132 "binary_threshold_vs_metric", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1133 "roc_curves", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1134 "precision_recall_curves", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1135 "calibration_1_vs_all", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1136 "calibration_multiclass", |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1137 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1138 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1139 output_dir = Path(output_dir) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1140 exp_dirs = sorted( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1141 output_dir.glob("experiment_run*"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1142 key=lambda p: p.stat().st_mtime, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1143 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1144 if not exp_dirs: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1145 logger.warning(f"No experiment run dirs found in {output_dir}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1146 return |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1147 exp_dir = exp_dirs[-1] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1148 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1149 viz_dir = exp_dir / "visualizations" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1150 viz_dir.mkdir(exist_ok=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1151 train_viz = viz_dir / "train" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1152 test_viz = viz_dir / "test" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1153 train_viz.mkdir(parents=True, exist_ok=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1154 test_viz.mkdir(parents=True, exist_ok=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1155 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1156 def _check(p: Path) -> Optional[str]: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1157 return str(p) if p.exists() else None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1158 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1159 training_stats = _check(exp_dir / "training_statistics.json") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1160 test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1161 gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1162 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1163 dataset_path = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1164 split_file = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1165 desc = exp_dir / DESCRIPTION_FILE_NAME |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1166 if desc.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1167 with open(desc, "r") as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1168 cfg = json.load(f) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1169 dataset_path = _check(Path(cfg.get("dataset", ""))) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1170 split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1171 model_name = cfg.get("model_name", "model") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1172 else: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1173 model_name = "model" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1174 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1175 output_feature = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1176 if desc.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1177 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1178 output_feature = cfg["config"]["output_features"][0]["name"] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1179 except Exception: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1180 pass |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1181 if not output_feature and test_stats: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1182 with open(test_stats, "r") as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1183 stats = json.load(f) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1184 output_feature = next(iter(stats.keys()), "") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1185 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1186 probs_path = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1187 prob_candidates = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1188 exp_dir / f"{LABEL_COLUMN_NAME}_probabilities.csv", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1189 exp_dir / f"{output_feature}_probabilities.csv" if output_feature else None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1190 exp_dir / "probabilities.csv", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1191 exp_dir / "predictions.csv", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1192 exp_dir / PREDICTIONS_PARQUET_FILE_NAME, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1193 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1194 for cand in prob_candidates: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1195 if cand and Path(cand).exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1196 probs_path = str(cand) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1197 break |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1198 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1199 viz_registry = get_visualizations_registry() |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1200 if not viz_registry: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1201 logger.warning( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1202 "Ludwig visualizations registry not available; train/test PNGs will be skipped." |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1203 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1204 return |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1205 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1206 base_kwargs = { |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1207 "training_statistics": [training_stats] if training_stats else [], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1208 "test_statistics": [test_stats] if test_stats else [], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1209 "probabilities": [probs_path] if probs_path else [], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1210 "output_feature_name": output_feature, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1211 "ground_truth_split": 2, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1212 "top_n_classes": [20], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1213 "top_k": 3, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1214 "metrics": ["f1", "precision", "recall", "accuracy"], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1215 "positive_label": 0, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1216 "ground_truth_metadata": gt_metadata, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1217 "ground_truth": dataset_path, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1218 "split_file": split_file, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1219 "output_directory": None, # set per plot below |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1220 "normalize": False, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1221 "file_format": "png", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1222 "model_names": [model_name], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1223 } |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1224 for viz_name, viz_func in viz_registry.items(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1225 if viz_name in train_plots: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1226 viz_dir_plot = train_viz |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1227 elif viz_name in test_plots: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1228 viz_dir_plot = test_viz |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1229 else: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1230 continue |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1231 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1232 try: |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1233 # Build per-viz kwargs based on the function signature to avoid unexpected args |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1234 sig_params = set(inspect.signature(viz_func).parameters.keys()) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1235 call_kwargs = { |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1236 k: v |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1237 for k, v in base_kwargs.items() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1238 if k in sig_params and v is not None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1239 } |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1240 if "output_directory" in sig_params: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1241 call_kwargs["output_directory"] = str(viz_dir_plot) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1242 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1243 viz_func( |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1244 **call_kwargs, |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1245 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1246 logger.info(f"✔ Generated {viz_name}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1247 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1248 logger.warning(f"✘ Skipped {viz_name}: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1249 logger.info(f"All visualizations written to {viz_dir}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1250 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1251 def generate_html_report( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1252 self, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1253 title: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1254 output_dir: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1255 config: dict, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1256 split_info: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1257 ) -> Path: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1258 """Assemble an HTML report from visualizations under train_val/ and test/ folders.""" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1259 cwd = Path.cwd() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1260 report_name = title.lower().replace(" ", "_") + "_report.html" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1261 report_path = cwd / report_name |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1262 output_dir = Path(output_dir) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1263 output_type = None |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1264 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1265 exp_dirs = sorted( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1266 output_dir.glob("experiment_run*"), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1267 key=lambda p: p.stat().st_mtime, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1268 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1269 if not exp_dirs: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1270 raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1271 exp_dir = exp_dirs[-1] |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1272 train_set_metadata_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1273 label_metadata_path = config.get("label_column_data_path") |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1274 if label_metadata_path: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1275 label_metadata_path = Path(label_metadata_path) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1276 dataset_path_from_desc: Optional[Path] = None |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1277 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1278 # Pull additional config details from description.json if available |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1279 config_for_summary = dict(config) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1280 if "target_column" not in config_for_summary or not config_for_summary.get("target_column"): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1281 config_for_summary["target_column"] = LABEL_COLUMN_NAME |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1282 desc_path = exp_dir / DESCRIPTION_FILE_NAME |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1283 if desc_path.exists(): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1284 try: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1285 with open(desc_path, "r") as f: |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1286 desc_json = json.load(f) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1287 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1288 encoder_cfg = ( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1289 desc_cfg.get("input_features", [{}])[0].get("encoder", {}) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1290 if isinstance(desc_cfg.get("input_features", [{}]), list) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1291 else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1292 ) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1293 output_cfg = ( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1294 desc_cfg.get("output_features", [{}])[0] |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1295 if isinstance(desc_cfg.get("output_features", [{}]), list) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1296 else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1297 ) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1298 trainer_cfg = desc_cfg.get("trainer", {}) if isinstance(desc_cfg, dict) else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1299 loss_cfg = output_cfg.get("loss", {}) if isinstance(output_cfg, dict) else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1300 opt_cfg = trainer_cfg.get("optimizer", {}) if isinstance(trainer_cfg, dict) else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1301 clip_cfg = trainer_cfg.get("gradient_clipping", {}) if isinstance(trainer_cfg, dict) else {} |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1302 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1303 arch_type = encoder_cfg.get("type") |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1304 arch_variant = encoder_cfg.get("model_variant") |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1305 arch_custom = encoder_cfg.get("custom_model") |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1306 arch_name = None |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1307 if arch_custom: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1308 arch_name = str(arch_custom) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1309 if arch_type: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1310 arch_base = str(arch_type).replace("_", " ").title() |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1311 arch_type_name = ( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1312 f"{arch_base} {arch_variant}" if arch_variant is not None else arch_base |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1313 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1314 # Prefer explicit custom model names (e.g., MetaFormer) but fall back to encoder type |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1315 arch_name = arch_name or arch_type_name |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1316 if not arch_name and config.get("model_name"): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1317 # As a last resort, show the user-selected model name (handles custom/MetaFormer cases) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1318 arch_name = str(config.get("model_name")) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1319 |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1320 summary_fields = { |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1321 "architecture": arch_name, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1322 "model_variant": arch_variant, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1323 "pretrained": encoder_cfg.get("use_pretrained"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1324 "trainable": encoder_cfg.get("trainable"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1325 "target_column": output_cfg.get("column"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1326 "task_type": output_cfg.get("type"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1327 "validation_metric": trainer_cfg.get("validation_metric"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1328 "loss_function": loss_cfg.get("type"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1329 "threshold": output_cfg.get("threshold"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1330 "total_epochs": trainer_cfg.get("epochs"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1331 "early_stop": trainer_cfg.get("early_stop"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1332 "batch_size": trainer_cfg.get("batch_size"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1333 "optimizer": opt_cfg.get("type"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1334 "learning_rate": trainer_cfg.get("learning_rate"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1335 "random_seed": desc_cfg.get("random_seed") or config.get("random_seed"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1336 "use_mixed_precision": trainer_cfg.get("use_mixed_precision"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1337 "gradient_clipping": clip_cfg.get("clipglobalnorm"), |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1338 } |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1339 for k, v in summary_fields.items(): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1340 if v is None: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1341 continue |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1342 # Do not override user-passed target/image column names in config |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1343 if k in {"target_column", "image_column"} and config_for_summary.get(k): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1344 continue |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1345 config_for_summary.setdefault(k, v) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1346 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1347 dataset_field = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1348 if isinstance(desc_json, dict): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1349 dataset_field = desc_json.get("dataset") or desc_cfg.get("dataset") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1350 if dataset_field: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1351 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1352 dataset_path_from_desc = Path(dataset_field) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1353 except TypeError: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1354 dataset_path_from_desc = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1355 if dataset_path_from_desc and (not label_metadata_path or not label_metadata_path.exists()): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1356 label_metadata_path = dataset_path_from_desc |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1357 except Exception as e: # pragma: no cover - defensive |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1358 logger.warning(f"Could not merge description.json into config summary: {e}") |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1359 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1360 base_viz_dir = exp_dir / "visualizations" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1361 train_viz_dir = base_viz_dir / "train" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1362 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1363 html = get_html_template() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1364 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1365 # Extra CSS & JS: center Plotly and enable CSV download for predictions table |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1366 html += """ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1367 <style> |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1368 /* Center Plotly figures (both wrapper and native classes) */ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1369 .plotly-center { display: flex; justify-content: center; } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1370 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1371 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1372 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1373 /* Download button for predictions table */ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1374 .download-btn { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1375 padding: 8px 12px; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1376 border: 1px solid #4CAF50; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1377 background: #4CAF50; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1378 color: white; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1379 border-radius: 6px; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1380 cursor: pointer; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1381 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1382 .download-btn:hover { filter: brightness(0.95); } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1383 .preds-controls { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1384 display: flex; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1385 justify-content: flex-end; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1386 gap: 8px; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1387 margin: 8px 0; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1388 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1389 </style> |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1390 <script> |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1391 function tableToCSV(table){ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1392 const rows = Array.from(table.querySelectorAll('tr')); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1393 return rows.map(row => |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1394 Array.from(row.querySelectorAll('th,td')).map(cell => { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1395 let text = cell.innerText.replace(/\\r?\\n|\\r/g,' ').trim(); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1396 if (text.includes('"') || text.includes(',')) { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1397 text = '"' + text.replace(/"/g,'""') + '"'; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1398 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1399 return text; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1400 }).join(',') |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1401 ).join('\\n'); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1402 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1403 document.addEventListener('DOMContentLoaded', function(){ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1404 const btn = document.getElementById('downloadPredsCsv'); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1405 if(btn){ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1406 btn.addEventListener('click', function(){ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1407 const tbl = document.querySelector('.predictions-table'); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1408 if(!tbl){ alert('Predictions table not found.'); return; } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1409 const csv = tableToCSV(tbl); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1410 const blob = new Blob([csv], {type: 'text/csv;charset=utf-8;'}); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1411 const url = URL.createObjectURL(blob); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1412 const a = document.createElement('a'); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1413 a.href = url; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1414 a.download = 'ground_truth_vs_predictions.csv'; |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1415 document.body.appendChild(a); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1416 a.click(); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1417 document.body.removeChild(a); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1418 URL.revokeObjectURL(url); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1419 }); |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1420 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1421 }); |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1422 </script> |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1423 """ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1424 html += f"<h1>{title}</h1>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1425 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1426 def append_plot_blocks(tab_html: str, plots: List[Dict[str, str]], title_suffix: str = "") -> str: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1427 """Append Plotly blocks to a tab with consistent markup.""" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1428 if not plots: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1429 return tab_html |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1430 suffix = title_suffix or "" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1431 for plot in plots: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1432 tab_html += ( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1433 f"<h2 style='text-align: center;'>{plot['title']}{suffix}</h2>" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1434 f"<div class='plotly-center'>{plot['html']}</div>" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1435 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1436 return tab_html |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1437 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1438 def build_dataset_overview( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1439 label_metadata: Optional[Path], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1440 output_type: Optional[str], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1441 split_probabilities: Optional[List[float]], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1442 label_split_counts: Optional[List[Dict[str, int]]] = None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1443 split_counts: Optional[Dict[int, int]] = None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1444 fallback_dataset: Optional[Path] = None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1445 ) -> str: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1446 """Summarize dataset distribution across splits using the actual split config.""" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1447 if label_split_counts: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1448 # Use the actual counts captured during data prep instead of heuristics. |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1449 return format_dataset_overview_table(label_split_counts, regression_mode=False) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1450 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1451 if output_type == "regression" and split_counts: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1452 rows = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1453 {"split": "train", "count": int(split_counts.get(0, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1454 {"split": "validation", "count": int(split_counts.get(1, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1455 {"split": "test", "count": int(split_counts.get(2, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1456 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1457 return format_dataset_overview_table(rows, regression_mode=True) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1458 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1459 candidate_paths: List[Path] = [] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1460 if label_metadata and label_metadata.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1461 candidate_paths.append(label_metadata) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1462 if fallback_dataset and fallback_dataset.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1463 candidate_paths.append(fallback_dataset) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1464 if not candidate_paths: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1465 return format_dataset_overview_table([]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1466 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1467 def _normalize_split_probabilities( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1468 probs: Optional[List[float]], |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1469 ) -> Optional[List[float]]: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1470 if not probs or len(probs) != 3: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1471 return None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1472 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1473 probs = [float(p) for p in probs] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1474 except (TypeError, ValueError): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1475 return None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1476 total = sum(probs) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1477 if total <= 0: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1478 return None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1479 return [p / total for p in probs] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1480 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1481 def _split_counts_from_column(df: pd.DataFrame) -> Dict[int, int]: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1482 if SPLIT_COLUMN_NAME not in df.columns: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1483 return {} |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1484 split_series = pd.to_numeric( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1485 df[SPLIT_COLUMN_NAME], errors="coerce" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1486 ).dropna() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1487 if split_series.empty: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1488 return {} |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1489 split_series = split_series.astype(int) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1490 return split_series.value_counts().to_dict() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1491 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1492 def _split_counts_from_probs(total: int, probs: List[float]) -> Dict[int, int]: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1493 train_n = int(total * probs[0]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1494 val_n = int(total * probs[1]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1495 test_n = max(0, total - train_n - val_n) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1496 return {0: train_n, 1: val_n, 2: test_n} |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1497 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1498 fallback_rows: Optional[List[Dict[str, int]]] = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1499 for meta_path in candidate_paths: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1500 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1501 df_labels = pd.read_csv(meta_path) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1502 probs = _normalize_split_probabilities(split_probabilities) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1503 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1504 # Regression (or missing label column): only need split counts |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1505 if output_type == "regression" or LABEL_COLUMN_NAME not in df_labels.columns: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1506 split_counts_found = _split_counts_from_column(df_labels) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1507 if split_counts_found: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1508 rows = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1509 {"split": "train", "count": int(split_counts_found.get(0, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1510 {"split": "validation", "count": int(split_counts_found.get(1, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1511 {"split": "test", "count": int(split_counts_found.get(2, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1512 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1513 return format_dataset_overview_table(rows, regression_mode=True) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1514 if probs and fallback_rows is None: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1515 split_counts_found = _split_counts_from_probs(len(df_labels), probs) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1516 fallback_rows = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1517 {"split": "train", "count": int(split_counts_found.get(0, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1518 {"split": "validation", "count": int(split_counts_found.get(1, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1519 {"split": "test", "count": int(split_counts_found.get(2, 0))}, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1520 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1521 continue |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1522 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1523 # Classification: prefer actual split assignments; fall back to configured probabilities |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1524 if SPLIT_COLUMN_NAME in df_labels.columns: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1525 df_counts = df_labels[[LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]].copy() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1526 df_counts[SPLIT_COLUMN_NAME] = pd.to_numeric( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1527 df_counts[SPLIT_COLUMN_NAME], errors="coerce" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1528 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1529 df_counts = df_counts.dropna(subset=[SPLIT_COLUMN_NAME]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1530 if df_counts.empty: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1531 continue |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1532 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1533 df_counts[SPLIT_COLUMN_NAME] = df_counts[SPLIT_COLUMN_NAME].astype(int) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1534 df_counts = df_counts.dropna(subset=[LABEL_COLUMN_NAME]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1535 counts = ( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1536 df_counts.groupby([LABEL_COLUMN_NAME, SPLIT_COLUMN_NAME]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1537 .size() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1538 .unstack(fill_value=0) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1539 .sort_index() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1540 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1541 rows = [] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1542 for lbl, row in counts.iterrows(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1543 rows.append( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1544 { |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1545 "label": str(lbl), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1546 "train": int(row.get(0, 0)), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1547 "validation": int(row.get(1, 0)), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1548 "test": int(row.get(2, 0)), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1549 } |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1550 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1551 return format_dataset_overview_table(rows) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1552 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1553 if probs: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1554 label_series = df_labels[LABEL_COLUMN_NAME].dropna() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1555 label_counts = label_series.value_counts().sort_index() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1556 if label_counts.empty: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1557 continue |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1558 rows = [] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1559 for lbl, count in label_counts.items(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1560 train_n = int(count * probs[0]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1561 val_n = int(count * probs[1]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1562 test_n = max(0, count - train_n - val_n) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1563 rows.append( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1564 { |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1565 "label": str(lbl), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1566 "train": train_n, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1567 "validation": val_n, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1568 "test": test_n, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1569 } |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1570 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1571 fallback_rows = fallback_rows or rows |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1572 except Exception as exc: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1573 logger.warning("Failed to build dataset overview from %s: %s", meta_path, exc) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1574 continue |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1575 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1576 if fallback_rows: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1577 return format_dataset_overview_table(fallback_rows, regression_mode=output_type == "regression") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1578 return format_dataset_overview_table([]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1579 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1580 metrics_html = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1581 train_val_metrics_html = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1582 test_metrics_html = "" |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1583 output_type = None |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1584 train_stats_path = exp_dir / "training_statistics.json" |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1585 test_stats_path = exp_dir / TEST_STATISTICS_FILE_NAME |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1586 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1587 if train_stats_path.exists() and test_stats_path.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1588 with open(train_stats_path) as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1589 train_stats = json.load(f) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1590 with open(test_stats_path) as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1591 test_stats = json.load(f) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1592 output_type = detect_output_type(test_stats) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1593 metrics_html = format_stats_table_html(train_stats, test_stats, output_type) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1594 train_val_metrics_html = format_train_val_stats_table_html( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1595 train_stats, test_stats |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1596 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1597 test_metrics_html = format_test_merged_stats_table_html( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1598 extract_metrics_from_json(train_stats, test_stats, output_type)[ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1599 "test" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1600 ], output_type |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1601 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1602 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1603 logger.warning( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1604 f"Could not load stats for HTML report: {type(e).__name__}: {e}" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1605 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1606 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1607 if not output_type: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1608 # Fallback to configured task type when stats are unavailable (e.g., failed run). |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1609 output_type = ( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1610 str(config_for_summary.get("task_type")).lower() |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1611 if config_for_summary.get("task_type") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1612 else None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1613 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1614 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1615 dataset_overview_html = build_dataset_overview( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1616 label_metadata_path, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1617 output_type, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1618 config.get("split_probabilities"), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1619 config.get("label_split_counts"), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1620 config.get("split_counts"), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1621 dataset_path_from_desc, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1622 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1623 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1624 config_html = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1625 training_progress = self.get_training_process(output_dir) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1626 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1627 config_html = format_config_table_html( |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1628 config_for_summary, split_info, training_progress, output_type |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1629 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1630 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1631 logger.warning(f"Could not load config for HTML report: {e}") |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1632 config_html = ( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1633 "<h2 style='text-align: center;'>Model and Training Summary</h2>" |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1634 "<p style='text-align:center; color:#666;'>Configuration details unavailable.</p>" |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1635 ) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1636 if not config_html: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1637 config_html = ( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1638 "<h2 style='text-align: center;'>Model and Training Summary</h2>" |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1639 "<p style='text-align:center; color:#666;'>No configuration details found.</p>" |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1640 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1641 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1642 # ---------- image rendering with exclusions ---------- |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1643 def render_img_section( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1644 title: str, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1645 dir_path: Path, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1646 output_type: str = None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1647 exclude_names: Optional[set] = None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1648 ) -> str: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1649 if not dir_path.exists(): |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1650 return "" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1651 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1652 exclude_names = exclude_names or set() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1653 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1654 # Search recursively because Ludwig can nest figures under per-feature folders |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1655 imgs = list(dir_path.rglob("*.png")) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1656 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1657 # Exclude ROC curves and standard confusion matrices (keep only entropy version) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1658 default_exclude = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1659 # "roc_curves.png", # Remove ROC curves from test tab |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1660 "confusion_matrix__label_top5.png", # Remove standard confusion matrix |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1661 "confusion_matrix__label_top10.png", # Remove duplicate |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1662 "confusion_matrix__label_top6.png", # Remove duplicate |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1663 "confusion_matrix_entropy__label_top10.png", # Keep only top5 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1664 "confusion_matrix_entropy__label_top6.png", # Keep only top5 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1665 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1666 title_is_test = title.lower().startswith("test") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1667 if title_is_test and output_type == "binary": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1668 default_exclude.update( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1669 { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1670 "confusion_matrix__label_top2.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1671 "confusion_matrix_entropy__label_top2.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1672 "roc_curves_from_prediction_statistics.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1673 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1674 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1675 elif title_is_test and output_type == "category": |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1676 default_exclude.update( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1677 { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1678 "compare_classifiers_multiclass_multimetric__label_best10.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1679 "compare_classifiers_multiclass_multimetric__label_sorted.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1680 "compare_classifiers_multiclass_multimetric__label_worst10.png", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1681 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1682 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1683 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1684 imgs = [ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1685 img |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1686 for img in imgs |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1687 if img.name not in default_exclude |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1688 and img.name not in exclude_names |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1689 and not ( |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1690 "learning_curves" in img.stem |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1691 and "loss" in img.stem |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1692 and "label" in img.stem |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1693 ) |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1694 ] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1695 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1696 if not imgs: |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1697 return "" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1698 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1699 # Sort images by name for consistent ordering (works with string and numeric labels) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1700 imgs = sorted(imgs, key=lambda x: x.name) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1701 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1702 html_section = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1703 custom_titles = { |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1704 "compare_classifiers_multiclass_multimetric__label_top10": "Metric Comparison by Label", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1705 "compare_classifiers_performance_from_prob": "Label Metric Comparison by Probability", |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1706 } |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1707 for img in imgs: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1708 b64 = encode_image_to_base64(str(img)) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1709 default_title = img.stem.replace("_", " ").title() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1710 img_title = custom_titles.get(img.stem, default_title) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1711 html_section += ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1712 f"<h2 style='text-align: center;'>{img_title}</h2>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1713 f'<div class="plot" style="margin-bottom:20px;text-align:center;">' |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1714 f'<img src="data:image/png;base64,{b64}" ' |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1715 f'style="max-width:90%;max-height:600px;border:1px solid #ddd;" />' |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1716 f"</div>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1717 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1718 return html_section |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1719 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1720 # Show dataset overview, performance first, then config |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1721 predictions_csv_path = exp_dir / "predictions.csv" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1722 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1723 tab1_content = dataset_overview_html + metrics_html + config_html |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1724 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1725 tab2_content = train_val_metrics_html |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1726 # Preload binary threshold plot so it appears first in Train/Val tab |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1727 threshold_plot = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1728 threshold_value = ( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1729 config_for_summary.get("threshold") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1730 if config_for_summary.get("threshold") is not None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1731 else config.get("threshold") |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1732 ) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1733 if threshold_value is None and output_type == "binary": |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1734 threshold_value = 0.5 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1735 if output_type == "binary" and predictions_csv_path.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1736 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1737 threshold_plot = build_binary_threshold_plot( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1738 str(predictions_csv_path), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1739 label_data_path=str(config.get("label_column_data_path")) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1740 if config.get("label_column_data_path") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1741 else None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1742 split_value=1, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1743 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1744 except Exception as e: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1745 logger.warning(f"Could not generate validation threshold plot: {e}") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1746 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1747 if train_stats_path.exists(): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1748 try: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1749 if output_type == "regression": |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1750 tv_plots = build_regression_train_val_plots(str(train_stats_path)) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1751 tab2_content = append_plot_blocks(tab2_content, tv_plots) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1752 else: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1753 tv_plots = build_train_validation_plots(str(train_stats_path)) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1754 # Add threshold plot first, then other train/val plots |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1755 if threshold_plot: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1756 tab2_content = append_plot_blocks(tab2_content, [threshold_plot]) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1757 # Only append once; avoid duplicates if added elsewhere |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1758 threshold_plot = None |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1759 tab2_content = append_plot_blocks(tab2_content, tv_plots) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1760 if threshold_plot or tv_plots: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1761 logger.info( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1762 f"Added {len(tv_plots) + (1 if threshold_plot else 0)} train/val diagnostic plots" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1763 ) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1764 except Exception as e: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1765 logger.warning(f"Could not generate train/val plots: {e}") |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1766 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1767 # Only include training PNGs for regression; classification is handled by filtered Plotly plots |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1768 if output_type == "regression": |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1769 tab2_content += render_img_section( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1770 "Training and Validation Visualizations", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1771 train_viz_dir, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1772 output_type, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1773 exclude_names={ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1774 "compare_classifiers_performance_from_prob.png", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1775 "roc_curves_from_prediction_statistics.png", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1776 "precision_recall_curves_from_prediction_statistics.png", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1777 "precision_recall_curve.png", |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1778 }, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1779 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1780 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1781 # Validation diagnostics (calibration/threshold) from predictions.csv, using split=1 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1782 if output_type in ("binary", "category") and predictions_csv_path.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1783 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1784 val_diag_plots = build_prediction_diagnostics( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1785 str(predictions_csv_path), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1786 label_data_path=str(config.get("label_column_data_path")) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1787 if config.get("label_column_data_path") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1788 else None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1789 split_value=1, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1790 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1791 val_conf_plots = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1792 p for p in val_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1793 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1794 tab2_content = append_plot_blocks( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1795 tab2_content, val_conf_plots, " (Validation)" |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1796 ) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1797 except Exception as e: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1798 logger.warning(f"Could not generate validation diagnostics: {e}") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1799 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1800 # --- Predictions vs Ground Truth table (REGRESSION ONLY) --- |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1801 preds_section = "" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1802 parquet_path = exp_dir / PREDICTIONS_PARQUET_FILE_NAME |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1803 if output_type == "regression" and parquet_path.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1804 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1805 # 1) load predictions from Parquet |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1806 df_preds = pd.read_parquet(parquet_path).reset_index(drop=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1807 # assume the column containing your model's prediction is named "prediction" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1808 # or contains that substring: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1809 pred_col = next( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1810 (c for c in df_preds.columns if "prediction" in c.lower()), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1811 None, |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1812 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1813 if pred_col is None: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1814 raise ValueError("No prediction column found in Parquet output") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1815 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1816 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1817 # 2) load ground truth for the test split from prepared CSV |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1818 df_all = pd.read_csv(config["label_column_data_path"]) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1819 df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1820 LABEL_COLUMN_NAME |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1821 ].reset_index(drop=True) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1822 # 3) concatenate side-by-side |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1823 df_table = pd.concat([df_gt, df_pred], axis=1) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1824 df_table.columns = [LABEL_COLUMN_NAME, "prediction"] |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1825 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1826 # 4) render as HTML |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1827 preds_html = df_table.to_html(index=False, classes="predictions-table") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1828 preds_section = ( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1829 "<h2 style='text-align: center;'>Ground Truth vs. Predictions</h2>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1830 "<div class='preds-controls'>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1831 "<button id='downloadPredsCsv' class='download-btn'>Download CSV</button>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1832 "</div>" |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1833 "<div class='scroll-rows-30' style='overflow-x:auto; overflow-y:auto; max-height:350px; margin-bottom:20px;'>" |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1834 + preds_html |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1835 + "</div>" |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1836 ) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1837 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1838 logger.warning(f"Could not build Predictions vs GT table: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1839 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1840 tab3_content = test_metrics_html + preds_section |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1841 |
|
19
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1842 gradcam_info = self._generate_gradcam_heatmaps(exp_dir, config, output_type) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1843 |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1844 if output_type == "regression" and train_stats_path.exists(): |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1845 try: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1846 test_plots = build_regression_test_plots(str(train_stats_path)) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1847 tab3_content = append_plot_blocks(tab3_content, test_plots) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1848 if test_plots: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1849 logger.info(f"Generated {len(test_plots)} regression test plots") |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1850 except Exception as e: |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1851 logger.warning(f"Could not generate regression test plots: {e}") |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1852 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1853 if output_type in ("binary", "category") and test_stats_path.exists(): |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1854 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1855 interactive_plots = build_classification_plots( |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1856 str(test_stats_path), |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1857 str(train_stats_path) if train_stats_path.exists() else None, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1858 metadata_csv_path=str(label_metadata_path) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1859 if label_metadata_path and label_metadata_path.exists() |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1860 else None, |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1861 train_set_metadata_path=str(train_set_metadata_path) |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1862 if train_set_metadata_path.exists() |
|
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1863 else None, |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1864 threshold=threshold_value, |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1865 ) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1866 tab3_content = append_plot_blocks(tab3_content, interactive_plots) |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1867 if interactive_plots: |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1868 logger.info(f"Generated {len(interactive_plots)} interactive Plotly plots") |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1869 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1870 logger.warning(f"Could not generate Plotly plots: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1871 |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1872 # Multi-class transparency plots from test stats (replace ROC/PR for multi-class) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1873 if output_type == "category" and test_stats_path.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1874 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1875 multi_curves = build_multiclass_metric_plots(str(test_stats_path)) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1876 tab3_content = append_plot_blocks(tab3_content, multi_curves) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1877 if multi_curves: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1878 logger.info("Added multi-class per-class metric plots to test tab") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1879 except Exception as e: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1880 logger.warning(f"Could not generate multi-class metric plots: {e}") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1881 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1882 # Test diagnostics (confidence histogram) from predictions.csv, using split=2 |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1883 if predictions_csv_path.exists(): |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1884 try: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1885 test_diag_plots = build_prediction_diagnostics( |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1886 str(predictions_csv_path), |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1887 label_data_path=str(config.get("label_column_data_path")) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1888 if config.get("label_column_data_path") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1889 else None, |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1890 split_value=2, |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1891 ) |
|
17
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1892 test_conf_plots = [ |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1893 p for p in test_diag_plots if "Prediction Confidence Distribution" in p.get("title", "") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1894 ] |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1895 if test_conf_plots: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1896 tab3_content = append_plot_blocks(tab3_content, test_conf_plots) |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1897 logger.info("Added test prediction confidence plot") |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1898 except Exception as e: |
|
db9be962dc13
planemo upload for repository https://github.com/goeckslab/gleam.git commit 9db874612b0c3e4f53d639459fe789b762660cd6
goeckslab
parents:
16
diff
changeset
|
1899 logger.warning(f"Could not generate test diagnostics: {e}") |
|
15
d17e3a1b8659
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
1900 |
|
19
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1901 if gradcam_info.get("status") == "generated": |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1902 tab3_content += "<h2 style='text-align: center;'>Grad-CAM Heatmaps</h2>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1903 for orig_path, heat_path in gradcam_info.get("pairs", [])[:4]: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1904 try: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1905 display_name = Path(str(orig_path)).name |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1906 if display_name.endswith("_original.png"): |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1907 display_name = display_name[: -len("_original.png")] |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1908 b64_orig = encode_image_to_base64(str(orig_path)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1909 b64_heat = encode_image_to_base64(str(heat_path)) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1910 tab3_content += ( |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1911 "<div class='plot' style='margin-bottom:15px;text-align:center;display:flex;gap:12px;justify-content:center;flex-wrap:wrap;'>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1912 f"<div><div style='font-weight:600;margin-bottom:4px;'>{display_name}</div>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1913 f"<img src='data:image/png;base64,{b64_orig}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1914 f"<div><div style='font-weight:600;margin-bottom:4px;'>Grad-CAM</div>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1915 f"<img src='data:image/png;base64,{b64_heat}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1916 "</div>" |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1917 ) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1918 except Exception as exc: |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1919 logger.debug("Could not embed Grad-CAM pair %s / %s: %s", orig_path, heat_path, exc) |
|
c460abae83eb
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
goeckslab
parents:
18
diff
changeset
|
1920 |
|
12
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1921 # Add static TEST PNGs (with default dedupe/exclusions) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1922 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1923 modal_html = get_metrics_help_modal() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1924 html += tabbed_html + modal_html + get_html_closing() |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1925 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1926 try: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1927 with open(report_path, "w") as f: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1928 f.write(html) |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1929 logger.info(f"HTML report generated at: {report_path}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1930 except Exception as e: |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1931 logger.error(f"Failed to write HTML report: {e}") |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1932 raise |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1933 |
|
bcfa2e234a80
planemo upload for repository https://github.com/goeckslab/gleam.git commit 96bab8325992d16fcaad8e0a4dc4c62b00e2abc2
goeckslab
parents:
diff
changeset
|
1934 return report_path |
