Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | 4fee4504646e |
| children |
comparison
equal
deleted
inserted
replaced
| 16:4fee4504646e | 17:c5c324ac29fc |
|---|---|
| 7 import joblib | 7 import joblib |
| 8 import numpy as np | 8 import numpy as np |
| 9 import pandas as pd | 9 import pandas as pd |
| 10 from feature_help_modal import get_feature_metrics_help_modal | 10 from feature_help_modal import get_feature_metrics_help_modal |
| 11 from feature_importance import FeatureImportanceAnalyzer | 11 from feature_importance import FeatureImportanceAnalyzer |
| 12 from sklearn.metrics import average_precision_score | 12 from sklearn.metrics import ( |
| 13 accuracy_score, | |
| 14 average_precision_score, | |
| 15 confusion_matrix, | |
| 16 f1_score, | |
| 17 matthews_corrcoef, | |
| 18 precision_score, | |
| 19 recall_score, | |
| 20 roc_auc_score, | |
| 21 ) | |
| 13 from utils import ( | 22 from utils import ( |
| 14 add_hr_to_html, | 23 add_hr_to_html, |
| 15 add_plot_to_html, | 24 add_plot_to_html, |
| 16 build_tabbed_html, | 25 build_tabbed_html, |
| 17 encode_image_to_base64, | 26 encode_image_to_base64, |
| 385 | 394 |
| 386 def encode_image_to_base64(self, img_path: str) -> str: | 395 def encode_image_to_base64(self, img_path: str) -> str: |
| 387 with open(img_path, "rb") as img_file: | 396 with open(img_path, "rb") as img_file: |
| 388 return base64.b64encode(img_file.read()).decode("utf-8") | 397 return base64.b64encode(img_file.read()).decode("utf-8") |
| 389 | 398 |
| 399 def _build_dataset_overview(self): | |
| 400 """ | |
| 401 Build an HTML table showing label counts with labels as rows and splits | |
| 402 (Train / Validation / Test) as columns. Each cell shows count and | |
| 403 percentage of that split. Returns empty string for regression or when | |
| 404 no label data is available. | |
| 405 """ | |
| 406 if self.task_type != "classification": | |
| 407 return "" | |
| 408 | |
| 409 def _safe_series(obj): | |
| 410 try: | |
| 411 return pd.Series(obj).reset_index(drop=True) | |
| 412 except Exception: | |
| 413 return None | |
| 414 | |
| 415 def _get_from_config(keys): | |
| 416 if self.exp is None: | |
| 417 return None | |
| 418 for key in keys: | |
| 419 try: | |
| 420 val = self.exp.get_config(key) | |
| 421 except Exception: | |
| 422 val = getattr(self.exp, key, None) | |
| 423 if val is not None: | |
| 424 return val | |
| 425 return None | |
| 426 | |
| 427 # Prefer PyCaret-configured splits; fall back to raw inputs. | |
| 428 X_train = _get_from_config(["X_train_transformed", "X_train"]) | |
| 429 y_train = _get_from_config(["y_train_transformed", "y_train"]) | |
| 430 y_test_cfg = _get_from_config(["y_test_transformed", "y_test"]) | |
| 431 | |
| 432 if y_train is None and self.data is not None and self.target in self.data.columns: | |
| 433 y_train = self.data[self.target] | |
| 434 | |
| 435 y_train_series = _safe_series(y_train) | |
| 436 | |
| 437 # Build a cross-validation generator to derive a validation subset size. | |
| 438 cv_gen = self._get_cv_generator(y_train_series) | |
| 439 y_train_fold = y_train_series | |
| 440 y_val_fold = None | |
| 441 if cv_gen is not None and y_train_series is not None: | |
| 442 try: | |
| 443 # Use the first fold to approximate Train/Validation split sizes. | |
| 444 splitter = cv_gen.split( | |
| 445 pd.DataFrame(X_train).reset_index(drop=True) | |
| 446 if X_train is not None | |
| 447 else y_train_series, | |
| 448 y_train_series, | |
| 449 ) | |
| 450 train_idx, val_idx = next(iter(splitter)) | |
| 451 y_train_fold = y_train_series.iloc[train_idx].reset_index(drop=True) | |
| 452 y_val_fold = y_train_series.iloc[val_idx].reset_index(drop=True) | |
| 453 except Exception as exc: | |
| 454 LOG.warning("Could not derive validation split for dataset overview: %s", exc) | |
| 455 | |
| 456 # Test labels: prefer PyCaret transformed holdout (single file) or external test. | |
| 457 if self.test_data is not None: | |
| 458 if y_test_cfg is not None: | |
| 459 y_test = y_test_cfg | |
| 460 elif self.target in self.test_data.columns: | |
| 461 y_test = self.test_data[self.target] | |
| 462 else: | |
| 463 y_test = None | |
| 464 else: | |
| 465 y_test = y_test_cfg | |
| 466 | |
| 467 split_map = { | |
| 468 "Train": _safe_series(y_train_fold), | |
| 469 "Validation": _safe_series(y_val_fold), | |
| 470 "Test": _safe_series(y_test), | |
| 471 } | |
| 472 available = {k: v for k, v in split_map.items() if v is not None and not v.empty} | |
| 473 if not available: | |
| 474 return "" | |
| 475 | |
| 476 # Collect all labels across available splits (including NaN) | |
| 477 label_pool = pd.concat( | |
| 478 available.values(), ignore_index=True | |
| 479 ) | |
| 480 labels = pd.unique(label_pool) | |
| 481 | |
| 482 def _count_for_label(series, label): | |
| 483 if series is None or series.empty: | |
| 484 return None, None | |
| 485 total = len(series) | |
| 486 if pd.isna(label): | |
| 487 cnt = series.isna().sum() | |
| 488 else: | |
| 489 cnt = (series == label).sum() | |
| 490 return int(cnt), total | |
| 491 | |
| 492 rows = [] | |
| 493 for label in labels: | |
| 494 row = ["NaN" if pd.isna(label) else str(label)] | |
| 495 for split_name in ["Train", "Validation", "Test"]: | |
| 496 cnt, total = _count_for_label(split_map.get(split_name), label) | |
| 497 if cnt is None or total is None: | |
| 498 cell = "—" | |
| 499 else: | |
| 500 pct = (cnt / total * 100) if total else 0 | |
| 501 cell = f"{cnt} ({pct:.1f}%)" | |
| 502 row.append(cell) | |
| 503 rows.append(row) | |
| 504 | |
| 505 df = pd.DataFrame(rows, columns=["Label", "Train", "Validation", "Test"]) | |
| 506 df.sort_values("Label", inplace=True) | |
| 507 | |
| 508 return ( | |
| 509 "<h2>Dataset Overview</h2>" | |
| 510 + '<div class="table-wrapper">' | |
| 511 + df.to_html( | |
| 512 index=False, | |
| 513 classes=["table", "sortable", "table-dataset-overview"], | |
| 514 ) | |
| 515 + "</div>" | |
| 516 ) | |
| 517 | |
| 518 def _predict_with_thresholds(self, X, y_true): | |
| 519 """ | |
| 520 Generate predictions/probabilities for a split, respecting an optional | |
| 521 probability threshold for binary tasks. Returns a dict with y_true, | |
| 522 y_pred, y_scores (positive-class probs when available), pos_label, | |
| 523 and neg_label. | |
| 524 """ | |
| 525 if X is None or y_true is None: | |
| 526 return None | |
| 527 | |
| 528 y_true_series = pd.Series(y_true).reset_index(drop=True) | |
| 529 classes = list(getattr(self.best_model, "classes_", [])) | |
| 530 if not classes: | |
| 531 try: | |
| 532 classes = pd.unique(y_true_series).tolist() | |
| 533 except Exception: | |
| 534 classes = [] | |
| 535 if len(classes) > 1: | |
| 536 try: | |
| 537 pos_idx = classes.index(1) | |
| 538 except Exception: | |
| 539 pos_idx = 1 | |
| 540 else: | |
| 541 pos_idx = 0 | |
| 542 pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 | |
| 543 pos_label = ( | |
| 544 classes[pos_idx] | |
| 545 if len(classes) > pos_idx and pos_idx >= 0 | |
| 546 else (classes[-1] if classes else 1) | |
| 547 ) | |
| 548 neg_label = None | |
| 549 if len(classes) >= 2: | |
| 550 neg_candidates = [c for c in classes if c != pos_label] | |
| 551 if neg_candidates: | |
| 552 neg_label = neg_candidates[0] | |
| 553 | |
| 554 prob_thresh = getattr(self, "probability_threshold", None) | |
| 555 y_scores = None | |
| 556 try: | |
| 557 proba = self.best_model.predict_proba(X) | |
| 558 y_scores = np.asarray(proba) if proba is not None else None | |
| 559 except Exception: | |
| 560 y_scores = None | |
| 561 | |
| 562 try: | |
| 563 if ( | |
| 564 prob_thresh is not None | |
| 565 and not getattr(self.exp, "is_multiclass", False) | |
| 566 and y_scores is not None | |
| 567 and y_scores.ndim == 2 | |
| 568 and y_scores.shape[1] > 1 | |
| 569 ): | |
| 570 pos_idx = min(pos_idx, y_scores.shape[1] - 1) | |
| 571 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 | |
| 572 if neg_label is None and len(classes) > neg_idx: | |
| 573 neg_label = classes[neg_idx] | |
| 574 y_pred = np.where( | |
| 575 y_scores[:, pos_idx] >= prob_thresh, | |
| 576 pos_label, | |
| 577 neg_label if neg_label is not None else 0, | |
| 578 ) | |
| 579 y_scores = y_scores[:, pos_idx] | |
| 580 else: | |
| 581 y_pred = self.best_model.predict(X) | |
| 582 if ( | |
| 583 not getattr(self.exp, "is_multiclass", False) | |
| 584 and y_scores is not None | |
| 585 and y_scores.ndim == 2 | |
| 586 and y_scores.shape[1] > 1 | |
| 587 ): | |
| 588 pos_idx = min(pos_idx, y_scores.shape[1] - 1) | |
| 589 y_scores = y_scores[:, pos_idx] | |
| 590 except Exception as exc: | |
| 591 LOG.warning( | |
| 592 "Falling back to raw predict while computing performance summary: %s", | |
| 593 exc, | |
| 594 ) | |
| 595 try: | |
| 596 y_pred = self.best_model.predict(X) | |
| 597 except Exception as exc_inner: | |
| 598 LOG.warning( | |
| 599 "Unable to score split after fallback prediction: %s", | |
| 600 exc_inner, | |
| 601 ) | |
| 602 return None | |
| 603 y_scores = None | |
| 604 | |
| 605 y_pred_series = pd.Series(y_pred).reset_index(drop=True) | |
| 606 if y_scores is not None: | |
| 607 y_scores = np.asarray(y_scores) | |
| 608 if y_scores.ndim > 1 and y_scores.shape[1] == 1: | |
| 609 y_scores = y_scores.ravel() | |
| 610 if getattr(self.exp, "is_multiclass", False) and y_scores.ndim > 1: | |
| 611 # Avoid passing multiclass score matrices to ROC/PR utilities | |
| 612 y_scores = None | |
| 613 | |
| 614 return { | |
| 615 "y_true": y_true_series, | |
| 616 "y_pred": y_pred_series, | |
| 617 "y_scores": y_scores, | |
| 618 "pos_label": pos_label, | |
| 619 "neg_label": neg_label, | |
| 620 } | |
| 621 | |
| 622 def _get_cv_generator(self, y_series): | |
| 623 """ | |
| 624 Build a cross-validation splitter that mirrors the experiment's | |
| 625 configuration. Returns None when CV is disabled or not applicable. | |
| 626 """ | |
| 627 if self.task_type != "classification": | |
| 628 return None | |
| 629 | |
| 630 if getattr(self, "cross_validation", None) is False: | |
| 631 return None | |
| 632 | |
| 633 try: | |
| 634 cfg_gen = self.exp.get_config("fold_generator") | |
| 635 if cfg_gen is not None: | |
| 636 return cfg_gen | |
| 637 except Exception: | |
| 638 cfg_gen = None | |
| 639 | |
| 640 folds = ( | |
| 641 getattr(self, "cross_validation_folds", None) | |
| 642 or self.setup_params.get("fold") | |
| 643 or getattr(self.exp, "fold", None) | |
| 644 or 10 | |
| 645 ) | |
| 646 try: | |
| 647 folds = int(folds) | |
| 648 except Exception: | |
| 649 folds = 10 | |
| 650 | |
| 651 try: | |
| 652 y_series = pd.Series(y_series).reset_index(drop=True) | |
| 653 except Exception: | |
| 654 y_series = None | |
| 655 if y_series is None or y_series.empty: | |
| 656 return None | |
| 657 | |
| 658 if folds < 2: | |
| 659 return None | |
| 660 if len(y_series) < folds: | |
| 661 folds = len(y_series) | |
| 662 if folds < 2: | |
| 663 return None | |
| 664 | |
| 665 try: | |
| 666 from sklearn.model_selection import KFold, StratifiedKFold | |
| 667 | |
| 668 if self.task_type == "classification": | |
| 669 return StratifiedKFold( | |
| 670 n_splits=folds, | |
| 671 shuffle=True, | |
| 672 random_state=self.random_seed, | |
| 673 ) | |
| 674 return KFold( | |
| 675 n_splits=folds, | |
| 676 shuffle=True, | |
| 677 random_state=self.random_seed, | |
| 678 ) | |
| 679 except Exception as exc: | |
| 680 LOG.warning("Could not build CV generator: %s", exc) | |
| 681 return None | |
| 682 | |
| 683 def _get_cross_validated_predictions(self, X, y): | |
| 684 """ | |
| 685 Generate cross-validated predictions for the validation split so we | |
| 686 can report validation metrics for the selected best model. | |
| 687 """ | |
| 688 if self.task_type != "classification": | |
| 689 return None | |
| 690 if getattr(self, "cross_validation", None) is False: | |
| 691 return None | |
| 692 if X is None or y is None: | |
| 693 return None | |
| 694 | |
| 695 try: | |
| 696 from sklearn.model_selection import cross_val_predict | |
| 697 except Exception as exc: | |
| 698 LOG.warning("cross_val_predict unavailable: %s", exc) | |
| 699 return None | |
| 700 | |
| 701 y_series = pd.Series(y).reset_index(drop=True) | |
| 702 if y_series.empty: | |
| 703 return None | |
| 704 | |
| 705 cv_gen = self._get_cv_generator(y_series) | |
| 706 if cv_gen is None: | |
| 707 return None | |
| 708 | |
| 709 X_df = pd.DataFrame(X).reset_index(drop=True) | |
| 710 if len(X_df) != len(y_series): | |
| 711 X_df = X_df.iloc[: len(y_series)].reset_index(drop=True) | |
| 712 | |
| 713 classes = list(getattr(self.best_model, "classes_", [])) | |
| 714 if len(classes) > 1: | |
| 715 try: | |
| 716 pos_idx = classes.index(1) | |
| 717 except Exception: | |
| 718 pos_idx = 1 | |
| 719 else: | |
| 720 pos_idx = 0 | |
| 721 pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 | |
| 722 pos_label = ( | |
| 723 classes[pos_idx] if len(classes) > pos_idx else 1 | |
| 724 ) | |
| 725 neg_label = None | |
| 726 if len(classes) >= 2: | |
| 727 neg_candidates = [c for c in classes if c != pos_label] | |
| 728 if neg_candidates: | |
| 729 neg_label = neg_candidates[0] | |
| 730 | |
| 731 prob_thresh = getattr(self, "probability_threshold", None) | |
| 732 n_jobs = getattr(self, "n_jobs", None) | |
| 733 | |
| 734 y_scores = None | |
| 735 if not getattr(self.exp, "is_multiclass", False): | |
| 736 try: | |
| 737 proba = cross_val_predict( | |
| 738 self.best_model, | |
| 739 X_df, | |
| 740 y_series, | |
| 741 cv=cv_gen, | |
| 742 method="predict_proba", | |
| 743 n_jobs=n_jobs, | |
| 744 ) | |
| 745 y_scores = np.asarray(proba) | |
| 746 except Exception as exc: | |
| 747 LOG.debug("Could not compute CV probabilities: %s", exc) | |
| 748 | |
| 749 y_pred = None | |
| 750 if ( | |
| 751 prob_thresh is not None | |
| 752 and not getattr(self.exp, "is_multiclass", False) | |
| 753 and y_scores is not None | |
| 754 and y_scores.ndim == 2 | |
| 755 and y_scores.shape[1] > 1 | |
| 756 ): | |
| 757 pos_idx = min(pos_idx, y_scores.shape[1] - 1) | |
| 758 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 | |
| 759 if neg_label is None and len(classes) > neg_idx: | |
| 760 neg_label = classes[neg_idx] | |
| 761 y_pred = np.where( | |
| 762 y_scores[:, pos_idx] >= prob_thresh, | |
| 763 pos_label, | |
| 764 neg_label if neg_label is not None else 0, | |
| 765 ) | |
| 766 y_scores = y_scores[:, pos_idx] | |
| 767 else: | |
| 768 try: | |
| 769 y_pred = cross_val_predict( | |
| 770 self.best_model, | |
| 771 X_df, | |
| 772 y_series, | |
| 773 cv=cv_gen, | |
| 774 method="predict", | |
| 775 n_jobs=n_jobs, | |
| 776 ) | |
| 777 except Exception as exc: | |
| 778 LOG.warning( | |
| 779 "Could not compute cross-validated predictions: %s", | |
| 780 exc, | |
| 781 ) | |
| 782 return None | |
| 783 if ( | |
| 784 not getattr(self.exp, "is_multiclass", False) | |
| 785 and y_scores is not None | |
| 786 and y_scores.ndim == 2 | |
| 787 and y_scores.shape[1] > 1 | |
| 788 ): | |
| 789 pos_idx = min(pos_idx, y_scores.shape[1] - 1) | |
| 790 y_scores = y_scores[:, pos_idx] | |
| 791 | |
| 792 if y_scores is not None and getattr(self.exp, "is_multiclass", False): | |
| 793 y_scores = None | |
| 794 | |
| 795 return { | |
| 796 "y_true": y_series, | |
| 797 "y_pred": pd.Series(y_pred).reset_index(drop=True), | |
| 798 "y_scores": y_scores, | |
| 799 "pos_label": pos_label, | |
| 800 "neg_label": neg_label, | |
| 801 } | |
| 802 | |
| 803 def _get_split_predictions_for_report(self): | |
| 804 """ | |
| 805 Collect predictions/probabilities for Train/Validation/Test splits so the | |
| 806 performance table can show consistent metrics across splits. | |
| 807 """ | |
| 808 if self.task_type != "classification": | |
| 809 return {} | |
| 810 | |
| 811 def _get_from_config(keys): | |
| 812 for key in keys: | |
| 813 try: | |
| 814 val = self.exp.get_config(key) | |
| 815 except Exception: | |
| 816 val = getattr(self.exp, key, None) | |
| 817 if val is not None: | |
| 818 return val | |
| 819 return None | |
| 820 | |
| 821 X_train = _get_from_config(["X_train_transformed", "X_train"]) | |
| 822 y_train = _get_from_config(["y_train_transformed", "y_train"]) | |
| 823 X_holdout = _get_from_config(["X_test_transformed", "X_test"]) | |
| 824 y_holdout = _get_from_config(["y_test_transformed", "y_test"]) | |
| 825 | |
| 826 predictions = {} | |
| 827 | |
| 828 # Train metrics (best model on training data) | |
| 829 if X_train is not None and y_train is not None: | |
| 830 try: | |
| 831 train_preds = self._predict_with_thresholds(X_train, y_train) | |
| 832 if train_preds is not None: | |
| 833 predictions["Train"] = train_preds | |
| 834 except Exception as exc: | |
| 835 LOG.warning( | |
| 836 "Could not score Train split for performance summary: %s", | |
| 837 exc, | |
| 838 ) | |
| 839 | |
| 840 # Validation metrics via cross-validation on training data | |
| 841 try: | |
| 842 val_preds = self._get_cross_validated_predictions(X_train, y_train) | |
| 843 if val_preds is not None: | |
| 844 predictions["Validation"] = val_preds | |
| 845 except Exception as exc: | |
| 846 LOG.warning( | |
| 847 "Could not score Validation split for performance summary: %s", | |
| 848 exc, | |
| 849 ) | |
| 850 | |
| 851 # Test metrics (holdout from single file, or provided test file) | |
| 852 X_test = X_holdout | |
| 853 y_test = y_holdout | |
| 854 if (X_test is None or y_test is None) and self.test_data is not None: | |
| 855 try: | |
| 856 X_test = self.test_data.drop(columns=[self.target]) | |
| 857 y_test = self.test_data[self.target] | |
| 858 except Exception as exc: | |
| 859 LOG.warning( | |
| 860 "Could not prepare external test data for performance summary: %s", | |
| 861 exc, | |
| 862 ) | |
| 863 | |
| 864 if X_test is not None and y_test is not None: | |
| 865 try: | |
| 866 test_preds = self._predict_with_thresholds(X_test, y_test) | |
| 867 if test_preds is not None: | |
| 868 predictions["Test"] = test_preds | |
| 869 except Exception as exc: | |
| 870 LOG.warning( | |
| 871 "Could not score Test split for performance summary: %s", | |
| 872 exc, | |
| 873 ) | |
| 874 return predictions | |
| 875 | |
| 876 def _compute_metric_value(self, metric_name, preds, split_name): | |
| 877 """ | |
| 878 Compute a single metric for a given split prediction bundle. | |
| 879 """ | |
| 880 if preds is None: | |
| 881 return None | |
| 882 | |
| 883 y_true = preds["y_true"] | |
| 884 y_pred = preds["y_pred"] | |
| 885 y_scores = preds.get("y_scores") | |
| 886 pos_label = preds.get("pos_label") | |
| 887 neg_label = preds.get("neg_label") | |
| 888 is_multiclass = getattr(self.exp, "is_multiclass", False) | |
| 889 | |
| 890 def _format_binary_labels(series): | |
| 891 if pos_label is None: | |
| 892 return series | |
| 893 try: | |
| 894 return (series == pos_label).astype(int) | |
| 895 except Exception: | |
| 896 return series | |
| 897 | |
| 898 try: | |
| 899 if metric_name == "Accuracy": | |
| 900 return accuracy_score(y_true, y_pred) | |
| 901 if metric_name == "ROC-AUC": | |
| 902 if y_scores is None: | |
| 903 return None | |
| 904 y_true_bin = _format_binary_labels(y_true) | |
| 905 if len(pd.unique(y_true_bin)) < 2: | |
| 906 return None | |
| 907 return roc_auc_score(y_true_bin, y_scores) | |
| 908 if metric_name == "Precision": | |
| 909 if is_multiclass: | |
| 910 return precision_score( | |
| 911 y_true, y_pred, average="weighted", zero_division=0 | |
| 912 ) | |
| 913 try: | |
| 914 return precision_score( | |
| 915 y_true, y_pred, pos_label=pos_label, zero_division=0 | |
| 916 ) | |
| 917 except Exception: | |
| 918 return precision_score( | |
| 919 y_true, y_pred, average="weighted", zero_division=0 | |
| 920 ) | |
| 921 if metric_name == "Recall": | |
| 922 if is_multiclass: | |
| 923 return recall_score( | |
| 924 y_true, y_pred, average="weighted", zero_division=0 | |
| 925 ) | |
| 926 try: | |
| 927 return recall_score( | |
| 928 y_true, y_pred, pos_label=pos_label, zero_division=0 | |
| 929 ) | |
| 930 except Exception: | |
| 931 return recall_score( | |
| 932 y_true, y_pred, average="weighted", zero_division=0 | |
| 933 ) | |
| 934 if metric_name == "F1-Score": | |
| 935 if is_multiclass: | |
| 936 return f1_score( | |
| 937 y_true, y_pred, average="weighted", zero_division=0 | |
| 938 ) | |
| 939 try: | |
| 940 return f1_score( | |
| 941 y_true, y_pred, pos_label=pos_label, zero_division=0 | |
| 942 ) | |
| 943 except Exception: | |
| 944 return f1_score( | |
| 945 y_true, y_pred, average="weighted", zero_division=0 | |
| 946 ) | |
| 947 if metric_name == "PR-AUC": | |
| 948 if y_scores is None: | |
| 949 return None | |
| 950 y_true_bin = _format_binary_labels(y_true) | |
| 951 if len(pd.unique(y_true_bin)) < 2: | |
| 952 return None | |
| 953 return average_precision_score(y_true_bin, y_scores) | |
| 954 if metric_name == "Specificity": | |
| 955 labels = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) | |
| 956 if len(labels) != 2: | |
| 957 return None | |
| 958 if pos_label is None or pos_label not in labels: | |
| 959 pos_label = labels[1] | |
| 960 neg_candidates = [lbl for lbl in labels if lbl != pos_label] | |
| 961 neg_label_final = ( | |
| 962 neg_label if neg_label in labels else (neg_candidates[0] if neg_candidates else None) | |
| 963 ) | |
| 964 if neg_label_final is None: | |
| 965 return None | |
| 966 cm = confusion_matrix( | |
| 967 y_true, y_pred, labels=[neg_label_final, pos_label] | |
| 968 ) | |
| 969 if cm.shape != (2, 2): | |
| 970 return None | |
| 971 tn, fp, fn, tp = cm.ravel() | |
| 972 denom = tn + fp | |
| 973 return (tn / denom) if denom else None | |
| 974 if metric_name == "MCC": | |
| 975 return matthews_corrcoef(y_true, y_pred) | |
| 976 except Exception as exc: | |
| 977 LOG.warning( | |
| 978 "Could not compute %s for %s split: %s", | |
| 979 metric_name, | |
| 980 split_name, | |
| 981 exc, | |
| 982 ) | |
| 983 return None | |
| 984 return None | |
| 985 | |
| 986 def _build_performance_summary_table(self): | |
| 987 """ | |
| 988 Build a Train/Validation/Test metrics table for classification tasks. | |
| 989 Returns empty string when metrics are unavailable or not applicable. | |
| 990 """ | |
| 991 if self.task_type != "classification": | |
| 992 return "" | |
| 993 | |
| 994 split_predictions = self._get_split_predictions_for_report() | |
| 995 validation_best_row = None | |
| 996 try: | |
| 997 if isinstance(self.results, pd.DataFrame) and not self.results.empty: | |
| 998 validation_best_row = self.results.iloc[0] | |
| 999 except Exception: | |
| 1000 validation_best_row = None | |
| 1001 | |
| 1002 if not split_predictions and validation_best_row is None: | |
| 1003 return "" | |
| 1004 | |
| 1005 metric_names = [ | |
| 1006 "Accuracy", | |
| 1007 "ROC-AUC", | |
| 1008 "Precision", | |
| 1009 "Recall", | |
| 1010 "F1-Score", | |
| 1011 "PR-AUC", | |
| 1012 "Specificity", | |
| 1013 "MCC", | |
| 1014 ] | |
| 1015 | |
| 1016 validation_column_map = { | |
| 1017 "Accuracy": ["Accuracy"], | |
| 1018 "ROC-AUC": ["ROC-AUC", "AUC"], | |
| 1019 "Precision": ["Precision", "Prec.", "Prec"], | |
| 1020 "Recall": ["Recall"], | |
| 1021 "F1-Score": ["F1-Score", "F1"], | |
| 1022 "PR-AUC": ["PR-AUC", "PR-AUC-Weighted", "PRC"], | |
| 1023 "Specificity": ["Specificity"], | |
| 1024 "MCC": ["MCC"], | |
| 1025 } | |
| 1026 | |
| 1027 def _fmt(value): | |
| 1028 if value is None: | |
| 1029 return "—" | |
| 1030 try: | |
| 1031 if isinstance(value, (float, np.floating)) and ( | |
| 1032 np.isnan(value) or np.isinf(value) | |
| 1033 ): | |
| 1034 return "—" | |
| 1035 return f"{value:.3f}" | |
| 1036 except Exception: | |
| 1037 return str(value) | |
| 1038 | |
| 1039 def _validation_metric(metric_name): | |
| 1040 if validation_best_row is None: | |
| 1041 return None | |
| 1042 cols = validation_column_map.get(metric_name, []) | |
| 1043 for col in cols: | |
| 1044 if col in validation_best_row: | |
| 1045 try: | |
| 1046 return validation_best_row[col] | |
| 1047 except Exception: | |
| 1048 return None | |
| 1049 return None | |
| 1050 | |
| 1051 rows = [] | |
| 1052 for metric in metric_names: | |
| 1053 row = [metric] | |
| 1054 # Train | |
| 1055 train_val = self._compute_metric_value( | |
| 1056 metric, split_predictions.get("Train"), "Train" | |
| 1057 ) | |
| 1058 row.append(_fmt(train_val)) | |
| 1059 | |
| 1060 # Validation from Train & Validation Summary first row; fallback to computed CV. | |
| 1061 val_val = _validation_metric(metric) | |
| 1062 if val_val is None: | |
| 1063 val_val = self._compute_metric_value( | |
| 1064 metric, split_predictions.get("Validation"), "Validation" | |
| 1065 ) | |
| 1066 row.append(_fmt(val_val)) | |
| 1067 | |
| 1068 # Test | |
| 1069 test_val = self._compute_metric_value( | |
| 1070 metric, split_predictions.get("Test"), "Test" | |
| 1071 ) | |
| 1072 row.append(_fmt(test_val)) | |
| 1073 rows.append(row) | |
| 1074 | |
| 1075 df = pd.DataFrame(rows, columns=["Metric", "Train", "Validation", "Test"]) | |
| 1076 return ( | |
| 1077 "<h2>Model Performance Summary</h2>" | |
| 1078 + '<div class="table-wrapper">' | |
| 1079 + df.to_html( | |
| 1080 index=False, | |
| 1081 classes=["table", "sortable", "table-perf-summary"], | |
| 1082 ) | |
| 1083 + "</div>" | |
| 1084 ) | |
| 1085 | |
| 390 def _resolve_plot_callable(self, key, fig_or_fn, section): | 1086 def _resolve_plot_callable(self, key, fig_or_fn, section): |
| 391 """ | 1087 """ |
| 392 Safely execute stored plot callables so a single failure does not | 1088 Safely execute stored plot callables so a single failure does not |
| 393 abort the entire HTML report generation. | 1089 abort the entire HTML report generation. |
| 394 """ | 1090 """ |
| 519 # 5) Header | 1215 # 5) Header |
| 520 header = f"<h2>Best Model: {best_model_name}</h2>" | 1216 header = f"<h2>Best Model: {best_model_name}</h2>" |
| 521 | 1217 |
| 522 # — Validation Summary & Configuration — | 1218 # — Validation Summary & Configuration — |
| 523 val_df = self.results.copy() | 1219 val_df = self.results.copy() |
| 1220 dataset_overview_html = self._build_dataset_overview() | |
| 1221 performance_summary_html = self._build_performance_summary_table() | |
| 524 # mapping raw plot keys to user-friendly titles | 1222 # mapping raw plot keys to user-friendly titles |
| 525 plot_title_map = { | 1223 plot_title_map = { |
| 526 "learning": "Learning Curve", | 1224 "learning": "Learning Curve", |
| 527 "vc": "Validation Curve", | 1225 "vc": "Validation Curve", |
| 528 "calibration": "Calibration Curve", | 1226 "calibration": "Calibration Curve", |
| 529 "dimension": "Dimensionality Reduction", | 1227 "dimension": "Dimensionality Reduction", |
| 530 "manifold": "Manifold Learning", | 1228 "manifold": "t-SNE", |
| 531 "rfe": "Recursive Feature Elimination", | 1229 "rfe": "Recursive Feature Elimination", |
| 532 "threshold": "Threshold Plot", | 1230 "threshold": "Threshold Plot", |
| 533 "percentage_above_below": "Percentage Above vs. Below Cutoff", | 1231 "percentage_above_below": "Percentage Above vs. Below Cutoff", |
| 534 "class_report": "Classification Report", | 1232 "class_report": "Per-Class Metrics", |
| 535 "pr_auc": "Precision-Recall AUC", | 1233 "pr_auc": "Precision-Recall AUC", |
| 536 "roc_auc": "Receiver Operating Characteristic AUC", | 1234 "roc_auc": "Receiver Operating Characteristic AUC", |
| 537 "residuals": "Residuals Distribution", | 1235 "residuals": "Residuals Distribution", |
| 538 "error": "Prediction Error Distribution", | 1236 "error": "Prediction Error Distribution", |
| 539 } | 1237 } |
| 558 + '<div class="table-wrapper">' | 1256 + '<div class="table-wrapper">' |
| 559 + tuning_df.to_html(index=False, classes="table sortable") | 1257 + tuning_df.to_html(index=False, classes="table sortable") |
| 560 + "</div>" | 1258 + "</div>" |
| 561 ) | 1259 ) |
| 562 | 1260 |
| 563 summary_html += ( | 1261 config_html = ( |
| 564 "<h2>Setup Parameters</h2>" | 1262 header |
| 1263 + dataset_overview_html | |
| 1264 + performance_summary_html | |
| 1265 + "<h2>Setup Parameters</h2>" | |
| 565 + '<div class="table-wrapper">' | 1266 + '<div class="table-wrapper">' |
| 566 + df_setup.to_html(index=False, classes="table sortable") | 1267 + df_setup.to_html( |
| 1268 index=False, | |
| 1269 classes=["table", "sortable", "table-setup-params"], | |
| 1270 ) | |
| 567 + "</div>" | 1271 + "</div>" |
| 568 # — Hyperparameters | 1272 # — Hyperparameters |
| 569 + "<h2>Best Model Hyperparameters</h2>" | 1273 + "<h2>Best Model Hyperparameters</h2>" |
| 570 + '<div class="table-wrapper">' | 1274 + '<div class="table-wrapper">' |
| 571 + pd.DataFrame( | 1275 + pd.DataFrame( |
| 572 self.best_model.get_params().items(), | 1276 self.best_model.get_params().items(), |
| 573 columns=["Parameter", "Value"] | 1277 columns=["Parameter", "Value"] |
| 574 ).to_html(index=False, classes="table sortable") | 1278 ).to_html( |
| 1279 index=False, | |
| 1280 classes=["table", "sortable", "table-hyperparams"], | |
| 1281 ) | |
| 575 + "</div>" | 1282 + "</div>" |
| 576 ) | 1283 ) |
| 577 | 1284 |
| 578 # choose summary plots based on task type | 1285 # choose summary plots based on task type |
| 579 if self.task_type == "classification": | 1286 if self.task_type == "classification": |
| 580 summary_plots = [ | 1287 summary_plots = [ |
| 1288 "threshold", | |
| 581 "learning", | 1289 "learning", |
| 1290 "calibration", | |
| 1291 "rfe", | |
| 582 "vc", | 1292 "vc", |
| 583 "calibration", | |
| 584 "dimension", | 1293 "dimension", |
| 585 "manifold", | 1294 "manifold", |
| 586 "rfe", | |
| 587 "threshold", | |
| 588 "percentage_above_below", | 1295 "percentage_above_below", |
| 589 ] | 1296 ] |
| 590 else: | 1297 else: |
| 591 summary_plots = ["learning", "vc", "parameter", "residuals"] | 1298 summary_plots = ["learning", "vc", "parameter", "residuals"] |
| 592 | 1299 |
| 647 if self.task_type == "regression": | 1354 if self.task_type == "regression": |
| 648 test_order = ["residuals"] | 1355 test_order = ["residuals"] |
| 649 else: | 1356 else: |
| 650 test_order = [ | 1357 test_order = [ |
| 651 "confusion_matrix", | 1358 "confusion_matrix", |
| 1359 "class_report", | |
| 652 "roc_auc", | 1360 "roc_auc", |
| 653 "pr_auc", | 1361 "pr_auc", |
| 654 "lift_curve", | 1362 "lift_curve", |
| 655 "cumulative_precision", | 1363 "cumulative_precision", |
| 656 ] | 1364 ] |
| 1365 rendered_test_plots = set() | |
| 657 for key in test_order: | 1366 for key in test_order: |
| 658 fig_or_fn = self.explainer_plots.pop(key, None) | 1367 fig_or_fn = self.explainer_plots.pop(key, None) |
| 659 if fig_or_fn is not None: | 1368 if fig_or_fn is not None: |
| 660 fig = self._resolve_plot_callable( | 1369 fig = self._resolve_plot_callable( |
| 661 key, fig_or_fn, section="test/explainer" | 1370 key, fig_or_fn, section="test/explainer" |
| 662 ) | 1371 ) |
| 663 if fig is None: | 1372 if fig is None: |
| 664 continue | 1373 continue |
| 1374 rendered_test_plots.add(key) | |
| 665 title = plot_title_map.get( | 1375 title = plot_title_map.get( |
| 666 key, key.replace("_", " ").title() | 1376 key, key.replace("_", " ").title() |
| 667 ) | 1377 ) |
| 668 test_html += ( | 1378 test_html += ( |
| 669 f"<h2>{title}</h2>" + add_plot_to_html(fig) | 1379 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
| 677 name in { | 1387 name in { |
| 678 "pr_auc", | 1388 "pr_auc", |
| 679 "class_report", | 1389 "class_report", |
| 680 } | 1390 } |
| 681 ): | 1391 ): |
| 1392 if name in rendered_test_plots: | |
| 1393 continue | |
| 682 title = plot_title_map.get( | 1394 title = plot_title_map.get( |
| 683 name, name.replace("_", " ").title() | 1395 name, name.replace("_", " ").title() |
| 684 ) | 1396 ) |
| 685 b64 = encode_image_to_base64(path) | 1397 b64 = encode_image_to_base64(path) |
| 686 test_html += ( | 1398 test_html += ( |
| 748 ("Features used in SHAP", fi_analyzer.shap_used_features) | 1460 ("Features used in SHAP", fi_analyzer.shap_used_features) |
| 749 ) | 1461 ) |
| 750 if cap_rows: | 1462 if cap_rows: |
| 751 cap_table = ( | 1463 cap_table = ( |
| 752 "<div class='table-wrapper'>" | 1464 "<div class='table-wrapper'>" |
| 753 "<table class='table sortable'>" | 1465 "<table class='table sortable table-fi-scope'>" |
| 754 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" | 1466 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" |
| 755 "<tbody>" | 1467 "<tbody>" |
| 756 + "".join( | 1468 + "".join( |
| 757 f"<tr><td>{label}</td><td>{value}</td></tr>" | 1469 f"<tr><td>{label}</td><td>{value}</td></tr>" |
| 758 for label, value in cap_rows | 1470 for label, value in cap_rows |
| 801 + add_hr_to_html() | 1513 + add_hr_to_html() |
| 802 ) | 1514 ) |
| 803 # 7) Assemble final HTML (three tabs) | 1515 # 7) Assemble final HTML (three tabs) |
| 804 html = get_html_template() | 1516 html = get_html_template() |
| 805 html += "<h1>Tabular Learner Model Report</h1>" | 1517 html += "<h1>Tabular Learner Model Report</h1>" |
| 806 html += build_tabbed_html(summary_html, test_html, feature_html) | 1518 html += build_tabbed_html( |
| 1519 summary_html, | |
| 1520 test_html, | |
| 1521 feature_html, | |
| 1522 explainer_html=None, | |
| 1523 config_html=config_html, | |
| 1524 ) | |
| 807 html += get_feature_metrics_help_modal() | 1525 html += get_feature_metrics_help_modal() |
| 808 html += get_html_closing() | 1526 html += get_html_closing() |
| 809 | 1527 |
| 810 # 8) Write out | 1528 # 8) Write out |
| 811 (Path(self.output_dir) / "comparison_result.html").write_text( | 1529 (Path(self.output_dir) / "comparison_result.html").write_text( |
| 821 | 1539 |
| 822 def generate_plots_explainer(self): | 1540 def generate_plots_explainer(self): |
| 823 raise NotImplementedError("Subclasses should implement this method") | 1541 raise NotImplementedError("Subclasses should implement this method") |
| 824 | 1542 |
| 825 def generate_tree_plots(self): | 1543 def generate_tree_plots(self): |
| 1544 from explainerdashboard.explainers import RandomForestExplainer | |
| 826 from sklearn.ensemble import ( | 1545 from sklearn.ensemble import ( |
| 827 RandomForestClassifier, RandomForestRegressor | 1546 RandomForestClassifier, RandomForestRegressor |
| 828 ) | 1547 ) |
| 829 from xgboost import XGBClassifier, XGBRegressor | 1548 from xgboost import XGBClassifier, XGBRegressor |
| 830 from explainerdashboard.explainers import RandomForestExplainer | |
| 831 | 1549 |
| 832 LOG.info("Generating tree plots") | 1550 LOG.info("Generating tree plots") |
| 833 X_test = self.exp.X_test_transformed.copy() | 1551 X_test = self.exp.X_test_transformed.copy() |
| 834 y_test = self.exp.y_test_transformed | 1552 y_test = self.exp.y_test_transformed |
| 835 | 1553 |
