Mercurial > repos > goeckslab > image_learner
comparison ludwig_backend.py @ 19:c460abae83eb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
| author | goeckslab |
|---|---|
| date | Thu, 18 Dec 2025 16:59:58 +0000 |
| parents | bbf30253c99f |
| children |
comparison
equal
deleted
inserted
replaced
| 18:bbf30253c99f | 19:c460abae83eb |
|---|---|
| 1 import inspect | 1 import inspect |
| 2 import json | 2 import json |
| 3 import logging | 3 import logging |
| 4 import os | 4 import os |
| 5 import zipfile | |
| 5 from pathlib import Path | 6 from pathlib import Path |
| 6 from typing import Any, Dict, List, Optional, Protocol, Tuple | 7 from typing import Any, Dict, List, Optional, Protocol, Tuple |
| 7 | 8 |
| 8 import pandas as pd | 9 import pandas as pd |
| 9 import pandas.api.types as ptypes | 10 import pandas.api.types as ptypes |
| 709 df = pd.read_parquet(parquet_path) | 710 df = pd.read_parquet(parquet_path) |
| 710 df.to_csv(csv_path, index=False) | 711 df.to_csv(csv_path, index=False) |
| 711 logger.info(f"Converted Parquet to CSV: {csv_path}") | 712 logger.info(f"Converted Parquet to CSV: {csv_path}") |
| 712 except Exception as e: | 713 except Exception as e: |
| 713 logger.error(f"Error converting Parquet to CSV: {e}") | 714 logger.error(f"Error converting Parquet to CSV: {e}") |
| 715 | |
| 716 def _get_latest_experiment_dir(self, output_dir: Path) -> Optional[Path]: | |
| 717 """Return the most recent experiment_run* directory, if present.""" | |
| 718 output_dir = Path(output_dir) | |
| 719 exp_dirs = sorted( | |
| 720 output_dir.glob("experiment_run*"), | |
| 721 key=lambda p: p.stat().st_mtime, | |
| 722 ) | |
| 723 return exp_dirs[-1] if exp_dirs else None | |
| 724 | |
| 725 def _extract_preprocessing_config( | |
| 726 self, exp_dir: Path, config: Dict[str, Any] | |
| 727 ) -> Tuple[Optional[Dict[str, Any]], Optional[Path]]: | |
| 728 """Parse Ludwig preprocessing settings from train_set_metadata or description.""" | |
| 729 image_meta: Dict[str, Any] = {} | |
| 730 meta_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME | |
| 731 if meta_path.exists(): | |
| 732 try: | |
| 733 with meta_path.open("r", encoding="utf-8") as f: | |
| 734 meta_json = json.load(f) | |
| 735 image_list = meta_json.get("input_features") or [] | |
| 736 if image_list: | |
| 737 image_meta = image_list[0] or {} | |
| 738 except Exception as exc: | |
| 739 logger.warning("Unable to read train_set_metadata: %s", exc) | |
| 740 | |
| 741 # Fallback to description config for preprocessing hints | |
| 742 desc_cfg: Dict[str, Any] = {} | |
| 743 desc_path = exp_dir / DESCRIPTION_FILE_NAME | |
| 744 if desc_path.exists(): | |
| 745 try: | |
| 746 with desc_path.open("r", encoding="utf-8") as f: | |
| 747 desc_json = json.load(f) | |
| 748 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {} | |
| 749 except Exception as exc: | |
| 750 logger.warning("Unable to read description.json for preprocessing: %s", exc) | |
| 751 | |
| 752 preprocessing = {} | |
| 753 if isinstance(image_meta, dict): | |
| 754 preprocessing = image_meta.get("preprocessing") or {} | |
| 755 if not preprocessing and desc_cfg: | |
| 756 try: | |
| 757 preprocessing = ( | |
| 758 desc_cfg.get("input_features", [{}])[0].get("preprocessing") or {} | |
| 759 ) | |
| 760 except Exception: | |
| 761 preprocessing = {} | |
| 762 | |
| 763 # If height/width are missing but max inferred dimensions exist, use them as fallback | |
| 764 if isinstance(preprocessing, dict): | |
| 765 if not preprocessing.get("height") and preprocessing.get("infer_image_max_height"): | |
| 766 preprocessing["height"] = preprocessing.get("infer_image_max_height") | |
| 767 if not preprocessing.get("width") and preprocessing.get("infer_image_max_width"): | |
| 768 preprocessing["width"] = preprocessing.get("infer_image_max_width") | |
| 769 | |
| 770 # Keep label path for downstream sampling | |
| 771 label_path = None | |
| 772 try: | |
| 773 label_path_cfg = config.get("label_column_data_path") | |
| 774 if label_path_cfg: | |
| 775 label_path = Path(label_path_cfg) | |
| 776 except Exception: | |
| 777 label_path = None | |
| 778 | |
| 779 return preprocessing if isinstance(preprocessing, dict) else {}, label_path | |
| 780 | |
| 781 def _find_last_conv_layer(self, encoder_obj: Any) -> Optional[Any]: | |
| 782 """Identify the last Conv2d layer within the encoder.""" | |
| 783 try: | |
| 784 import torch.nn as nn | |
| 785 except Exception: | |
| 786 return None | |
| 787 | |
| 788 target_model = encoder_obj | |
| 789 if hasattr(encoder_obj, "model"): | |
| 790 target_model = encoder_obj.model | |
| 791 | |
| 792 try: | |
| 793 modules = list(target_model.named_modules()) | |
| 794 except Exception: | |
| 795 return None | |
| 796 | |
| 797 for _, module in reversed(modules): | |
| 798 if isinstance(module, nn.Conv2d): | |
| 799 return module | |
| 800 return None | |
| 801 | |
| 802 def _generate_gradcam_heatmaps( | |
| 803 self, | |
| 804 exp_dir: Path, | |
| 805 config: Dict[str, Any], | |
| 806 output_type: Optional[str], | |
| 807 ) -> Dict[str, Any]: | |
| 808 """Compute Grad-CAM overlays for convolutional encoders, when possible.""" | |
| 809 result: Dict[str, Any] = { | |
| 810 "status": "skipped", | |
| 811 "reason": "", | |
| 812 "preview_paths": [], | |
| 813 "zip_path": None, | |
| 814 "dir_path": None, | |
| 815 } | |
| 816 | |
| 817 try: | |
| 818 import numpy as np | |
| 819 import torch | |
| 820 import torch.nn.functional as F | |
| 821 from matplotlib import cm | |
| 822 from PIL import Image | |
| 823 from ludwig.api import LudwigModel | |
| 824 except Exception as exc: | |
| 825 result["reason"] = f"Missing dependency for Grad-CAM: {exc}" | |
| 826 return result | |
| 827 | |
| 828 exp_dir = Path(exp_dir) | |
| 829 model_dir = exp_dir / "model" | |
| 830 if not model_dir.exists(): | |
| 831 result["reason"] = "Model directory not found; skipping Grad-CAM." | |
| 832 return result | |
| 833 | |
| 834 preprocessing, label_path = self._extract_preprocessing_config(exp_dir, config) | |
| 835 height = preprocessing.get("height") | |
| 836 width = preprocessing.get("width") | |
| 837 if not height or not width: | |
| 838 result["reason"] = "Image resize/height not found in Ludwig preprocessing." | |
| 839 return result | |
| 840 | |
| 841 label_csv = label_path if label_path and label_path.exists() else None | |
| 842 if not label_csv: | |
| 843 result["reason"] = "Prepared label CSV not available for Grad-CAM sampling." | |
| 844 return result | |
| 845 | |
| 846 try: | |
| 847 df_all = pd.read_csv(label_csv) | |
| 848 except Exception as exc: | |
| 849 result["reason"] = f"Could not read prepared CSV: {exc}" | |
| 850 return result | |
| 851 | |
| 852 if IMAGE_PATH_COLUMN_NAME not in df_all.columns: | |
| 853 result["reason"] = "Image column missing from prepared CSV; cannot build Grad-CAM inputs." | |
| 854 return result | |
| 855 | |
| 856 # Prefer test split; otherwise fall back to the full dataset | |
| 857 df_candidates = df_all | |
| 858 if SPLIT_COLUMN_NAME in df_all.columns: | |
| 859 try: | |
| 860 df_candidates = df_all[df_all[SPLIT_COLUMN_NAME] == 2] | |
| 861 if df_candidates.empty: | |
| 862 df_candidates = df_all | |
| 863 except Exception: | |
| 864 df_candidates = df_all | |
| 865 | |
| 866 # Cap the number of samples | |
| 867 df_candidates = df_candidates.head(12) | |
| 868 if df_candidates.empty: | |
| 869 result["reason"] = "No samples available for Grad-CAM generation." | |
| 870 return result | |
| 871 | |
| 872 try: | |
| 873 ludwig_model = LudwigModel.load(str(model_dir)) | |
| 874 except Exception as exc: | |
| 875 result["reason"] = f"Unable to load LudwigModel for Grad-CAM: {exc}" | |
| 876 return result | |
| 877 | |
| 878 base_model = getattr(ludwig_model, "model", None) | |
| 879 if base_model is None: | |
| 880 result["reason"] = "Ludwig model missing underlying torch model." | |
| 881 return result | |
| 882 | |
| 883 image_feature_name = None | |
| 884 image_feature = None | |
| 885 try: | |
| 886 for name, feat in getattr(base_model, "input_features", {}).items(): | |
| 887 if hasattr(feat, "encoder_obj"): | |
| 888 image_feature_name = name | |
| 889 image_feature = feat | |
| 890 break | |
| 891 except Exception: | |
| 892 image_feature_name = None | |
| 893 | |
| 894 if not image_feature or not image_feature_name: | |
| 895 result["reason"] = "Image input feature not found; skipping Grad-CAM." | |
| 896 return result | |
| 897 | |
| 898 target_layer = self._find_last_conv_layer(getattr(image_feature, "encoder_obj", None)) | |
| 899 if target_layer is None: | |
| 900 result["reason"] = "No convolutional layer detected in the encoder (heatmaps unsupported)." | |
| 901 return result | |
| 902 | |
| 903 standardize = preprocessing.get("standardize_image") | |
| 904 mean = preprocessing.get("mean") or preprocessing.get("img_mean") | |
| 905 std = preprocessing.get("std") or preprocessing.get("img_std") | |
| 906 encoder_obj = getattr(image_feature, "encoder_obj", None) | |
| 907 if hasattr(encoder_obj, "normalize_mean") and encoder_obj.normalize_mean: | |
| 908 mean = encoder_obj.normalize_mean | |
| 909 if hasattr(encoder_obj, "normalize_std") and encoder_obj.normalize_std: | |
| 910 std = encoder_obj.normalize_std | |
| 911 if isinstance(standardize, str) and standardize.lower() == "imagenet1k": | |
| 912 mean = [0.485, 0.456, 0.406] | |
| 913 std = [0.229, 0.224, 0.225] | |
| 914 if mean is None or std is None: | |
| 915 result["reason"] = "Normalization parameters (mean/std) not found in the saved encoder; skipping heatmaps to avoid mismatch." | |
| 916 return result | |
| 917 | |
| 918 output_feature_name = LABEL_COLUMN_NAME | |
| 919 try: | |
| 920 if getattr(base_model, "output_features", None): | |
| 921 output_feature_name = next(iter(base_model.output_features.keys())) | |
| 922 except Exception: | |
| 923 output_feature_name = LABEL_COLUMN_NAME | |
| 924 | |
| 925 device = torch.device("cpu") | |
| 926 try: | |
| 927 base_model.to(device) | |
| 928 base_model.eval() | |
| 929 except Exception: | |
| 930 logger.debug("Could not move model to CPU for Grad-CAM; continuing on default device.") | |
| 931 | |
| 932 heatmap_dir = exp_dir / "feature_importance_examples" | |
| 933 heatmap_dir.mkdir(parents=True, exist_ok=True) | |
| 934 | |
| 935 def _load_tensor(image_path: Path) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: | |
| 936 try: | |
| 937 img = Image.open(image_path).convert("RGB") | |
| 938 except Exception: | |
| 939 return None, None | |
| 940 resized = img.resize((int(width), int(height))) | |
| 941 arr = np.asarray(resized).astype("float32") / 255.0 | |
| 942 arr = np.transpose(arr, (2, 0, 1)) | |
| 943 tensor = torch.from_numpy(arr) | |
| 944 try: | |
| 945 mean_tensor = torch.tensor(mean).view(-1, 1, 1) | |
| 946 std_tensor = torch.tensor(std).view(-1, 1, 1) | |
| 947 tensor = (tensor - mean_tensor) / std_tensor | |
| 948 except Exception: | |
| 949 return None, None | |
| 950 return tensor.unsqueeze(0).to(device), resized | |
| 951 | |
| 952 generated: List[Path] = [] | |
| 953 pairs: List[Tuple[Path, Path]] = [] | |
| 954 image_root = label_csv.parent | |
| 955 | |
| 956 for _, row in df_candidates.iterrows(): | |
| 957 raw_path = row.get(IMAGE_PATH_COLUMN_NAME) | |
| 958 if not isinstance(raw_path, str): | |
| 959 continue | |
| 960 abs_path = (image_root / raw_path).resolve() | |
| 961 if not abs_path.exists(): | |
| 962 continue | |
| 963 | |
| 964 tensor, resized_img = _load_tensor(abs_path) | |
| 965 if tensor is None or resized_img is None: | |
| 966 continue | |
| 967 | |
| 968 activations: List[torch.Tensor] = [] | |
| 969 gradients: List[torch.Tensor] = [] | |
| 970 | |
| 971 def _fwd_hook(_module, _inp, output): | |
| 972 activations.append(output) | |
| 973 | |
| 974 def _bwd_hook(_module, _grad_in, grad_out): | |
| 975 if grad_out and isinstance(grad_out[0], torch.Tensor): | |
| 976 gradients.append(grad_out[0]) | |
| 977 | |
| 978 handle_fwd = target_layer.register_forward_hook(_fwd_hook) | |
| 979 try: | |
| 980 handle_bwd = target_layer.register_full_backward_hook(_bwd_hook) | |
| 981 except Exception: | |
| 982 handle_bwd = target_layer.register_backward_hook(_bwd_hook) | |
| 983 | |
| 984 try: | |
| 985 base_model.zero_grad(set_to_none=True) | |
| 986 with torch.enable_grad(): | |
| 987 outputs = base_model({image_feature_name: tensor}) | |
| 988 | |
| 989 logits = None | |
| 990 if isinstance(outputs, dict): | |
| 991 feature_out = outputs.get(output_feature_name) | |
| 992 if isinstance(feature_out, dict): | |
| 993 logits = feature_out.get("logits") or feature_out.get("logit") | |
| 994 elif isinstance(feature_out, torch.Tensor): | |
| 995 logits = feature_out | |
| 996 | |
| 997 # Ludwig 0.10+ uses namespaced keys: "<feature>::logits" | |
| 998 if logits is None: | |
| 999 ns_key = f"{output_feature_name}::logits" | |
| 1000 if isinstance(outputs.get(ns_key), torch.Tensor): | |
| 1001 logits = outputs[ns_key] | |
| 1002 | |
| 1003 # Fallback: a top-level logits tensor | |
| 1004 if logits is None and isinstance(outputs.get("logits"), torch.Tensor): | |
| 1005 logits = outputs.get("logits") | |
| 1006 | |
| 1007 if logits is None: | |
| 1008 raise ValueError("Could not locate logits for Grad-CAM.") | |
| 1009 | |
| 1010 if logits.dim() == 1: | |
| 1011 target_logit = logits.unsqueeze(0) | |
| 1012 else: | |
| 1013 target_class = 0 | |
| 1014 if output_type != "regression" and logits.shape[-1] > 1: | |
| 1015 target_class = int(torch.argmax(logits, dim=-1).item()) | |
| 1016 target_logit = logits[:, target_class] | |
| 1017 | |
| 1018 target_logit.sum().backward() | |
| 1019 | |
| 1020 if not activations or not gradients: | |
| 1021 raise ValueError("Missing activations or gradients for Grad-CAM.") | |
| 1022 | |
| 1023 act = activations[-1] | |
| 1024 grad = gradients[-1] | |
| 1025 weights = grad.mean(dim=(2, 3), keepdim=True) | |
| 1026 cam = (weights * act).sum(dim=1) | |
| 1027 cam = torch.relu(cam) | |
| 1028 cam = cam.squeeze(0) | |
| 1029 cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(int(height), int(width)), mode="bilinear", align_corners=False) | |
| 1030 cam = cam.squeeze().detach().cpu().numpy() | |
| 1031 if cam.max() > 0: | |
| 1032 cam = cam / cam.max() | |
| 1033 heatmap_rgba = np.uint8(cm.get_cmap("jet")(cam) * 255) | |
| 1034 heatmap_img = Image.fromarray(heatmap_rgba).convert("RGBA").resize(resized_img.size) | |
| 1035 overlay = Image.blend(resized_img.convert("RGBA"), heatmap_img, alpha=0.45) | |
| 1036 | |
| 1037 stem = Path(raw_path).stem | |
| 1038 out_path = heatmap_dir / f"{stem}_gradcam.png" | |
| 1039 overlay.save(out_path) | |
| 1040 orig_path = heatmap_dir / f"{stem}_original.png" | |
| 1041 try: | |
| 1042 resized_img.save(orig_path) | |
| 1043 except Exception: | |
| 1044 orig_path = None | |
| 1045 | |
| 1046 generated.append(out_path) | |
| 1047 if orig_path: | |
| 1048 pairs.append((orig_path, out_path)) | |
| 1049 except Exception as exc: | |
| 1050 logger.warning("Grad-CAM failed for %s: %s", raw_path, exc) | |
| 1051 finally: | |
| 1052 try: | |
| 1053 handle_fwd.remove() | |
| 1054 handle_bwd.remove() | |
| 1055 except Exception: | |
| 1056 pass | |
| 1057 | |
| 1058 if not generated: | |
| 1059 result["reason"] = "No heatmaps were generated (model may be non-convolutional or preprocessing missing)." | |
| 1060 return result | |
| 1061 | |
| 1062 zip_path = exp_dir / "feature_importance_examples.zip" | |
| 1063 try: | |
| 1064 with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| 1065 for png in generated: | |
| 1066 zf.write(png, png.name) | |
| 1067 except Exception as exc: | |
| 1068 logger.warning("Failed to create Grad-CAM zip: %s", exc) | |
| 1069 | |
| 1070 result.update( | |
| 1071 { | |
| 1072 "status": "generated", | |
| 1073 "preview_paths": generated[:6], | |
| 1074 "pairs": pairs[:6], | |
| 1075 "zip_path": zip_path if zip_path.exists() else None, | |
| 1076 "dir_path": heatmap_dir, | |
| 1077 } | |
| 1078 ) | |
| 1079 return result | |
| 714 | 1080 |
| 715 @staticmethod | 1081 @staticmethod |
| 716 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: | 1082 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: |
| 717 """Pull the first numeric metric list we can find for the requested split.""" | 1083 """Pull the first numeric metric list we can find for the requested split.""" |
| 718 if not isinstance(stats, dict): | 1084 if not isinstance(stats, dict): |
| 1471 except Exception as e: | 1837 except Exception as e: |
| 1472 logger.warning(f"Could not build Predictions vs GT table: {e}") | 1838 logger.warning(f"Could not build Predictions vs GT table: {e}") |
| 1473 | 1839 |
| 1474 tab3_content = test_metrics_html + preds_section | 1840 tab3_content = test_metrics_html + preds_section |
| 1475 | 1841 |
| 1842 gradcam_info = self._generate_gradcam_heatmaps(exp_dir, config, output_type) | |
| 1843 | |
| 1476 if output_type == "regression" and train_stats_path.exists(): | 1844 if output_type == "regression" and train_stats_path.exists(): |
| 1477 try: | 1845 try: |
| 1478 test_plots = build_regression_test_plots(str(train_stats_path)) | 1846 test_plots = build_regression_test_plots(str(train_stats_path)) |
| 1479 tab3_content = append_plot_blocks(tab3_content, test_plots) | 1847 tab3_content = append_plot_blocks(tab3_content, test_plots) |
| 1480 if test_plots: | 1848 if test_plots: |
| 1528 tab3_content = append_plot_blocks(tab3_content, test_conf_plots) | 1896 tab3_content = append_plot_blocks(tab3_content, test_conf_plots) |
| 1529 logger.info("Added test prediction confidence plot") | 1897 logger.info("Added test prediction confidence plot") |
| 1530 except Exception as e: | 1898 except Exception as e: |
| 1531 logger.warning(f"Could not generate test diagnostics: {e}") | 1899 logger.warning(f"Could not generate test diagnostics: {e}") |
| 1532 | 1900 |
| 1901 if gradcam_info.get("status") == "generated": | |
| 1902 tab3_content += "<h2 style='text-align: center;'>Grad-CAM Heatmaps</h2>" | |
| 1903 for orig_path, heat_path in gradcam_info.get("pairs", [])[:4]: | |
| 1904 try: | |
| 1905 display_name = Path(str(orig_path)).name | |
| 1906 if display_name.endswith("_original.png"): | |
| 1907 display_name = display_name[: -len("_original.png")] | |
| 1908 b64_orig = encode_image_to_base64(str(orig_path)) | |
| 1909 b64_heat = encode_image_to_base64(str(heat_path)) | |
| 1910 tab3_content += ( | |
| 1911 "<div class='plot' style='margin-bottom:15px;text-align:center;display:flex;gap:12px;justify-content:center;flex-wrap:wrap;'>" | |
| 1912 f"<div><div style='font-weight:600;margin-bottom:4px;'>{display_name}</div>" | |
| 1913 f"<img src='data:image/png;base64,{b64_orig}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" | |
| 1914 f"<div><div style='font-weight:600;margin-bottom:4px;'>Grad-CAM</div>" | |
| 1915 f"<img src='data:image/png;base64,{b64_heat}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>" | |
| 1916 "</div>" | |
| 1917 ) | |
| 1918 except Exception as exc: | |
| 1919 logger.debug("Could not embed Grad-CAM pair %s / %s: %s", orig_path, heat_path, exc) | |
| 1920 | |
| 1533 # Add static TEST PNGs (with default dedupe/exclusions) | 1921 # Add static TEST PNGs (with default dedupe/exclusions) |
| 1534 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) | 1922 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) |
| 1535 modal_html = get_metrics_help_modal() | 1923 modal_html = get_metrics_help_modal() |
| 1536 html += tabbed_html + modal_html + get_html_closing() | 1924 html += tabbed_html + modal_html + get_html_closing() |
| 1537 | 1925 |
