# HG changeset patch
# User goeckslab
# Date 1757371135 0
# Node ID ba45bc057d705902d214b1fdf1239cf6c5017414
# Parent 0afd970bd8aeb2a10d9470572b276eadc1b2c973
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
diff -r 0afd970bd8ae -r ba45bc057d70 feature_importance.py
--- a/feature_importance.py Fri Aug 22 21:13:44 2025 +0000
+++ b/feature_importance.py Mon Sep 08 22:38:55 2025 +0000
@@ -23,7 +23,6 @@
exp=None,
best_model=None,
):
-
self.task_type = task_type
self.output_dir = output_dir
self.exp = exp
@@ -40,8 +39,8 @@
LOG.info("Data loaded from memory")
else:
self.target_col = target_col
- self.data = pd.read_csv(data_path, sep=None, engine='python')
- self.data.columns = self.data.columns.str.replace('.', '_')
+ self.data = pd.read_csv(data_path, sep=None, engine="python")
+ self.data.columns = self.data.columns.str.replace(".", "_")
self.data = self.data.fillna(self.data.median(numeric_only=True))
self.target = self.data.columns[int(target_col) - 1]
self.exp = (
@@ -53,63 +52,58 @@
self.plots = {}
def setup_pycaret(self):
- if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
+ if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
LOG.info("Experiment already set up. Skipping PyCaret setup.")
return
LOG.info("Initializing PyCaret")
setup_params = {
- 'target': self.target,
- 'session_id': 123,
- 'html': True,
- 'log_experiment': False,
- 'system_log': False,
+ "target": self.target,
+ "session_id": 123,
+ "html": True,
+ "log_experiment": False,
+ "system_log": False,
}
self.exp.setup(self.data, **setup_params)
def save_tree_importance(self):
- model = self.best_model or self.exp.get_config('best_model')
- processed_features = self.exp.get_config('X_transformed').columns
+ model = self.best_model or self.exp.get_config("best_model")
+ processed_features = self.exp.get_config("X_transformed").columns
- # Try feature_importances_ or coef_ if available
importances = None
model_type = model.__class__.__name__
- self.tree_model_name = model_type # Store the model name for reporting
+ self.tree_model_name = model_type
- if hasattr(model, 'feature_importances_'):
+ if hasattr(model, "feature_importances_"):
importances = model.feature_importances_
- elif hasattr(model, 'coef_'):
- # For linear models, flatten coef_ and take abs (importance as magnitude)
+ elif hasattr(model, "coef_"):
importances = abs(model.coef_).flatten()
else:
- # Neither attribute exists; skip the plot
LOG.warning(
- f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot."
+ f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance."
)
- self.tree_model_name = None # No plot generated
+ self.tree_model_name = None
return
- # Defensive: handle mismatch in number of features
if len(importances) != len(processed_features):
LOG.warning(
- f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
+ f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
)
self.tree_model_name = None
return
feature_importances = pd.DataFrame(
- {'Feature': processed_features, 'Importance': importances}
- ).sort_values(by='Importance', ascending=False)
+ {"Feature": processed_features, "Importance": importances}
+ ).sort_values(by="Importance", ascending=False)
plt.figure(figsize=(10, 6))
- plt.barh(feature_importances['Feature'], feature_importances['Importance'])
- plt.xlabel('Importance')
- plt.title(f'Feature Importance ({model_type})')
- plot_path = os.path.join(self.output_dir, 'tree_importance.png')
- plt.savefig(plot_path)
+ plt.barh(feature_importances["Feature"], feature_importances["Importance"])
+ plt.xlabel("Importance")
+ plt.title(f"Feature Importance ({model_type})")
+ plot_path = os.path.join(self.output_dir, "tree_importance.png")
+ plt.savefig(plot_path, bbox_inches="tight")
plt.close()
- self.plots['tree_importance'] = plot_path
+ self.plots["tree_importance"] = plot_path
- def save_shap_values(self):
-
+ def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
model = self.best_model or self.exp.get_config("best_model")
X_data = None
@@ -120,78 +114,119 @@
except KeyError:
continue
if X_data is None:
- raise RuntimeError(
- "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. "
- "Make sure PyCaret setup/compare_models was run with feature_selection=True."
- )
+ raise RuntimeError("No transformed dataset found for SHAP.")
+
+ # --- Adaptive feature limiting (proportional cap) ---
+ n_rows, n_features = X_data.shape
+ if max_features is None:
+ if n_features <= 200:
+ max_features = n_features
+ else:
+ max_features = min(200, max(20, int(n_features * 0.1)))
try:
- used_features = model.booster_.feature_name()
- except Exception:
- used_features = getattr(model, "feature_names_in_", X_data.columns.tolist())
- X_data = X_data[used_features]
+ if hasattr(model, "feature_importances_"):
+ importances = pd.Series(
+ model.feature_importances_, index=X_data.columns
+ )
+ top_features = importances.nlargest(max_features).index
+ elif hasattr(model, "coef_"):
+ coef = abs(model.coef_).flatten()
+ importances = pd.Series(coef, index=X_data.columns)
+ top_features = importances.nlargest(max_features).index
+ else:
+ variances = X_data.var()
+ top_features = variances.nlargest(max_features).index
+
+ if len(top_features) < n_features:
+ LOG.info(
+ f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
+ )
+ X_data = X_data[top_features]
+ except Exception as e:
+ LOG.warning(
+ f"Feature limiting failed: {e}. Using all {n_features} features."
+ )
- max_bg = min(len(X_data), 100)
- bg = X_data.sample(max_bg, random_state=42)
+ # --- Adaptive row subsampling ---
+ if max_samples is None:
+ if n_rows <= 500:
+ max_samples = n_rows
+ elif n_rows <= 5000:
+ max_samples = 500
+ else:
+ max_samples = min(1000, int(n_rows * 0.1))
+
+ if n_rows > max_samples:
+ LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
+ X_data = X_data.sample(max_samples, random_state=42)
- predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict
+ # --- Adaptive feature display ---
+ if max_display is None:
+ if X_data.shape[1] <= 20:
+ max_display = X_data.shape[1]
+ elif X_data.shape[1] <= 100:
+ max_display = 30
+ else:
+ max_display = 50
+
+ # Background set
+ bg = X_data.sample(min(len(X_data), 100), random_state=42)
+ predict_fn = (
+ model.predict_proba if hasattr(model, "predict_proba") else model.predict
+ )
+
+ # Optimized explainer
+ if hasattr(model, "feature_importances_"):
+ explainer = shap.TreeExplainer(
+ model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+ )
+ elif hasattr(model, "coef_"):
+ explainer = shap.LinearExplainer(model, bg)
+ else:
+ explainer = shap.Explainer(predict_fn, bg)
try:
- explainer = shap.Explainer(predict_fn, bg)
+ shap_values = explainer(X_data)
self.shap_model_name = explainer.__class__.__name__
-
- shap_values = explainer(X_data)
except Exception as e:
LOG.error(f"SHAP computation failed: {e}")
self.shap_model_name = None
return
- output_names = getattr(shap_values, "output_names", None)
- if output_names is None and hasattr(model, "classes_"):
- output_names = list(model.classes_)
- if output_names is None:
- n_out = shap_values.values.shape[-1]
- output_names = list(map(str, range(n_out)))
+ # --- Plot SHAP summary ---
+ out_path = os.path.join(self.output_dir, "shap_summary.png")
+ plt.figure()
+ shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
+ plt.title(
+ f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
+ )
+ plt.savefig(out_path, bbox_inches="tight")
+ plt.close()
+ self.plots["shap_summary"] = out_path
- values = shap_values.values
- if values.ndim == 3:
- for j, name in enumerate(output_names):
- safe = name.replace(" ", "_").replace("/", "_")
- out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png")
- plt.figure()
- shap.plots.beeswarm(shap_values[..., j], show=False)
- plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}")
- plt.savefig(out_path)
- plt.close()
- self.plots[f"shap_summary_{safe}"] = out_path
- else:
- plt.figure()
- shap.plots.beeswarm(shap_values, show=False)
- plt.title(f"SHAP Summary for {model.__class__.__name__}")
- out_path = os.path.join(self.output_dir, "shap_summary.png")
- plt.savefig(out_path)
- plt.close()
- self.plots["shap_summary"] = out_path
+ # --- Log summary ---
+ LOG.info(
+ f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
+ )
def generate_html_report(self):
LOG.info("Generating HTML report")
-
plots_html = ""
for plot_name, plot_path in self.plots.items():
- # Special handling for tree importance: skip if no model name (not generated)
- if plot_name == 'tree_importance' and not getattr(
- self, 'tree_model_name', None
+ if plot_name == "tree_importance" and not getattr(
+ self, "tree_model_name", None
):
continue
encoded_image = self.encode_image_to_base64(plot_path)
- if plot_name == 'tree_importance' and getattr(
- self, 'tree_model_name', None
+ if plot_name == "tree_importance" and getattr(
+ self, "tree_model_name", None
):
+ section_title = f"Feature importance from {self.tree_model_name}"
+ elif plot_name == "shap_summary":
section_title = (
- f"Feature importance analysis from a trained {self.tree_model_name}"
+ f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
)
- elif plot_name == 'shap_summary':
- section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
else:
section_title = plot_name
plots_html += f"""
@@ -200,25 +235,19 @@
"""
-
- html_content = f"""
- {plots_html}
- """
-
- return html_content
+ return f"{plots_html}"
def encode_image_to_base64(self, img_path):
- with open(img_path, 'rb') as img_file:
- return base64.b64encode(img_file.read()).decode('utf-8')
+ with open(img_path, "rb") as img_file:
+ return base64.b64encode(img_file.read()).decode("utf-8")
def run(self):
if (
self.exp is None
- or not hasattr(self.exp, 'is_setup')
+ or not hasattr(self.exp, "is_setup")
or not self.exp.is_setup
):
self.setup_pycaret()
self.save_tree_importance()
self.save_shap_values()
- html_content = self.generate_html_report()
- return html_content
+ return self.generate_html_report()
diff -r 0afd970bd8ae -r ba45bc057d70 pycaret_classification.py
--- a/pycaret_classification.py Fri Aug 22 21:13:44 2025 +0000
+++ b/pycaret_classification.py Mon Sep 08 22:38:55 2025 +0000
@@ -2,15 +2,29 @@
import types
from typing import Dict
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
from base_model_trainer import BaseModelTrainer
from dashboard import generate_classifier_explainer_dashboard
-from plotly.graph_objects import Figure
from pycaret.classification import ClassificationExperiment
+from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
from utils import predict_proba
LOG = logging.getLogger(__name__)
+def _apply_report_layout(fig: go.Figure) -> go.Figure:
+ # Give the left side more space for y-axis title/ticks and let axes auto-reserve room
+ fig.update_xaxes(automargin=True, title_standoff=12)
+ fig.update_yaxes(automargin=True, title_standoff=12)
+ fig.update_layout(
+ autosize=True,
+ margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping
+ )
+ return fig
+
+
class ClassificationModelTrainer(BaseModelTrainer):
def __init__(
self,
@@ -50,20 +64,19 @@
)
plots = [
- 'confusion_matrix',
- 'auc',
- 'threshold',
- 'pr',
- 'error',
- 'class_report',
- 'learning',
- 'calibration',
- 'vc',
- 'dimension',
- 'manifold',
- 'rfe',
- 'feature',
- 'feature_all',
+ "auc",
+ "threshold",
+ "pr",
+ "error",
+ "class_report",
+ "learning",
+ "calibration",
+ "vc",
+ "dimension",
+ "manifold",
+ "rfe",
+ "feature",
+ "feature_all",
]
for plot_name in plots:
try:
@@ -102,24 +115,146 @@
LOG.info("Generating explainer plots")
+ # Ensure predict_proba is available here too
+ if not hasattr(self.best_model, "predict_proba"):
+ self.best_model.predict_proba = types.MethodType(
+ predict_proba, self.best_model
+ )
+ LOG.warning(
+ f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
+ )
+
X_test = self.exp.X_test_transformed.copy()
y_test = self.exp.y_test_transformed
explainer = ClassifierExplainer(self.best_model, X_test, y_test)
# a dict to hold the raw Figure objects or callables
- self.explainer_plots: Dict[str, Figure] = {}
+ self.explainer_plots: Dict[str, go.Figure] = {}
+
+ # --- Threshold-aware overrides for CM / ROC / PR ---
+ prob_thresh = getattr(self, "probability_threshold", None)
+
+ # Only for binary classification and when threshold is provided
+ if (prob_thresh is not None) and (not self.exp.is_multiclass):
+ X = self.exp.X_test_transformed
+ y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
+
+ # Get positive-class scores (robust defaults)
+ classes = list(getattr(self.best_model, "classes_", [0, 1]))
+ try:
+ pos_idx = classes.index(1) if 1 in classes else 1
+ except Exception:
+ pos_idx = 1
+
+ proba = self.best_model.predict_proba(X)
+ y_scores = proba[:, pos_idx]
+
+ # Derive label names consistently
+ pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
+ neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
+
+ # ---- Confusion Matrix @ threshold ----
+ try:
+ y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
+ cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
+ fig_cm = go.Figure(
+ data=go.Heatmap(
+ z=cm,
+ x=[f"Pred {neg_label}", f"Pred {pos_label}"],
+ y=[f"True {neg_label}", f"True {pos_label}"],
+ text=cm,
+ texttemplate="%{text}",
+ colorscale="Blues",
+ showscale=False,
+ )
+ )
+ fig_cm.update_layout(
+ title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
+ xaxis_title="Predicted label",
+ yaxis_title="True label",
+ )
+ _apply_report_layout(fig_cm)
+ self.explainer_plots["confusion_matrix"] = fig_cm
+ except Exception as e:
+ LOG.warning(
+ f"Threshold-aware confusion matrix failed; falling back: {e}"
+ )
- # these go into the Test tab
+ # ---- ROC with threshold marker ----
+ try:
+ fpr, tpr, thr = roc_curve(y, y_scores)
+ roc_auc = auc(fpr, tpr)
+ fig_roc = go.Figure()
+ fig_roc.add_scatter(
+ x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})"
+ )
+ if len(thr):
+ mask = np.isfinite(thr)
+ if mask.any():
+ idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh)))
+ idx = np.where(mask)[0][idx_local]
+ if 0 <= idx < len(fpr):
+ fig_roc.add_scatter(
+ x=[fpr[idx]],
+ y=[tpr[idx]],
+ mode="markers",
+ name=f"@ {prob_thresh:.2f}",
+ marker=dict(size=10),
+ )
+ fig_roc.update_layout(
+ title=f"ROC Curve (marker at threshold={prob_thresh:.2f})",
+ xaxis_title="False Positive Rate",
+ yaxis_title="True Positive Rate",
+ )
+ _apply_report_layout(fig_roc)
+ self.explainer_plots["roc_auc"] = fig_roc
+ except Exception as e:
+ LOG.warning(f"Threshold marker on ROC failed; falling back: {e}")
+
+ # ---- PR with threshold marker ----
+ try:
+ precision, recall, thr_pr = precision_recall_curve(y, y_scores)
+ pr_auc = auc(recall, precision)
+ fig_pr = go.Figure()
+ fig_pr.add_scatter(
+ x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})"
+ )
+ if len(thr_pr):
+ idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh)))
+ # note: thr_pr has length = len(precision) - 1
+ idx_pr = max(0, min(idx_pr, len(recall) - 1))
+ fig_pr.add_scatter(
+ x=[recall[idx_pr]],
+ y=[precision[idx_pr]],
+ mode="markers",
+ name=f"@ {prob_thresh:.2f}",
+ marker=dict(size=10),
+ )
+ fig_pr.update_layout(
+ title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})",
+ xaxis_title="Recall",
+ yaxis_title="Precision",
+ )
+ _apply_report_layout(fig_pr)
+ self.explainer_plots["pr_auc"] = fig_pr
+ except Exception as e:
+ LOG.warning(f"Threshold marker on PR failed; falling back: {e}")
+
+ # these go into the Test tab (don't overwrite overrides)
for key, fn in [
("roc_auc", explainer.plot_roc_auc),
("pr_auc", explainer.plot_pr_auc),
("lift_curve", explainer.plot_lift_curve),
("confusion_matrix", explainer.plot_confusion_matrix),
- ("threshold", explainer.plot_precision), # Percentage vs probability
+ ("threshold", explainer.plot_precision), # percentage vs probability
("cumulative_precision", explainer.plot_cumulative_precision),
]:
+ if key in self.explainer_plots:
+ continue
try:
- self.explainer_plots[key] = fn()
+ fig = fn()
+ if fig is not None:
+ self.explainer_plots[key] = fig
except Exception as e:
LOG.error(f"Error generating explainer plot {key}: {e}")
@@ -143,7 +278,9 @@
if feat in explainer.X.columns or feat in explainer.onehot_cols:
valid_feats.append(feat)
else:
- LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data")
+ LOG.warning(
+ f"Skipping PDP for feature {feat!r}: not found in explainer data"
+ )
for feat in valid_feats:
# wrap each PDP call to catch any unexpected AssertionErrors
@@ -157,6 +294,7 @@
except Exception as e:
LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
return None
+
return _plot
self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
diff -r 0afd970bd8ae -r ba45bc057d70 tabular_learner.xml
--- a/tabular_learner.xml Fri Aug 22 21:13:44 2025 +0000
+++ b/tabular_learner.xml Mon Sep 08 22:38:55 2025 +0000
@@ -55,7 +55,7 @@
--probability_threshold '$probability_threshold'
#end if
#end if
- #if $test_file
+ #if $has_test_file == "yes"
--test_file '$test_file'
#end if
--model_type '$model_type'
@@ -63,12 +63,21 @@
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -161,9 +170,9 @@
-
+