comparison ludwig_backend.py @ 19:c460abae83eb draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit b47f0fd63d8d5d18d602d45bb21ebbe36ba4fcfe
author goeckslab
date Thu, 18 Dec 2025 16:59:58 +0000
parents bbf30253c99f
children
comparison
equal deleted inserted replaced
18:bbf30253c99f 19:c460abae83eb
1 import inspect 1 import inspect
2 import json 2 import json
3 import logging 3 import logging
4 import os 4 import os
5 import zipfile
5 from pathlib import Path 6 from pathlib import Path
6 from typing import Any, Dict, List, Optional, Protocol, Tuple 7 from typing import Any, Dict, List, Optional, Protocol, Tuple
7 8
8 import pandas as pd 9 import pandas as pd
9 import pandas.api.types as ptypes 10 import pandas.api.types as ptypes
709 df = pd.read_parquet(parquet_path) 710 df = pd.read_parquet(parquet_path)
710 df.to_csv(csv_path, index=False) 711 df.to_csv(csv_path, index=False)
711 logger.info(f"Converted Parquet to CSV: {csv_path}") 712 logger.info(f"Converted Parquet to CSV: {csv_path}")
712 except Exception as e: 713 except Exception as e:
713 logger.error(f"Error converting Parquet to CSV: {e}") 714 logger.error(f"Error converting Parquet to CSV: {e}")
715
716 def _get_latest_experiment_dir(self, output_dir: Path) -> Optional[Path]:
717 """Return the most recent experiment_run* directory, if present."""
718 output_dir = Path(output_dir)
719 exp_dirs = sorted(
720 output_dir.glob("experiment_run*"),
721 key=lambda p: p.stat().st_mtime,
722 )
723 return exp_dirs[-1] if exp_dirs else None
724
725 def _extract_preprocessing_config(
726 self, exp_dir: Path, config: Dict[str, Any]
727 ) -> Tuple[Optional[Dict[str, Any]], Optional[Path]]:
728 """Parse Ludwig preprocessing settings from train_set_metadata or description."""
729 image_meta: Dict[str, Any] = {}
730 meta_path = exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME
731 if meta_path.exists():
732 try:
733 with meta_path.open("r", encoding="utf-8") as f:
734 meta_json = json.load(f)
735 image_list = meta_json.get("input_features") or []
736 if image_list:
737 image_meta = image_list[0] or {}
738 except Exception as exc:
739 logger.warning("Unable to read train_set_metadata: %s", exc)
740
741 # Fallback to description config for preprocessing hints
742 desc_cfg: Dict[str, Any] = {}
743 desc_path = exp_dir / DESCRIPTION_FILE_NAME
744 if desc_path.exists():
745 try:
746 with desc_path.open("r", encoding="utf-8") as f:
747 desc_json = json.load(f)
748 desc_cfg = desc_json.get("config", {}) if isinstance(desc_json, dict) else {}
749 except Exception as exc:
750 logger.warning("Unable to read description.json for preprocessing: %s", exc)
751
752 preprocessing = {}
753 if isinstance(image_meta, dict):
754 preprocessing = image_meta.get("preprocessing") or {}
755 if not preprocessing and desc_cfg:
756 try:
757 preprocessing = (
758 desc_cfg.get("input_features", [{}])[0].get("preprocessing") or {}
759 )
760 except Exception:
761 preprocessing = {}
762
763 # If height/width are missing but max inferred dimensions exist, use them as fallback
764 if isinstance(preprocessing, dict):
765 if not preprocessing.get("height") and preprocessing.get("infer_image_max_height"):
766 preprocessing["height"] = preprocessing.get("infer_image_max_height")
767 if not preprocessing.get("width") and preprocessing.get("infer_image_max_width"):
768 preprocessing["width"] = preprocessing.get("infer_image_max_width")
769
770 # Keep label path for downstream sampling
771 label_path = None
772 try:
773 label_path_cfg = config.get("label_column_data_path")
774 if label_path_cfg:
775 label_path = Path(label_path_cfg)
776 except Exception:
777 label_path = None
778
779 return preprocessing if isinstance(preprocessing, dict) else {}, label_path
780
781 def _find_last_conv_layer(self, encoder_obj: Any) -> Optional[Any]:
782 """Identify the last Conv2d layer within the encoder."""
783 try:
784 import torch.nn as nn
785 except Exception:
786 return None
787
788 target_model = encoder_obj
789 if hasattr(encoder_obj, "model"):
790 target_model = encoder_obj.model
791
792 try:
793 modules = list(target_model.named_modules())
794 except Exception:
795 return None
796
797 for _, module in reversed(modules):
798 if isinstance(module, nn.Conv2d):
799 return module
800 return None
801
802 def _generate_gradcam_heatmaps(
803 self,
804 exp_dir: Path,
805 config: Dict[str, Any],
806 output_type: Optional[str],
807 ) -> Dict[str, Any]:
808 """Compute Grad-CAM overlays for convolutional encoders, when possible."""
809 result: Dict[str, Any] = {
810 "status": "skipped",
811 "reason": "",
812 "preview_paths": [],
813 "zip_path": None,
814 "dir_path": None,
815 }
816
817 try:
818 import numpy as np
819 import torch
820 import torch.nn.functional as F
821 from matplotlib import cm
822 from PIL import Image
823 from ludwig.api import LudwigModel
824 except Exception as exc:
825 result["reason"] = f"Missing dependency for Grad-CAM: {exc}"
826 return result
827
828 exp_dir = Path(exp_dir)
829 model_dir = exp_dir / "model"
830 if not model_dir.exists():
831 result["reason"] = "Model directory not found; skipping Grad-CAM."
832 return result
833
834 preprocessing, label_path = self._extract_preprocessing_config(exp_dir, config)
835 height = preprocessing.get("height")
836 width = preprocessing.get("width")
837 if not height or not width:
838 result["reason"] = "Image resize/height not found in Ludwig preprocessing."
839 return result
840
841 label_csv = label_path if label_path and label_path.exists() else None
842 if not label_csv:
843 result["reason"] = "Prepared label CSV not available for Grad-CAM sampling."
844 return result
845
846 try:
847 df_all = pd.read_csv(label_csv)
848 except Exception as exc:
849 result["reason"] = f"Could not read prepared CSV: {exc}"
850 return result
851
852 if IMAGE_PATH_COLUMN_NAME not in df_all.columns:
853 result["reason"] = "Image column missing from prepared CSV; cannot build Grad-CAM inputs."
854 return result
855
856 # Prefer test split; otherwise fall back to the full dataset
857 df_candidates = df_all
858 if SPLIT_COLUMN_NAME in df_all.columns:
859 try:
860 df_candidates = df_all[df_all[SPLIT_COLUMN_NAME] == 2]
861 if df_candidates.empty:
862 df_candidates = df_all
863 except Exception:
864 df_candidates = df_all
865
866 # Cap the number of samples
867 df_candidates = df_candidates.head(12)
868 if df_candidates.empty:
869 result["reason"] = "No samples available for Grad-CAM generation."
870 return result
871
872 try:
873 ludwig_model = LudwigModel.load(str(model_dir))
874 except Exception as exc:
875 result["reason"] = f"Unable to load LudwigModel for Grad-CAM: {exc}"
876 return result
877
878 base_model = getattr(ludwig_model, "model", None)
879 if base_model is None:
880 result["reason"] = "Ludwig model missing underlying torch model."
881 return result
882
883 image_feature_name = None
884 image_feature = None
885 try:
886 for name, feat in getattr(base_model, "input_features", {}).items():
887 if hasattr(feat, "encoder_obj"):
888 image_feature_name = name
889 image_feature = feat
890 break
891 except Exception:
892 image_feature_name = None
893
894 if not image_feature or not image_feature_name:
895 result["reason"] = "Image input feature not found; skipping Grad-CAM."
896 return result
897
898 target_layer = self._find_last_conv_layer(getattr(image_feature, "encoder_obj", None))
899 if target_layer is None:
900 result["reason"] = "No convolutional layer detected in the encoder (heatmaps unsupported)."
901 return result
902
903 standardize = preprocessing.get("standardize_image")
904 mean = preprocessing.get("mean") or preprocessing.get("img_mean")
905 std = preprocessing.get("std") or preprocessing.get("img_std")
906 encoder_obj = getattr(image_feature, "encoder_obj", None)
907 if hasattr(encoder_obj, "normalize_mean") and encoder_obj.normalize_mean:
908 mean = encoder_obj.normalize_mean
909 if hasattr(encoder_obj, "normalize_std") and encoder_obj.normalize_std:
910 std = encoder_obj.normalize_std
911 if isinstance(standardize, str) and standardize.lower() == "imagenet1k":
912 mean = [0.485, 0.456, 0.406]
913 std = [0.229, 0.224, 0.225]
914 if mean is None or std is None:
915 result["reason"] = "Normalization parameters (mean/std) not found in the saved encoder; skipping heatmaps to avoid mismatch."
916 return result
917
918 output_feature_name = LABEL_COLUMN_NAME
919 try:
920 if getattr(base_model, "output_features", None):
921 output_feature_name = next(iter(base_model.output_features.keys()))
922 except Exception:
923 output_feature_name = LABEL_COLUMN_NAME
924
925 device = torch.device("cpu")
926 try:
927 base_model.to(device)
928 base_model.eval()
929 except Exception:
930 logger.debug("Could not move model to CPU for Grad-CAM; continuing on default device.")
931
932 heatmap_dir = exp_dir / "feature_importance_examples"
933 heatmap_dir.mkdir(parents=True, exist_ok=True)
934
935 def _load_tensor(image_path: Path) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]:
936 try:
937 img = Image.open(image_path).convert("RGB")
938 except Exception:
939 return None, None
940 resized = img.resize((int(width), int(height)))
941 arr = np.asarray(resized).astype("float32") / 255.0
942 arr = np.transpose(arr, (2, 0, 1))
943 tensor = torch.from_numpy(arr)
944 try:
945 mean_tensor = torch.tensor(mean).view(-1, 1, 1)
946 std_tensor = torch.tensor(std).view(-1, 1, 1)
947 tensor = (tensor - mean_tensor) / std_tensor
948 except Exception:
949 return None, None
950 return tensor.unsqueeze(0).to(device), resized
951
952 generated: List[Path] = []
953 pairs: List[Tuple[Path, Path]] = []
954 image_root = label_csv.parent
955
956 for _, row in df_candidates.iterrows():
957 raw_path = row.get(IMAGE_PATH_COLUMN_NAME)
958 if not isinstance(raw_path, str):
959 continue
960 abs_path = (image_root / raw_path).resolve()
961 if not abs_path.exists():
962 continue
963
964 tensor, resized_img = _load_tensor(abs_path)
965 if tensor is None or resized_img is None:
966 continue
967
968 activations: List[torch.Tensor] = []
969 gradients: List[torch.Tensor] = []
970
971 def _fwd_hook(_module, _inp, output):
972 activations.append(output)
973
974 def _bwd_hook(_module, _grad_in, grad_out):
975 if grad_out and isinstance(grad_out[0], torch.Tensor):
976 gradients.append(grad_out[0])
977
978 handle_fwd = target_layer.register_forward_hook(_fwd_hook)
979 try:
980 handle_bwd = target_layer.register_full_backward_hook(_bwd_hook)
981 except Exception:
982 handle_bwd = target_layer.register_backward_hook(_bwd_hook)
983
984 try:
985 base_model.zero_grad(set_to_none=True)
986 with torch.enable_grad():
987 outputs = base_model({image_feature_name: tensor})
988
989 logits = None
990 if isinstance(outputs, dict):
991 feature_out = outputs.get(output_feature_name)
992 if isinstance(feature_out, dict):
993 logits = feature_out.get("logits") or feature_out.get("logit")
994 elif isinstance(feature_out, torch.Tensor):
995 logits = feature_out
996
997 # Ludwig 0.10+ uses namespaced keys: "<feature>::logits"
998 if logits is None:
999 ns_key = f"{output_feature_name}::logits"
1000 if isinstance(outputs.get(ns_key), torch.Tensor):
1001 logits = outputs[ns_key]
1002
1003 # Fallback: a top-level logits tensor
1004 if logits is None and isinstance(outputs.get("logits"), torch.Tensor):
1005 logits = outputs.get("logits")
1006
1007 if logits is None:
1008 raise ValueError("Could not locate logits for Grad-CAM.")
1009
1010 if logits.dim() == 1:
1011 target_logit = logits.unsqueeze(0)
1012 else:
1013 target_class = 0
1014 if output_type != "regression" and logits.shape[-1] > 1:
1015 target_class = int(torch.argmax(logits, dim=-1).item())
1016 target_logit = logits[:, target_class]
1017
1018 target_logit.sum().backward()
1019
1020 if not activations or not gradients:
1021 raise ValueError("Missing activations or gradients for Grad-CAM.")
1022
1023 act = activations[-1]
1024 grad = gradients[-1]
1025 weights = grad.mean(dim=(2, 3), keepdim=True)
1026 cam = (weights * act).sum(dim=1)
1027 cam = torch.relu(cam)
1028 cam = cam.squeeze(0)
1029 cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(int(height), int(width)), mode="bilinear", align_corners=False)
1030 cam = cam.squeeze().detach().cpu().numpy()
1031 if cam.max() > 0:
1032 cam = cam / cam.max()
1033 heatmap_rgba = np.uint8(cm.get_cmap("jet")(cam) * 255)
1034 heatmap_img = Image.fromarray(heatmap_rgba).convert("RGBA").resize(resized_img.size)
1035 overlay = Image.blend(resized_img.convert("RGBA"), heatmap_img, alpha=0.45)
1036
1037 stem = Path(raw_path).stem
1038 out_path = heatmap_dir / f"{stem}_gradcam.png"
1039 overlay.save(out_path)
1040 orig_path = heatmap_dir / f"{stem}_original.png"
1041 try:
1042 resized_img.save(orig_path)
1043 except Exception:
1044 orig_path = None
1045
1046 generated.append(out_path)
1047 if orig_path:
1048 pairs.append((orig_path, out_path))
1049 except Exception as exc:
1050 logger.warning("Grad-CAM failed for %s: %s", raw_path, exc)
1051 finally:
1052 try:
1053 handle_fwd.remove()
1054 handle_bwd.remove()
1055 except Exception:
1056 pass
1057
1058 if not generated:
1059 result["reason"] = "No heatmaps were generated (model may be non-convolutional or preprocessing missing)."
1060 return result
1061
1062 zip_path = exp_dir / "feature_importance_examples.zip"
1063 try:
1064 with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
1065 for png in generated:
1066 zf.write(png, png.name)
1067 except Exception as exc:
1068 logger.warning("Failed to create Grad-CAM zip: %s", exc)
1069
1070 result.update(
1071 {
1072 "status": "generated",
1073 "preview_paths": generated[:6],
1074 "pairs": pairs[:6],
1075 "zip_path": zip_path if zip_path.exists() else None,
1076 "dir_path": heatmap_dir,
1077 }
1078 )
1079 return result
714 1080
715 @staticmethod 1081 @staticmethod
716 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]: 1082 def _extract_metric_series(stats: Dict[str, Any], split: str, prefer: Optional[str] = None) -> Tuple[Optional[str], Optional[List[float]]]:
717 """Pull the first numeric metric list we can find for the requested split.""" 1083 """Pull the first numeric metric list we can find for the requested split."""
718 if not isinstance(stats, dict): 1084 if not isinstance(stats, dict):
1471 except Exception as e: 1837 except Exception as e:
1472 logger.warning(f"Could not build Predictions vs GT table: {e}") 1838 logger.warning(f"Could not build Predictions vs GT table: {e}")
1473 1839
1474 tab3_content = test_metrics_html + preds_section 1840 tab3_content = test_metrics_html + preds_section
1475 1841
1842 gradcam_info = self._generate_gradcam_heatmaps(exp_dir, config, output_type)
1843
1476 if output_type == "regression" and train_stats_path.exists(): 1844 if output_type == "regression" and train_stats_path.exists():
1477 try: 1845 try:
1478 test_plots = build_regression_test_plots(str(train_stats_path)) 1846 test_plots = build_regression_test_plots(str(train_stats_path))
1479 tab3_content = append_plot_blocks(tab3_content, test_plots) 1847 tab3_content = append_plot_blocks(tab3_content, test_plots)
1480 if test_plots: 1848 if test_plots:
1528 tab3_content = append_plot_blocks(tab3_content, test_conf_plots) 1896 tab3_content = append_plot_blocks(tab3_content, test_conf_plots)
1529 logger.info("Added test prediction confidence plot") 1897 logger.info("Added test prediction confidence plot")
1530 except Exception as e: 1898 except Exception as e:
1531 logger.warning(f"Could not generate test diagnostics: {e}") 1899 logger.warning(f"Could not generate test diagnostics: {e}")
1532 1900
1901 if gradcam_info.get("status") == "generated":
1902 tab3_content += "<h2 style='text-align: center;'>Grad-CAM Heatmaps</h2>"
1903 for orig_path, heat_path in gradcam_info.get("pairs", [])[:4]:
1904 try:
1905 display_name = Path(str(orig_path)).name
1906 if display_name.endswith("_original.png"):
1907 display_name = display_name[: -len("_original.png")]
1908 b64_orig = encode_image_to_base64(str(orig_path))
1909 b64_heat = encode_image_to_base64(str(heat_path))
1910 tab3_content += (
1911 "<div class='plot' style='margin-bottom:15px;text-align:center;display:flex;gap:12px;justify-content:center;flex-wrap:wrap;'>"
1912 f"<div><div style='font-weight:600;margin-bottom:4px;'>{display_name}</div>"
1913 f"<img src='data:image/png;base64,{b64_orig}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>"
1914 f"<div><div style='font-weight:600;margin-bottom:4px;'>Grad-CAM</div>"
1915 f"<img src='data:image/png;base64,{b64_heat}' style='max-width:320px;max-height:320px;border:1px solid #ddd;' /></div>"
1916 "</div>"
1917 )
1918 except Exception as exc:
1919 logger.debug("Could not embed Grad-CAM pair %s / %s: %s", orig_path, heat_path, exc)
1920
1533 # Add static TEST PNGs (with default dedupe/exclusions) 1921 # Add static TEST PNGs (with default dedupe/exclusions)
1534 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) 1922 tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
1535 modal_html = get_metrics_help_modal() 1923 modal_html = get_metrics_help_modal()
1536 html += tabbed_html + modal_html + get_html_closing() 1924 html += tabbed_html + modal_html + get_html_closing()
1537 1925