Mercurial > repos > goeckslab > image_learner
comparison image_workflow.py @ 15:d17e3a1b8659 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:45:49 +0000 |
| parents | bcfa2e234a80 |
| children |
comparison
equal
deleted
inserted
replaced
| 14:94cd9ac4a9b1 | 15:d17e3a1b8659 |
|---|---|
| 125 logger.info(f"Loaded metadata file: {self.args.csv_file}") | 125 logger.info(f"Loaded metadata file: {self.args.csv_file}") |
| 126 except Exception: | 126 except Exception: |
| 127 logger.error("Error loading metadata file", exc_info=True) | 127 logger.error("Error loading metadata file", exc_info=True) |
| 128 raise | 128 raise |
| 129 | 129 |
| 130 required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} | 130 label_col = self.args.target_column or LABEL_COLUMN_NAME |
| 131 missing = required - set(df.columns) | 131 image_col = self.args.image_column or IMAGE_PATH_COLUMN_NAME |
| 132 if missing: | 132 |
| 133 raise ValueError(f"Missing CSV columns: {', '.join(missing)}") | 133 # Remember the user-specified columns for reporting |
| 134 | 134 self.args.report_target_column = label_col |
| 135 try: | 135 self.args.report_image_column = image_col |
| 136 # Use relative paths that Ludwig can resolve from its internal working directory | 136 |
| 137 df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( | 137 missing_cols = [] |
| 138 lambda p: str(Path("images") / p) | 138 if label_col not in df.columns: |
| 139 ) | 139 missing_cols.append(label_col) |
| 140 if image_col not in df.columns: | |
| 141 missing_cols.append(image_col) | |
| 142 if missing_cols: | |
| 143 raise ValueError( | |
| 144 f"Missing required column(s) in metadata: {', '.join(missing_cols)}. " | |
| 145 "Update the XML selections or rename your columns." | |
| 146 ) | |
| 147 | |
| 148 if label_col != LABEL_COLUMN_NAME: | |
| 149 df = df.rename(columns={label_col: LABEL_COLUMN_NAME}) | |
| 150 if image_col != IMAGE_PATH_COLUMN_NAME: | |
| 151 df = df.rename(columns={image_col: IMAGE_PATH_COLUMN_NAME}) | |
| 152 | |
| 153 try: | |
| 154 df = self._map_image_paths_with_search(df) | |
| 140 except Exception: | 155 except Exception: |
| 141 logger.error("Error updating image paths", exc_info=True) | 156 logger.error("Error updating image paths", exc_info=True) |
| 142 raise | 157 raise |
| 143 | 158 |
| 144 if SPLIT_COLUMN_NAME in df.columns: | 159 if SPLIT_COLUMN_NAME in df.columns: |
| 203 metadata = {} | 218 metadata = {} |
| 204 | 219 |
| 205 self.label_metadata = metadata | 220 self.label_metadata = metadata |
| 206 self.output_type_hint = "binary" if metadata.get("is_binary") else None | 221 self.output_type_hint = "binary" if metadata.get("is_binary") else None |
| 207 | 222 |
| 223 def _map_image_paths_with_search(self, df: pd.DataFrame) -> pd.DataFrame: | |
| 224 """Map image identifiers to actual files by searching the extracted directory.""" | |
| 225 if not self.image_extract_dir: | |
| 226 raise RuntimeError("Image directory is not initialized.") | |
| 227 | |
| 228 # Build lookup maps for fast resolution by stem or full name | |
| 229 lookup_by_stem = {} | |
| 230 lookup_by_name = {} | |
| 231 for fpath in self.image_extract_dir.rglob("*"): | |
| 232 if fpath.is_file(): | |
| 233 stem_key = fpath.stem.lower() | |
| 234 name_key = fpath.name.lower() | |
| 235 # Prefer first encounter; warn on collisions | |
| 236 if stem_key in lookup_by_stem and lookup_by_stem[stem_key] != fpath: | |
| 237 logger.warning( | |
| 238 "Multiple files share the same stem '%s'. Using '%s'.", | |
| 239 stem_key, | |
| 240 lookup_by_stem[stem_key], | |
| 241 ) | |
| 242 else: | |
| 243 lookup_by_stem[stem_key] = fpath | |
| 244 if name_key in lookup_by_name and lookup_by_name[name_key] != fpath: | |
| 245 logger.warning( | |
| 246 "Multiple files share the same name '%s'. Using '%s'.", | |
| 247 name_key, | |
| 248 lookup_by_name[name_key], | |
| 249 ) | |
| 250 else: | |
| 251 lookup_by_name[name_key] = fpath | |
| 252 | |
| 253 resolved_paths = [] | |
| 254 missing_count = 0 | |
| 255 missing_samples = [] | |
| 256 | |
| 257 for raw in df[IMAGE_PATH_COLUMN_NAME]: | |
| 258 raw_str = str(raw) | |
| 259 name_key = Path(raw_str).name.lower() | |
| 260 stem_key = Path(raw_str).stem.lower() | |
| 261 resolved = lookup_by_name.get(name_key) or lookup_by_stem.get(stem_key) | |
| 262 | |
| 263 if resolved is None: | |
| 264 missing_count += 1 | |
| 265 missing_samples.append(raw_str) | |
| 266 resolved_paths.append(pd.NA) | |
| 267 continue | |
| 268 | |
| 269 try: | |
| 270 rel_path = resolved.relative_to(self.image_extract_dir) | |
| 271 except ValueError: | |
| 272 rel_path = resolved | |
| 273 resolved_paths.append(str(Path("images") / rel_path)) | |
| 274 | |
| 275 if missing_count: | |
| 276 logger.warning( | |
| 277 "Unable to locate %d image(s) from the metadata in the extracted images directory.", | |
| 278 missing_count, | |
| 279 ) | |
| 280 preview = ", ".join(missing_samples[:5]) | |
| 281 logger.warning("Missing samples (showing up to 5): %s", preview) | |
| 282 | |
| 283 df = df.copy() | |
| 284 df[IMAGE_PATH_COLUMN_NAME] = resolved_paths | |
| 285 df = df.dropna(subset=[IMAGE_PATH_COLUMN_NAME]).reset_index(drop=True) | |
| 286 return df | |
| 287 | |
| 208 # Removed duplicate method | 288 # Removed duplicate method |
| 209 | 289 |
| 210 def _detect_image_dimensions(self) -> Tuple[int, int]: | 290 def _detect_image_dimensions(self) -> Tuple[int, int]: |
| 211 """Detect image dimensions from the first image in the dataset.""" | 291 """Detect image dimensions from the first image in the dataset.""" |
| 212 try: | 292 try: |
| 273 "image_resize": self.args.image_resize, | 353 "image_resize": self.args.image_resize, |
| 274 "image_zip": self.args.image_zip, | 354 "image_zip": self.args.image_zip, |
| 275 "threshold": self.args.threshold, | 355 "threshold": self.args.threshold, |
| 276 "label_metadata": self.label_metadata, | 356 "label_metadata": self.label_metadata, |
| 277 "output_type_hint": self.output_type_hint, | 357 "output_type_hint": self.output_type_hint, |
| 358 "validation_metric": self.args.validation_metric, | |
| 359 "target_column": getattr(self.args, "report_target_column", LABEL_COLUMN_NAME), | |
| 360 "image_column": getattr(self.args, "report_image_column", IMAGE_PATH_COLUMN_NAME), | |
| 278 } | 361 } |
| 279 yaml_str = self.backend.prepare_config(backend_args, split_cfg) | 362 yaml_str = self.backend.prepare_config(backend_args, split_cfg) |
| 280 | 363 |
| 281 config_file = self.temp_dir / TEMP_CONFIG_FILENAME | 364 config_file = self.temp_dir / TEMP_CONFIG_FILENAME |
| 282 config_file.write_text(yaml_str) | 365 config_file.write_text(yaml_str) |
| 295 logger.error("Workflow execution failed", exc_info=True) | 378 logger.error("Workflow execution failed", exc_info=True) |
| 296 ran_ok = False | 379 ran_ok = False |
| 297 | 380 |
| 298 if ran_ok: | 381 if ran_ok: |
| 299 logger.info("Workflow completed successfully.") | 382 logger.info("Workflow completed successfully.") |
| 383 # Convert predictions parquet → csv | |
| 384 self.backend.convert_parquet_to_csv(self.args.output_dir) | |
| 385 logger.info("Converted Parquet to CSV.") | |
| 300 # Generate a very small set of plots to conserve disk space | 386 # Generate a very small set of plots to conserve disk space |
| 301 self.backend.generate_plots(self.args.output_dir) | 387 self.backend.generate_plots(self.args.output_dir) |
| 302 # Build HTML report (robust to missing metrics) | 388 # Build HTML report (robust to missing metrics) |
| 303 report_file = self.backend.generate_html_report( | 389 report_file = self.backend.generate_html_report( |
| 304 "Image Classification Results", | 390 "Image Classification Results", |
| 305 self.args.output_dir, | 391 self.args.output_dir, |
| 306 backend_args, | 392 backend_args, |
| 307 split_info, | 393 split_info, |
| 308 ) | 394 ) |
| 309 logger.info(f"HTML report generated at: {report_file}") | 395 logger.info(f"HTML report generated at: {report_file}") |
| 310 # Convert predictions parquet → csv | |
| 311 self.backend.convert_parquet_to_csv(self.args.output_dir) | |
| 312 logger.info("Converted Parquet to CSV.") | |
| 313 # Post-process cleanup to reduce disk footprint for subsequent tests | 396 # Post-process cleanup to reduce disk footprint for subsequent tests |
| 314 try: | 397 try: |
| 315 self._postprocess_cleanup(self.args.output_dir) | 398 self._postprocess_cleanup(self.args.output_dir) |
| 316 except Exception as cleanup_err: | 399 except Exception as cleanup_err: |
| 317 logger.warning(f"Cleanup step failed: {cleanup_err}") | 400 logger.warning(f"Cleanup step failed: {cleanup_err}") |
