changeset 8:ba45bc057d70 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author goeckslab
date Mon, 08 Sep 2025 22:38:55 +0000
parents 0afd970bd8ae
children
files feature_importance.py pycaret_classification.py tabular_learner.xml
diffstat 3 files changed, 294 insertions(+), 118 deletions(-) [+]
line wrap: on
line diff
--- a/feature_importance.py	Fri Aug 22 21:13:44 2025 +0000
+++ b/feature_importance.py	Mon Sep 08 22:38:55 2025 +0000
@@ -23,7 +23,6 @@
         exp=None,
         best_model=None,
     ):
-
         self.task_type = task_type
         self.output_dir = output_dir
         self.exp = exp
@@ -40,8 +39,8 @@
                 LOG.info("Data loaded from memory")
             else:
                 self.target_col = target_col
-                self.data = pd.read_csv(data_path, sep=None, engine='python')
-                self.data.columns = self.data.columns.str.replace('.', '_')
+                self.data = pd.read_csv(data_path, sep=None, engine="python")
+                self.data.columns = self.data.columns.str.replace(".", "_")
                 self.data = self.data.fillna(self.data.median(numeric_only=True))
             self.target = self.data.columns[int(target_col) - 1]
             self.exp = (
@@ -53,63 +52,58 @@
         self.plots = {}
 
     def setup_pycaret(self):
-        if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
+        if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
             LOG.info("Experiment already set up. Skipping PyCaret setup.")
             return
         LOG.info("Initializing PyCaret")
         setup_params = {
-            'target': self.target,
-            'session_id': 123,
-            'html': True,
-            'log_experiment': False,
-            'system_log': False,
+            "target": self.target,
+            "session_id": 123,
+            "html": True,
+            "log_experiment": False,
+            "system_log": False,
         }
         self.exp.setup(self.data, **setup_params)
 
     def save_tree_importance(self):
-        model = self.best_model or self.exp.get_config('best_model')
-        processed_features = self.exp.get_config('X_transformed').columns
+        model = self.best_model or self.exp.get_config("best_model")
+        processed_features = self.exp.get_config("X_transformed").columns
 
-        # Try feature_importances_ or coef_ if available
         importances = None
         model_type = model.__class__.__name__
-        self.tree_model_name = model_type  # Store the model name for reporting
+        self.tree_model_name = model_type
 
-        if hasattr(model, 'feature_importances_'):
+        if hasattr(model, "feature_importances_"):
             importances = model.feature_importances_
-        elif hasattr(model, 'coef_'):
-            # For linear models, flatten coef_ and take abs (importance as magnitude)
+        elif hasattr(model, "coef_"):
             importances = abs(model.coef_).flatten()
         else:
-            # Neither attribute exists; skip the plot
             LOG.warning(
-                f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot."
+                f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance."
             )
-            self.tree_model_name = None  # No plot generated
+            self.tree_model_name = None
             return
 
-        # Defensive: handle mismatch in number of features
         if len(importances) != len(processed_features):
             LOG.warning(
-                f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
+                f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
             )
             self.tree_model_name = None
             return
 
         feature_importances = pd.DataFrame(
-            {'Feature': processed_features, 'Importance': importances}
-        ).sort_values(by='Importance', ascending=False)
+            {"Feature": processed_features, "Importance": importances}
+        ).sort_values(by="Importance", ascending=False)
         plt.figure(figsize=(10, 6))
-        plt.barh(feature_importances['Feature'], feature_importances['Importance'])
-        plt.xlabel('Importance')
-        plt.title(f'Feature Importance ({model_type})')
-        plot_path = os.path.join(self.output_dir, 'tree_importance.png')
-        plt.savefig(plot_path)
+        plt.barh(feature_importances["Feature"], feature_importances["Importance"])
+        plt.xlabel("Importance")
+        plt.title(f"Feature Importance ({model_type})")
+        plot_path = os.path.join(self.output_dir, "tree_importance.png")
+        plt.savefig(plot_path, bbox_inches="tight")
         plt.close()
-        self.plots['tree_importance'] = plot_path
+        self.plots["tree_importance"] = plot_path
 
-    def save_shap_values(self):
-
+    def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
         model = self.best_model or self.exp.get_config("best_model")
 
         X_data = None
@@ -120,78 +114,119 @@
             except KeyError:
                 continue
         if X_data is None:
-            raise RuntimeError(
-                "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. "
-                "Make sure PyCaret setup/compare_models was run with feature_selection=True."
-            )
+            raise RuntimeError("No transformed dataset found for SHAP.")
+
+        # --- Adaptive feature limiting (proportional cap) ---
+        n_rows, n_features = X_data.shape
+        if max_features is None:
+            if n_features <= 200:
+                max_features = n_features
+            else:
+                max_features = min(200, max(20, int(n_features * 0.1)))
 
         try:
-            used_features = model.booster_.feature_name()
-        except Exception:
-            used_features = getattr(model, "feature_names_in_", X_data.columns.tolist())
-        X_data = X_data[used_features]
+            if hasattr(model, "feature_importances_"):
+                importances = pd.Series(
+                    model.feature_importances_, index=X_data.columns
+                )
+                top_features = importances.nlargest(max_features).index
+            elif hasattr(model, "coef_"):
+                coef = abs(model.coef_).flatten()
+                importances = pd.Series(coef, index=X_data.columns)
+                top_features = importances.nlargest(max_features).index
+            else:
+                variances = X_data.var()
+                top_features = variances.nlargest(max_features).index
+
+            if len(top_features) < n_features:
+                LOG.info(
+                    f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
+                )
+            X_data = X_data[top_features]
+        except Exception as e:
+            LOG.warning(
+                f"Feature limiting failed: {e}. Using all {n_features} features."
+            )
 
-        max_bg = min(len(X_data), 100)
-        bg = X_data.sample(max_bg, random_state=42)
+        # --- Adaptive row subsampling ---
+        if max_samples is None:
+            if n_rows <= 500:
+                max_samples = n_rows
+            elif n_rows <= 5000:
+                max_samples = 500
+            else:
+                max_samples = min(1000, int(n_rows * 0.1))
+
+        if n_rows > max_samples:
+            LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
+            X_data = X_data.sample(max_samples, random_state=42)
 
-        predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict
+        # --- Adaptive feature display ---
+        if max_display is None:
+            if X_data.shape[1] <= 20:
+                max_display = X_data.shape[1]
+            elif X_data.shape[1] <= 100:
+                max_display = 30
+            else:
+                max_display = 50
+
+        # Background set
+        bg = X_data.sample(min(len(X_data), 100), random_state=42)
+        predict_fn = (
+            model.predict_proba if hasattr(model, "predict_proba") else model.predict
+        )
+
+        # Optimized explainer
+        if hasattr(model, "feature_importances_"):
+            explainer = shap.TreeExplainer(
+                model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+            )
+        elif hasattr(model, "coef_"):
+            explainer = shap.LinearExplainer(model, bg)
+        else:
+            explainer = shap.Explainer(predict_fn, bg)
 
         try:
-            explainer = shap.Explainer(predict_fn, bg)
+            shap_values = explainer(X_data)
             self.shap_model_name = explainer.__class__.__name__
-
-            shap_values = explainer(X_data)
         except Exception as e:
             LOG.error(f"SHAP computation failed: {e}")
             self.shap_model_name = None
             return
 
-        output_names = getattr(shap_values, "output_names", None)
-        if output_names is None and hasattr(model, "classes_"):
-            output_names = list(model.classes_)
-        if output_names is None:
-            n_out = shap_values.values.shape[-1]
-            output_names = list(map(str, range(n_out)))
+        # --- Plot SHAP summary ---
+        out_path = os.path.join(self.output_dir, "shap_summary.png")
+        plt.figure()
+        shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
+        plt.title(
+            f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
+        )
+        plt.savefig(out_path, bbox_inches="tight")
+        plt.close()
+        self.plots["shap_summary"] = out_path
 
-        values = shap_values.values
-        if values.ndim == 3:
-            for j, name in enumerate(output_names):
-                safe = name.replace(" ", "_").replace("/", "_")
-                out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png")
-                plt.figure()
-                shap.plots.beeswarm(shap_values[..., j], show=False)
-                plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}")
-                plt.savefig(out_path)
-                plt.close()
-                self.plots[f"shap_summary_{safe}"] = out_path
-        else:
-            plt.figure()
-            shap.plots.beeswarm(shap_values, show=False)
-            plt.title(f"SHAP Summary for {model.__class__.__name__}")
-            out_path = os.path.join(self.output_dir, "shap_summary.png")
-            plt.savefig(out_path)
-            plt.close()
-            self.plots["shap_summary"] = out_path
+        # --- Log summary ---
+        LOG.info(
+            f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
+        )
 
     def generate_html_report(self):
         LOG.info("Generating HTML report")
-
         plots_html = ""
         for plot_name, plot_path in self.plots.items():
-            # Special handling for tree importance: skip if no model name (not generated)
-            if plot_name == 'tree_importance' and not getattr(
-                self, 'tree_model_name', None
+            if plot_name == "tree_importance" and not getattr(
+                self, "tree_model_name", None
             ):
                 continue
             encoded_image = self.encode_image_to_base64(plot_path)
-            if plot_name == 'tree_importance' and getattr(
-                self, 'tree_model_name', None
+            if plot_name == "tree_importance" and getattr(
+                self, "tree_model_name", None
             ):
+                section_title = f"Feature importance from {self.tree_model_name}"
+            elif plot_name == "shap_summary":
                 section_title = (
-                    f"Feature importance analysis from a trained {self.tree_model_name}"
+                    f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
                 )
-            elif plot_name == 'shap_summary':
-                section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
             else:
                 section_title = plot_name
             plots_html += f"""
@@ -200,25 +235,19 @@
                 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
             </div>
             """
-
-        html_content = f"""
-            {plots_html}
-        """
-
-        return html_content
+        return f"{plots_html}"
 
     def encode_image_to_base64(self, img_path):
-        with open(img_path, 'rb') as img_file:
-            return base64.b64encode(img_file.read()).decode('utf-8')
+        with open(img_path, "rb") as img_file:
+            return base64.b64encode(img_file.read()).decode("utf-8")
 
     def run(self):
         if (
             self.exp is None
-            or not hasattr(self.exp, 'is_setup')
+            or not hasattr(self.exp, "is_setup")
             or not self.exp.is_setup
         ):
             self.setup_pycaret()
         self.save_tree_importance()
         self.save_shap_values()
-        html_content = self.generate_html_report()
-        return html_content
+        return self.generate_html_report()
--- a/pycaret_classification.py	Fri Aug 22 21:13:44 2025 +0000
+++ b/pycaret_classification.py	Mon Sep 08 22:38:55 2025 +0000
@@ -2,15 +2,29 @@
 import types
 from typing import Dict
 
+import numpy as np
+import pandas as pd
+import plotly.graph_objects as go
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_classifier_explainer_dashboard
-from plotly.graph_objects import Figure
 from pycaret.classification import ClassificationExperiment
+from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
 from utils import predict_proba
 
 LOG = logging.getLogger(__name__)
 
 
+def _apply_report_layout(fig: go.Figure) -> go.Figure:
+    # Give the left side more space for y-axis title/ticks and let axes auto-reserve room
+    fig.update_xaxes(automargin=True, title_standoff=12)
+    fig.update_yaxes(automargin=True, title_standoff=12)
+    fig.update_layout(
+        autosize=True,
+        margin=dict(l=120, r=40, t=60, b=60),  # bump 'l' if you still see clipping
+    )
+    return fig
+
+
 class ClassificationModelTrainer(BaseModelTrainer):
     def __init__(
         self,
@@ -50,20 +64,19 @@
             )
 
         plots = [
-            'confusion_matrix',
-            'auc',
-            'threshold',
-            'pr',
-            'error',
-            'class_report',
-            'learning',
-            'calibration',
-            'vc',
-            'dimension',
-            'manifold',
-            'rfe',
-            'feature',
-            'feature_all',
+            "auc",
+            "threshold",
+            "pr",
+            "error",
+            "class_report",
+            "learning",
+            "calibration",
+            "vc",
+            "dimension",
+            "manifold",
+            "rfe",
+            "feature",
+            "feature_all",
         ]
         for plot_name in plots:
             try:
@@ -102,24 +115,146 @@
 
         LOG.info("Generating explainer plots")
 
+        # Ensure predict_proba is available here too
+        if not hasattr(self.best_model, "predict_proba"):
+            self.best_model.predict_proba = types.MethodType(
+                predict_proba, self.best_model
+            )
+            LOG.warning(
+                f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
+            )
+
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
         explainer = ClassifierExplainer(self.best_model, X_test, y_test)
 
         # a dict to hold the raw Figure objects or callables
-        self.explainer_plots: Dict[str, Figure] = {}
+        self.explainer_plots: Dict[str, go.Figure] = {}
+
+        # --- Threshold-aware overrides for CM / ROC / PR ---
+        prob_thresh = getattr(self, "probability_threshold", None)
+
+        # Only for binary classification and when threshold is provided
+        if (prob_thresh is not None) and (not self.exp.is_multiclass):
+            X = self.exp.X_test_transformed
+            y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
+
+            # Get positive-class scores (robust defaults)
+            classes = list(getattr(self.best_model, "classes_", [0, 1]))
+            try:
+                pos_idx = classes.index(1) if 1 in classes else 1
+            except Exception:
+                pos_idx = 1
+
+            proba = self.best_model.predict_proba(X)
+            y_scores = proba[:, pos_idx]
+
+            # Derive label names consistently
+            pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
+            neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
+
+            # ---- Confusion Matrix @ threshold ----
+            try:
+                y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
+                cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
+                fig_cm = go.Figure(
+                    data=go.Heatmap(
+                        z=cm,
+                        x=[f"Pred {neg_label}", f"Pred {pos_label}"],
+                        y=[f"True {neg_label}", f"True {pos_label}"],
+                        text=cm,
+                        texttemplate="%{text}",
+                        colorscale="Blues",
+                        showscale=False,
+                    )
+                )
+                fig_cm.update_layout(
+                    title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
+                    xaxis_title="Predicted label",
+                    yaxis_title="True label",
+                )
+                _apply_report_layout(fig_cm)
+                self.explainer_plots["confusion_matrix"] = fig_cm
+            except Exception as e:
+                LOG.warning(
+                    f"Threshold-aware confusion matrix failed; falling back: {e}"
+                )
 
-        # these go into the Test tab
+            # ---- ROC with threshold marker ----
+            try:
+                fpr, tpr, thr = roc_curve(y, y_scores)
+                roc_auc = auc(fpr, tpr)
+                fig_roc = go.Figure()
+                fig_roc.add_scatter(
+                    x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})"
+                )
+                if len(thr):
+                    mask = np.isfinite(thr)
+                    if mask.any():
+                        idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh)))
+                        idx = np.where(mask)[0][idx_local]
+                        if 0 <= idx < len(fpr):
+                            fig_roc.add_scatter(
+                                x=[fpr[idx]],
+                                y=[tpr[idx]],
+                                mode="markers",
+                                name=f"@ {prob_thresh:.2f}",
+                                marker=dict(size=10),
+                            )
+                fig_roc.update_layout(
+                    title=f"ROC Curve (marker at threshold={prob_thresh:.2f})",
+                    xaxis_title="False Positive Rate",
+                    yaxis_title="True Positive Rate",
+                )
+                _apply_report_layout(fig_roc)
+                self.explainer_plots["roc_auc"] = fig_roc
+            except Exception as e:
+                LOG.warning(f"Threshold marker on ROC failed; falling back: {e}")
+
+            # ---- PR with threshold marker ----
+            try:
+                precision, recall, thr_pr = precision_recall_curve(y, y_scores)
+                pr_auc = auc(recall, precision)
+                fig_pr = go.Figure()
+                fig_pr.add_scatter(
+                    x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})"
+                )
+                if len(thr_pr):
+                    idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh)))
+                    # note: thr_pr has length = len(precision) - 1
+                    idx_pr = max(0, min(idx_pr, len(recall) - 1))
+                    fig_pr.add_scatter(
+                        x=[recall[idx_pr]],
+                        y=[precision[idx_pr]],
+                        mode="markers",
+                        name=f"@ {prob_thresh:.2f}",
+                        marker=dict(size=10),
+                    )
+                fig_pr.update_layout(
+                    title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})",
+                    xaxis_title="Recall",
+                    yaxis_title="Precision",
+                )
+                _apply_report_layout(fig_pr)
+                self.explainer_plots["pr_auc"] = fig_pr
+            except Exception as e:
+                LOG.warning(f"Threshold marker on PR failed; falling back: {e}")
+
+        # these go into the Test tab (don't overwrite overrides)
         for key, fn in [
             ("roc_auc", explainer.plot_roc_auc),
             ("pr_auc", explainer.plot_pr_auc),
             ("lift_curve", explainer.plot_lift_curve),
             ("confusion_matrix", explainer.plot_confusion_matrix),
-            ("threshold", explainer.plot_precision),  # Percentage vs probability
+            ("threshold", explainer.plot_precision),  # percentage vs probability
             ("cumulative_precision", explainer.plot_cumulative_precision),
         ]:
+            if key in self.explainer_plots:
+                continue
             try:
-                self.explainer_plots[key] = fn()
+                fig = fn()
+                if fig is not None:
+                    self.explainer_plots[key] = fig
             except Exception as e:
                 LOG.error(f"Error generating explainer plot {key}: {e}")
 
@@ -143,7 +278,9 @@
             if feat in explainer.X.columns or feat in explainer.onehot_cols:
                 valid_feats.append(feat)
             else:
-                LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data")
+                LOG.warning(
+                    f"Skipping PDP for feature {feat!r}: not found in explainer data"
+                )
 
         for feat in valid_feats:
             # wrap each PDP call to catch any unexpected AssertionErrors
@@ -157,6 +294,7 @@
                     except Exception as e:
                         LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
                         return None
+
                 return _plot
 
             self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
--- a/tabular_learner.xml	Fri Aug 22 21:13:44 2025 +0000
+++ b/tabular_learner.xml	Mon Sep 08 22:38:55 2025 +0000
@@ -55,7 +55,7 @@
                 --probability_threshold '$probability_threshold'
                 #end if
         #end if
-        #if $test_file
+        #if $has_test_file == "yes"
             --test_file '$test_file'
         #end if
         --model_type '$model_type'
@@ -63,12 +63,21 @@
     </command>
     <inputs>
         <param name="input_file" type="data" format="csv,tabular" label="Tabular Input Dataset" />
-        <param name="test_file" type="data" format="csv,tabular" optional="true" label="Tabular Test Dataset"
-        help="If a test dataset is not provided,
-        the input dataset will be split into training, validation, and test sets.
-        If a test set is provided, the input dataset will be split into training and validation sets.
-        Cross-validation is applied by default during training." />
-       <param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
+        <conditional name="test_data_choice">
+            <param name="has_test_file" type="select" label="Do you have a separate test dataset?">
+                <option value="no" selected="true">No</option>
+                <option value="yes">Yes</option>
+            </param>
+            <when value="yes">
+                <param name="test_file" type="data" format="csv,tabular" label="Tabular Test Dataset"
+                       help="If a test dataset is provided, the input dataset will be split into training and validation sets only. 
+                             If not, the tool will split your input into training, validation, and test automatically." />
+            </when>
+            <when value="no">
+                <!-- Nothing extra shown -->
+            </when>
+        </conditional>
+        <param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
         <conditional name="model_selection">
             <param name="model_type" type="select" label="Task">
                 <option value="classification">classification</option>
@@ -161,9 +170,9 @@
         </conditional>
     </inputs>
     <outputs>
-        <data name="comparison_result" format="html" from_work_dir="comparison_result.html" label="${tool.name} analysis report on ${on_string}"/>
         <data name="model" format="h5" from_work_dir="pycaret_model.h5" label="${tool.name} best model on ${on_string}" />
         <data name="best_model_csv" format="csv" from_work_dir="best_model.csv" label="${tool.name} The parameters of the best model on ${on_string}" hidden="true" />
+        <data name="comparison_result" format="html" from_work_dir="comparison_result.html" label="${tool.name} analysis report on ${on_string}" />
     </outputs>
     <tests>
         <test>