comparison base_model_trainer.py @ 17:c5c324ac29fc draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
author goeckslab
date Sat, 06 Dec 2025 14:20:36 +0000
parents 4fee4504646e
children
comparison
equal deleted inserted replaced
16:4fee4504646e 17:c5c324ac29fc
7 import joblib 7 import joblib
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 from feature_help_modal import get_feature_metrics_help_modal 10 from feature_help_modal import get_feature_metrics_help_modal
11 from feature_importance import FeatureImportanceAnalyzer 11 from feature_importance import FeatureImportanceAnalyzer
12 from sklearn.metrics import average_precision_score 12 from sklearn.metrics import (
13 accuracy_score,
14 average_precision_score,
15 confusion_matrix,
16 f1_score,
17 matthews_corrcoef,
18 precision_score,
19 recall_score,
20 roc_auc_score,
21 )
13 from utils import ( 22 from utils import (
14 add_hr_to_html, 23 add_hr_to_html,
15 add_plot_to_html, 24 add_plot_to_html,
16 build_tabbed_html, 25 build_tabbed_html,
17 encode_image_to_base64, 26 encode_image_to_base64,
385 394
386 def encode_image_to_base64(self, img_path: str) -> str: 395 def encode_image_to_base64(self, img_path: str) -> str:
387 with open(img_path, "rb") as img_file: 396 with open(img_path, "rb") as img_file:
388 return base64.b64encode(img_file.read()).decode("utf-8") 397 return base64.b64encode(img_file.read()).decode("utf-8")
389 398
399 def _build_dataset_overview(self):
400 """
401 Build an HTML table showing label counts with labels as rows and splits
402 (Train / Validation / Test) as columns. Each cell shows count and
403 percentage of that split. Returns empty string for regression or when
404 no label data is available.
405 """
406 if self.task_type != "classification":
407 return ""
408
409 def _safe_series(obj):
410 try:
411 return pd.Series(obj).reset_index(drop=True)
412 except Exception:
413 return None
414
415 def _get_from_config(keys):
416 if self.exp is None:
417 return None
418 for key in keys:
419 try:
420 val = self.exp.get_config(key)
421 except Exception:
422 val = getattr(self.exp, key, None)
423 if val is not None:
424 return val
425 return None
426
427 # Prefer PyCaret-configured splits; fall back to raw inputs.
428 X_train = _get_from_config(["X_train_transformed", "X_train"])
429 y_train = _get_from_config(["y_train_transformed", "y_train"])
430 y_test_cfg = _get_from_config(["y_test_transformed", "y_test"])
431
432 if y_train is None and self.data is not None and self.target in self.data.columns:
433 y_train = self.data[self.target]
434
435 y_train_series = _safe_series(y_train)
436
437 # Build a cross-validation generator to derive a validation subset size.
438 cv_gen = self._get_cv_generator(y_train_series)
439 y_train_fold = y_train_series
440 y_val_fold = None
441 if cv_gen is not None and y_train_series is not None:
442 try:
443 # Use the first fold to approximate Train/Validation split sizes.
444 splitter = cv_gen.split(
445 pd.DataFrame(X_train).reset_index(drop=True)
446 if X_train is not None
447 else y_train_series,
448 y_train_series,
449 )
450 train_idx, val_idx = next(iter(splitter))
451 y_train_fold = y_train_series.iloc[train_idx].reset_index(drop=True)
452 y_val_fold = y_train_series.iloc[val_idx].reset_index(drop=True)
453 except Exception as exc:
454 LOG.warning("Could not derive validation split for dataset overview: %s", exc)
455
456 # Test labels: prefer PyCaret transformed holdout (single file) or external test.
457 if self.test_data is not None:
458 if y_test_cfg is not None:
459 y_test = y_test_cfg
460 elif self.target in self.test_data.columns:
461 y_test = self.test_data[self.target]
462 else:
463 y_test = None
464 else:
465 y_test = y_test_cfg
466
467 split_map = {
468 "Train": _safe_series(y_train_fold),
469 "Validation": _safe_series(y_val_fold),
470 "Test": _safe_series(y_test),
471 }
472 available = {k: v for k, v in split_map.items() if v is not None and not v.empty}
473 if not available:
474 return ""
475
476 # Collect all labels across available splits (including NaN)
477 label_pool = pd.concat(
478 available.values(), ignore_index=True
479 )
480 labels = pd.unique(label_pool)
481
482 def _count_for_label(series, label):
483 if series is None or series.empty:
484 return None, None
485 total = len(series)
486 if pd.isna(label):
487 cnt = series.isna().sum()
488 else:
489 cnt = (series == label).sum()
490 return int(cnt), total
491
492 rows = []
493 for label in labels:
494 row = ["NaN" if pd.isna(label) else str(label)]
495 for split_name in ["Train", "Validation", "Test"]:
496 cnt, total = _count_for_label(split_map.get(split_name), label)
497 if cnt is None or total is None:
498 cell = "—"
499 else:
500 pct = (cnt / total * 100) if total else 0
501 cell = f"{cnt} ({pct:.1f}%)"
502 row.append(cell)
503 rows.append(row)
504
505 df = pd.DataFrame(rows, columns=["Label", "Train", "Validation", "Test"])
506 df.sort_values("Label", inplace=True)
507
508 return (
509 "<h2>Dataset Overview</h2>"
510 + '<div class="table-wrapper">'
511 + df.to_html(
512 index=False,
513 classes=["table", "sortable", "table-dataset-overview"],
514 )
515 + "</div>"
516 )
517
518 def _predict_with_thresholds(self, X, y_true):
519 """
520 Generate predictions/probabilities for a split, respecting an optional
521 probability threshold for binary tasks. Returns a dict with y_true,
522 y_pred, y_scores (positive-class probs when available), pos_label,
523 and neg_label.
524 """
525 if X is None or y_true is None:
526 return None
527
528 y_true_series = pd.Series(y_true).reset_index(drop=True)
529 classes = list(getattr(self.best_model, "classes_", []))
530 if not classes:
531 try:
532 classes = pd.unique(y_true_series).tolist()
533 except Exception:
534 classes = []
535 if len(classes) > 1:
536 try:
537 pos_idx = classes.index(1)
538 except Exception:
539 pos_idx = 1
540 else:
541 pos_idx = 0
542 pos_idx = min(pos_idx, len(classes) - 1) if classes else 0
543 pos_label = (
544 classes[pos_idx]
545 if len(classes) > pos_idx and pos_idx >= 0
546 else (classes[-1] if classes else 1)
547 )
548 neg_label = None
549 if len(classes) >= 2:
550 neg_candidates = [c for c in classes if c != pos_label]
551 if neg_candidates:
552 neg_label = neg_candidates[0]
553
554 prob_thresh = getattr(self, "probability_threshold", None)
555 y_scores = None
556 try:
557 proba = self.best_model.predict_proba(X)
558 y_scores = np.asarray(proba) if proba is not None else None
559 except Exception:
560 y_scores = None
561
562 try:
563 if (
564 prob_thresh is not None
565 and not getattr(self.exp, "is_multiclass", False)
566 and y_scores is not None
567 and y_scores.ndim == 2
568 and y_scores.shape[1] > 1
569 ):
570 pos_idx = min(pos_idx, y_scores.shape[1] - 1)
571 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
572 if neg_label is None and len(classes) > neg_idx:
573 neg_label = classes[neg_idx]
574 y_pred = np.where(
575 y_scores[:, pos_idx] >= prob_thresh,
576 pos_label,
577 neg_label if neg_label is not None else 0,
578 )
579 y_scores = y_scores[:, pos_idx]
580 else:
581 y_pred = self.best_model.predict(X)
582 if (
583 not getattr(self.exp, "is_multiclass", False)
584 and y_scores is not None
585 and y_scores.ndim == 2
586 and y_scores.shape[1] > 1
587 ):
588 pos_idx = min(pos_idx, y_scores.shape[1] - 1)
589 y_scores = y_scores[:, pos_idx]
590 except Exception as exc:
591 LOG.warning(
592 "Falling back to raw predict while computing performance summary: %s",
593 exc,
594 )
595 try:
596 y_pred = self.best_model.predict(X)
597 except Exception as exc_inner:
598 LOG.warning(
599 "Unable to score split after fallback prediction: %s",
600 exc_inner,
601 )
602 return None
603 y_scores = None
604
605 y_pred_series = pd.Series(y_pred).reset_index(drop=True)
606 if y_scores is not None:
607 y_scores = np.asarray(y_scores)
608 if y_scores.ndim > 1 and y_scores.shape[1] == 1:
609 y_scores = y_scores.ravel()
610 if getattr(self.exp, "is_multiclass", False) and y_scores.ndim > 1:
611 # Avoid passing multiclass score matrices to ROC/PR utilities
612 y_scores = None
613
614 return {
615 "y_true": y_true_series,
616 "y_pred": y_pred_series,
617 "y_scores": y_scores,
618 "pos_label": pos_label,
619 "neg_label": neg_label,
620 }
621
622 def _get_cv_generator(self, y_series):
623 """
624 Build a cross-validation splitter that mirrors the experiment's
625 configuration. Returns None when CV is disabled or not applicable.
626 """
627 if self.task_type != "classification":
628 return None
629
630 if getattr(self, "cross_validation", None) is False:
631 return None
632
633 try:
634 cfg_gen = self.exp.get_config("fold_generator")
635 if cfg_gen is not None:
636 return cfg_gen
637 except Exception:
638 cfg_gen = None
639
640 folds = (
641 getattr(self, "cross_validation_folds", None)
642 or self.setup_params.get("fold")
643 or getattr(self.exp, "fold", None)
644 or 10
645 )
646 try:
647 folds = int(folds)
648 except Exception:
649 folds = 10
650
651 try:
652 y_series = pd.Series(y_series).reset_index(drop=True)
653 except Exception:
654 y_series = None
655 if y_series is None or y_series.empty:
656 return None
657
658 if folds < 2:
659 return None
660 if len(y_series) < folds:
661 folds = len(y_series)
662 if folds < 2:
663 return None
664
665 try:
666 from sklearn.model_selection import KFold, StratifiedKFold
667
668 if self.task_type == "classification":
669 return StratifiedKFold(
670 n_splits=folds,
671 shuffle=True,
672 random_state=self.random_seed,
673 )
674 return KFold(
675 n_splits=folds,
676 shuffle=True,
677 random_state=self.random_seed,
678 )
679 except Exception as exc:
680 LOG.warning("Could not build CV generator: %s", exc)
681 return None
682
683 def _get_cross_validated_predictions(self, X, y):
684 """
685 Generate cross-validated predictions for the validation split so we
686 can report validation metrics for the selected best model.
687 """
688 if self.task_type != "classification":
689 return None
690 if getattr(self, "cross_validation", None) is False:
691 return None
692 if X is None or y is None:
693 return None
694
695 try:
696 from sklearn.model_selection import cross_val_predict
697 except Exception as exc:
698 LOG.warning("cross_val_predict unavailable: %s", exc)
699 return None
700
701 y_series = pd.Series(y).reset_index(drop=True)
702 if y_series.empty:
703 return None
704
705 cv_gen = self._get_cv_generator(y_series)
706 if cv_gen is None:
707 return None
708
709 X_df = pd.DataFrame(X).reset_index(drop=True)
710 if len(X_df) != len(y_series):
711 X_df = X_df.iloc[: len(y_series)].reset_index(drop=True)
712
713 classes = list(getattr(self.best_model, "classes_", []))
714 if len(classes) > 1:
715 try:
716 pos_idx = classes.index(1)
717 except Exception:
718 pos_idx = 1
719 else:
720 pos_idx = 0
721 pos_idx = min(pos_idx, len(classes) - 1) if classes else 0
722 pos_label = (
723 classes[pos_idx] if len(classes) > pos_idx else 1
724 )
725 neg_label = None
726 if len(classes) >= 2:
727 neg_candidates = [c for c in classes if c != pos_label]
728 if neg_candidates:
729 neg_label = neg_candidates[0]
730
731 prob_thresh = getattr(self, "probability_threshold", None)
732 n_jobs = getattr(self, "n_jobs", None)
733
734 y_scores = None
735 if not getattr(self.exp, "is_multiclass", False):
736 try:
737 proba = cross_val_predict(
738 self.best_model,
739 X_df,
740 y_series,
741 cv=cv_gen,
742 method="predict_proba",
743 n_jobs=n_jobs,
744 )
745 y_scores = np.asarray(proba)
746 except Exception as exc:
747 LOG.debug("Could not compute CV probabilities: %s", exc)
748
749 y_pred = None
750 if (
751 prob_thresh is not None
752 and not getattr(self.exp, "is_multiclass", False)
753 and y_scores is not None
754 and y_scores.ndim == 2
755 and y_scores.shape[1] > 1
756 ):
757 pos_idx = min(pos_idx, y_scores.shape[1] - 1)
758 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
759 if neg_label is None and len(classes) > neg_idx:
760 neg_label = classes[neg_idx]
761 y_pred = np.where(
762 y_scores[:, pos_idx] >= prob_thresh,
763 pos_label,
764 neg_label if neg_label is not None else 0,
765 )
766 y_scores = y_scores[:, pos_idx]
767 else:
768 try:
769 y_pred = cross_val_predict(
770 self.best_model,
771 X_df,
772 y_series,
773 cv=cv_gen,
774 method="predict",
775 n_jobs=n_jobs,
776 )
777 except Exception as exc:
778 LOG.warning(
779 "Could not compute cross-validated predictions: %s",
780 exc,
781 )
782 return None
783 if (
784 not getattr(self.exp, "is_multiclass", False)
785 and y_scores is not None
786 and y_scores.ndim == 2
787 and y_scores.shape[1] > 1
788 ):
789 pos_idx = min(pos_idx, y_scores.shape[1] - 1)
790 y_scores = y_scores[:, pos_idx]
791
792 if y_scores is not None and getattr(self.exp, "is_multiclass", False):
793 y_scores = None
794
795 return {
796 "y_true": y_series,
797 "y_pred": pd.Series(y_pred).reset_index(drop=True),
798 "y_scores": y_scores,
799 "pos_label": pos_label,
800 "neg_label": neg_label,
801 }
802
803 def _get_split_predictions_for_report(self):
804 """
805 Collect predictions/probabilities for Train/Validation/Test splits so the
806 performance table can show consistent metrics across splits.
807 """
808 if self.task_type != "classification":
809 return {}
810
811 def _get_from_config(keys):
812 for key in keys:
813 try:
814 val = self.exp.get_config(key)
815 except Exception:
816 val = getattr(self.exp, key, None)
817 if val is not None:
818 return val
819 return None
820
821 X_train = _get_from_config(["X_train_transformed", "X_train"])
822 y_train = _get_from_config(["y_train_transformed", "y_train"])
823 X_holdout = _get_from_config(["X_test_transformed", "X_test"])
824 y_holdout = _get_from_config(["y_test_transformed", "y_test"])
825
826 predictions = {}
827
828 # Train metrics (best model on training data)
829 if X_train is not None and y_train is not None:
830 try:
831 train_preds = self._predict_with_thresholds(X_train, y_train)
832 if train_preds is not None:
833 predictions["Train"] = train_preds
834 except Exception as exc:
835 LOG.warning(
836 "Could not score Train split for performance summary: %s",
837 exc,
838 )
839
840 # Validation metrics via cross-validation on training data
841 try:
842 val_preds = self._get_cross_validated_predictions(X_train, y_train)
843 if val_preds is not None:
844 predictions["Validation"] = val_preds
845 except Exception as exc:
846 LOG.warning(
847 "Could not score Validation split for performance summary: %s",
848 exc,
849 )
850
851 # Test metrics (holdout from single file, or provided test file)
852 X_test = X_holdout
853 y_test = y_holdout
854 if (X_test is None or y_test is None) and self.test_data is not None:
855 try:
856 X_test = self.test_data.drop(columns=[self.target])
857 y_test = self.test_data[self.target]
858 except Exception as exc:
859 LOG.warning(
860 "Could not prepare external test data for performance summary: %s",
861 exc,
862 )
863
864 if X_test is not None and y_test is not None:
865 try:
866 test_preds = self._predict_with_thresholds(X_test, y_test)
867 if test_preds is not None:
868 predictions["Test"] = test_preds
869 except Exception as exc:
870 LOG.warning(
871 "Could not score Test split for performance summary: %s",
872 exc,
873 )
874 return predictions
875
876 def _compute_metric_value(self, metric_name, preds, split_name):
877 """
878 Compute a single metric for a given split prediction bundle.
879 """
880 if preds is None:
881 return None
882
883 y_true = preds["y_true"]
884 y_pred = preds["y_pred"]
885 y_scores = preds.get("y_scores")
886 pos_label = preds.get("pos_label")
887 neg_label = preds.get("neg_label")
888 is_multiclass = getattr(self.exp, "is_multiclass", False)
889
890 def _format_binary_labels(series):
891 if pos_label is None:
892 return series
893 try:
894 return (series == pos_label).astype(int)
895 except Exception:
896 return series
897
898 try:
899 if metric_name == "Accuracy":
900 return accuracy_score(y_true, y_pred)
901 if metric_name == "ROC-AUC":
902 if y_scores is None:
903 return None
904 y_true_bin = _format_binary_labels(y_true)
905 if len(pd.unique(y_true_bin)) < 2:
906 return None
907 return roc_auc_score(y_true_bin, y_scores)
908 if metric_name == "Precision":
909 if is_multiclass:
910 return precision_score(
911 y_true, y_pred, average="weighted", zero_division=0
912 )
913 try:
914 return precision_score(
915 y_true, y_pred, pos_label=pos_label, zero_division=0
916 )
917 except Exception:
918 return precision_score(
919 y_true, y_pred, average="weighted", zero_division=0
920 )
921 if metric_name == "Recall":
922 if is_multiclass:
923 return recall_score(
924 y_true, y_pred, average="weighted", zero_division=0
925 )
926 try:
927 return recall_score(
928 y_true, y_pred, pos_label=pos_label, zero_division=0
929 )
930 except Exception:
931 return recall_score(
932 y_true, y_pred, average="weighted", zero_division=0
933 )
934 if metric_name == "F1-Score":
935 if is_multiclass:
936 return f1_score(
937 y_true, y_pred, average="weighted", zero_division=0
938 )
939 try:
940 return f1_score(
941 y_true, y_pred, pos_label=pos_label, zero_division=0
942 )
943 except Exception:
944 return f1_score(
945 y_true, y_pred, average="weighted", zero_division=0
946 )
947 if metric_name == "PR-AUC":
948 if y_scores is None:
949 return None
950 y_true_bin = _format_binary_labels(y_true)
951 if len(pd.unique(y_true_bin)) < 2:
952 return None
953 return average_precision_score(y_true_bin, y_scores)
954 if metric_name == "Specificity":
955 labels = pd.unique(pd.concat([y_true, y_pred], ignore_index=True))
956 if len(labels) != 2:
957 return None
958 if pos_label is None or pos_label not in labels:
959 pos_label = labels[1]
960 neg_candidates = [lbl for lbl in labels if lbl != pos_label]
961 neg_label_final = (
962 neg_label if neg_label in labels else (neg_candidates[0] if neg_candidates else None)
963 )
964 if neg_label_final is None:
965 return None
966 cm = confusion_matrix(
967 y_true, y_pred, labels=[neg_label_final, pos_label]
968 )
969 if cm.shape != (2, 2):
970 return None
971 tn, fp, fn, tp = cm.ravel()
972 denom = tn + fp
973 return (tn / denom) if denom else None
974 if metric_name == "MCC":
975 return matthews_corrcoef(y_true, y_pred)
976 except Exception as exc:
977 LOG.warning(
978 "Could not compute %s for %s split: %s",
979 metric_name,
980 split_name,
981 exc,
982 )
983 return None
984 return None
985
986 def _build_performance_summary_table(self):
987 """
988 Build a Train/Validation/Test metrics table for classification tasks.
989 Returns empty string when metrics are unavailable or not applicable.
990 """
991 if self.task_type != "classification":
992 return ""
993
994 split_predictions = self._get_split_predictions_for_report()
995 validation_best_row = None
996 try:
997 if isinstance(self.results, pd.DataFrame) and not self.results.empty:
998 validation_best_row = self.results.iloc[0]
999 except Exception:
1000 validation_best_row = None
1001
1002 if not split_predictions and validation_best_row is None:
1003 return ""
1004
1005 metric_names = [
1006 "Accuracy",
1007 "ROC-AUC",
1008 "Precision",
1009 "Recall",
1010 "F1-Score",
1011 "PR-AUC",
1012 "Specificity",
1013 "MCC",
1014 ]
1015
1016 validation_column_map = {
1017 "Accuracy": ["Accuracy"],
1018 "ROC-AUC": ["ROC-AUC", "AUC"],
1019 "Precision": ["Precision", "Prec.", "Prec"],
1020 "Recall": ["Recall"],
1021 "F1-Score": ["F1-Score", "F1"],
1022 "PR-AUC": ["PR-AUC", "PR-AUC-Weighted", "PRC"],
1023 "Specificity": ["Specificity"],
1024 "MCC": ["MCC"],
1025 }
1026
1027 def _fmt(value):
1028 if value is None:
1029 return "—"
1030 try:
1031 if isinstance(value, (float, np.floating)) and (
1032 np.isnan(value) or np.isinf(value)
1033 ):
1034 return "—"
1035 return f"{value:.3f}"
1036 except Exception:
1037 return str(value)
1038
1039 def _validation_metric(metric_name):
1040 if validation_best_row is None:
1041 return None
1042 cols = validation_column_map.get(metric_name, [])
1043 for col in cols:
1044 if col in validation_best_row:
1045 try:
1046 return validation_best_row[col]
1047 except Exception:
1048 return None
1049 return None
1050
1051 rows = []
1052 for metric in metric_names:
1053 row = [metric]
1054 # Train
1055 train_val = self._compute_metric_value(
1056 metric, split_predictions.get("Train"), "Train"
1057 )
1058 row.append(_fmt(train_val))
1059
1060 # Validation from Train & Validation Summary first row; fallback to computed CV.
1061 val_val = _validation_metric(metric)
1062 if val_val is None:
1063 val_val = self._compute_metric_value(
1064 metric, split_predictions.get("Validation"), "Validation"
1065 )
1066 row.append(_fmt(val_val))
1067
1068 # Test
1069 test_val = self._compute_metric_value(
1070 metric, split_predictions.get("Test"), "Test"
1071 )
1072 row.append(_fmt(test_val))
1073 rows.append(row)
1074
1075 df = pd.DataFrame(rows, columns=["Metric", "Train", "Validation", "Test"])
1076 return (
1077 "<h2>Model Performance Summary</h2>"
1078 + '<div class="table-wrapper">'
1079 + df.to_html(
1080 index=False,
1081 classes=["table", "sortable", "table-perf-summary"],
1082 )
1083 + "</div>"
1084 )
1085
390 def _resolve_plot_callable(self, key, fig_or_fn, section): 1086 def _resolve_plot_callable(self, key, fig_or_fn, section):
391 """ 1087 """
392 Safely execute stored plot callables so a single failure does not 1088 Safely execute stored plot callables so a single failure does not
393 abort the entire HTML report generation. 1089 abort the entire HTML report generation.
394 """ 1090 """
519 # 5) Header 1215 # 5) Header
520 header = f"<h2>Best Model: {best_model_name}</h2>" 1216 header = f"<h2>Best Model: {best_model_name}</h2>"
521 1217
522 # — Validation Summary & Configuration — 1218 # — Validation Summary & Configuration —
523 val_df = self.results.copy() 1219 val_df = self.results.copy()
1220 dataset_overview_html = self._build_dataset_overview()
1221 performance_summary_html = self._build_performance_summary_table()
524 # mapping raw plot keys to user-friendly titles 1222 # mapping raw plot keys to user-friendly titles
525 plot_title_map = { 1223 plot_title_map = {
526 "learning": "Learning Curve", 1224 "learning": "Learning Curve",
527 "vc": "Validation Curve", 1225 "vc": "Validation Curve",
528 "calibration": "Calibration Curve", 1226 "calibration": "Calibration Curve",
529 "dimension": "Dimensionality Reduction", 1227 "dimension": "Dimensionality Reduction",
530 "manifold": "Manifold Learning", 1228 "manifold": "t-SNE",
531 "rfe": "Recursive Feature Elimination", 1229 "rfe": "Recursive Feature Elimination",
532 "threshold": "Threshold Plot", 1230 "threshold": "Threshold Plot",
533 "percentage_above_below": "Percentage Above vs. Below Cutoff", 1231 "percentage_above_below": "Percentage Above vs. Below Cutoff",
534 "class_report": "Classification Report", 1232 "class_report": "Per-Class Metrics",
535 "pr_auc": "Precision-Recall AUC", 1233 "pr_auc": "Precision-Recall AUC",
536 "roc_auc": "Receiver Operating Characteristic AUC", 1234 "roc_auc": "Receiver Operating Characteristic AUC",
537 "residuals": "Residuals Distribution", 1235 "residuals": "Residuals Distribution",
538 "error": "Prediction Error Distribution", 1236 "error": "Prediction Error Distribution",
539 } 1237 }
558 + '<div class="table-wrapper">' 1256 + '<div class="table-wrapper">'
559 + tuning_df.to_html(index=False, classes="table sortable") 1257 + tuning_df.to_html(index=False, classes="table sortable")
560 + "</div>" 1258 + "</div>"
561 ) 1259 )
562 1260
563 summary_html += ( 1261 config_html = (
564 "<h2>Setup Parameters</h2>" 1262 header
1263 + dataset_overview_html
1264 + performance_summary_html
1265 + "<h2>Setup Parameters</h2>"
565 + '<div class="table-wrapper">' 1266 + '<div class="table-wrapper">'
566 + df_setup.to_html(index=False, classes="table sortable") 1267 + df_setup.to_html(
1268 index=False,
1269 classes=["table", "sortable", "table-setup-params"],
1270 )
567 + "</div>" 1271 + "</div>"
568 # — Hyperparameters 1272 # — Hyperparameters
569 + "<h2>Best Model Hyperparameters</h2>" 1273 + "<h2>Best Model Hyperparameters</h2>"
570 + '<div class="table-wrapper">' 1274 + '<div class="table-wrapper">'
571 + pd.DataFrame( 1275 + pd.DataFrame(
572 self.best_model.get_params().items(), 1276 self.best_model.get_params().items(),
573 columns=["Parameter", "Value"] 1277 columns=["Parameter", "Value"]
574 ).to_html(index=False, classes="table sortable") 1278 ).to_html(
1279 index=False,
1280 classes=["table", "sortable", "table-hyperparams"],
1281 )
575 + "</div>" 1282 + "</div>"
576 ) 1283 )
577 1284
578 # choose summary plots based on task type 1285 # choose summary plots based on task type
579 if self.task_type == "classification": 1286 if self.task_type == "classification":
580 summary_plots = [ 1287 summary_plots = [
1288 "threshold",
581 "learning", 1289 "learning",
1290 "calibration",
1291 "rfe",
582 "vc", 1292 "vc",
583 "calibration",
584 "dimension", 1293 "dimension",
585 "manifold", 1294 "manifold",
586 "rfe",
587 "threshold",
588 "percentage_above_below", 1295 "percentage_above_below",
589 ] 1296 ]
590 else: 1297 else:
591 summary_plots = ["learning", "vc", "parameter", "residuals"] 1298 summary_plots = ["learning", "vc", "parameter", "residuals"]
592 1299
647 if self.task_type == "regression": 1354 if self.task_type == "regression":
648 test_order = ["residuals"] 1355 test_order = ["residuals"]
649 else: 1356 else:
650 test_order = [ 1357 test_order = [
651 "confusion_matrix", 1358 "confusion_matrix",
1359 "class_report",
652 "roc_auc", 1360 "roc_auc",
653 "pr_auc", 1361 "pr_auc",
654 "lift_curve", 1362 "lift_curve",
655 "cumulative_precision", 1363 "cumulative_precision",
656 ] 1364 ]
1365 rendered_test_plots = set()
657 for key in test_order: 1366 for key in test_order:
658 fig_or_fn = self.explainer_plots.pop(key, None) 1367 fig_or_fn = self.explainer_plots.pop(key, None)
659 if fig_or_fn is not None: 1368 if fig_or_fn is not None:
660 fig = self._resolve_plot_callable( 1369 fig = self._resolve_plot_callable(
661 key, fig_or_fn, section="test/explainer" 1370 key, fig_or_fn, section="test/explainer"
662 ) 1371 )
663 if fig is None: 1372 if fig is None:
664 continue 1373 continue
1374 rendered_test_plots.add(key)
665 title = plot_title_map.get( 1375 title = plot_title_map.get(
666 key, key.replace("_", " ").title() 1376 key, key.replace("_", " ").title()
667 ) 1377 )
668 test_html += ( 1378 test_html += (
669 f"<h2>{title}</h2>" + add_plot_to_html(fig) 1379 f"<h2>{title}</h2>" + add_plot_to_html(fig)
677 name in { 1387 name in {
678 "pr_auc", 1388 "pr_auc",
679 "class_report", 1389 "class_report",
680 } 1390 }
681 ): 1391 ):
1392 if name in rendered_test_plots:
1393 continue
682 title = plot_title_map.get( 1394 title = plot_title_map.get(
683 name, name.replace("_", " ").title() 1395 name, name.replace("_", " ").title()
684 ) 1396 )
685 b64 = encode_image_to_base64(path) 1397 b64 = encode_image_to_base64(path)
686 test_html += ( 1398 test_html += (
748 ("Features used in SHAP", fi_analyzer.shap_used_features) 1460 ("Features used in SHAP", fi_analyzer.shap_used_features)
749 ) 1461 )
750 if cap_rows: 1462 if cap_rows:
751 cap_table = ( 1463 cap_table = (
752 "<div class='table-wrapper'>" 1464 "<div class='table-wrapper'>"
753 "<table class='table sortable'>" 1465 "<table class='table sortable table-fi-scope'>"
754 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" 1466 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>"
755 "<tbody>" 1467 "<tbody>"
756 + "".join( 1468 + "".join(
757 f"<tr><td>{label}</td><td>{value}</td></tr>" 1469 f"<tr><td>{label}</td><td>{value}</td></tr>"
758 for label, value in cap_rows 1470 for label, value in cap_rows
801 + add_hr_to_html() 1513 + add_hr_to_html()
802 ) 1514 )
803 # 7) Assemble final HTML (three tabs) 1515 # 7) Assemble final HTML (three tabs)
804 html = get_html_template() 1516 html = get_html_template()
805 html += "<h1>Tabular Learner Model Report</h1>" 1517 html += "<h1>Tabular Learner Model Report</h1>"
806 html += build_tabbed_html(summary_html, test_html, feature_html) 1518 html += build_tabbed_html(
1519 summary_html,
1520 test_html,
1521 feature_html,
1522 explainer_html=None,
1523 config_html=config_html,
1524 )
807 html += get_feature_metrics_help_modal() 1525 html += get_feature_metrics_help_modal()
808 html += get_html_closing() 1526 html += get_html_closing()
809 1527
810 # 8) Write out 1528 # 8) Write out
811 (Path(self.output_dir) / "comparison_result.html").write_text( 1529 (Path(self.output_dir) / "comparison_result.html").write_text(
821 1539
822 def generate_plots_explainer(self): 1540 def generate_plots_explainer(self):
823 raise NotImplementedError("Subclasses should implement this method") 1541 raise NotImplementedError("Subclasses should implement this method")
824 1542
825 def generate_tree_plots(self): 1543 def generate_tree_plots(self):
1544 from explainerdashboard.explainers import RandomForestExplainer
826 from sklearn.ensemble import ( 1545 from sklearn.ensemble import (
827 RandomForestClassifier, RandomForestRegressor 1546 RandomForestClassifier, RandomForestRegressor
828 ) 1547 )
829 from xgboost import XGBClassifier, XGBRegressor 1548 from xgboost import XGBClassifier, XGBRegressor
830 from explainerdashboard.explainers import RandomForestExplainer
831 1549
832 LOG.info("Generating tree plots") 1550 LOG.info("Generating tree plots")
833 X_test = self.exp.X_test_transformed.copy() 1551 X_test = self.exp.X_test_transformed.copy()
834 y_test = self.exp.y_test_transformed 1552 y_test = self.exp.y_test_transformed
835 1553