diff feature_importance.py @ 17:c5c324ac29fc draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
author goeckslab
date Sat, 06 Dec 2025 14:20:36 +0000
parents 4fee4504646e
children
line wrap: on
line diff
--- a/feature_importance.py	Fri Nov 28 22:28:26 2025 +0000
+++ b/feature_importance.py	Sat Dec 06 14:20:36 2025 +0000
@@ -287,24 +287,16 @@
 
         # Background set
         bg = X_data.sample(min(len(X_data), 100), random_state=42)
-        predict_fn = (
-            model.predict_proba if hasattr(model, "predict_proba") else model.predict
-        )
+        predict_fn = self._get_predict_fn(model)
 
-        # Optimized explainer
-        explainer = None
-        explainer_label = None
-        if hasattr(model, "feature_importances_"):
-            explainer = shap.TreeExplainer(
-                model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
-            )
-            explainer_label = "tree_path_dependent"
-        elif hasattr(model, "coef_"):
-            explainer = shap.LinearExplainer(model, bg)
-            explainer_label = "linear"
-        else:
-            explainer = shap.Explainer(predict_fn, bg)
-            explainer_label = explainer.__class__.__name__
+        # Optimized explainer based on model type
+        explainer, explainer_label, tree_based = self._choose_shap_explainer(
+            model, bg, predict_fn
+        )
+        if explainer is None:
+            LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model)
+            self.shap_model_name = None
+            return
 
         try:
             shap_values = explainer(X_data)
@@ -312,7 +304,7 @@
         except Exception as e:
             error_message = str(e)
             needs_tree_fallback = (
-                hasattr(model, "feature_importances_")
+                tree_based
                 and "does not cover all the leaves" in error_message.lower()
             )
             feature_name_mismatch = "feature names should match" in error_message.lower()
@@ -348,7 +340,9 @@
                     error_message,
                 )
                 try:
-                    agnostic_explainer = shap.Explainer(predict_fn, bg)
+                    agnostic_explainer = shap.Explainer(
+                        predict_fn, bg, algorithm="permutation"
+                    )
                     shap_values = agnostic_explainer(X_data)
                     self.shap_model_name = (
                         f"{agnostic_explainer.__class__.__name__} (fallback)"
@@ -485,6 +479,241 @@
         with open(img_path, "rb") as img_file:
             return base64.b64encode(img_file.read()).decode("utf-8")
 
+    def _get_predict_fn(self, model):
+        if hasattr(model, "predict_proba"):
+            return model.predict_proba
+        if hasattr(model, "decision_function"):
+            return model.decision_function
+        return model.predict
+
+    def _choose_shap_explainer(self, model, bg, predict_fn):
+        """
+        Select a SHAP explainer following the prescribed priority order for
+        algorithms. Returns (explainer, label, is_tree_based).
+        """
+        if model is None:
+            return None, None, False
+
+        name = model.__class__.__name__
+        lname = name.lower()
+        task = getattr(self, "task_type", None)
+
+        def _permutation(fn):
+            return shap.Explainer(fn, bg, algorithm="permutation")
+
+        if task == "classification":
+            # 1) Logistic Regression
+            if "logisticregression" in lname:
+                return _permutation(model.predict_proba), "permutation-proba", False
+
+            # 2) Ridge Classifier
+            if "ridgeclassifier" in lname:
+                fn = (
+                    model.decision_function
+                    if hasattr(model, "decision_function")
+                    else predict_fn
+                )
+                return _permutation(fn), "permutation-decision_function", False
+
+            # 3) LDA
+            if "lineardiscriminantanalysis" in lname:
+                return _permutation(model.predict_proba), "permutation-proba", False
+
+            # 4) Random Forest
+            if "randomforestclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 5) Gradient Boosting
+            if "gradientboostingclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 6) AdaBoost
+            if "adaboostclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 7) Extra Trees
+            if "extratreesclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 8) LightGBM
+            if "lgbmclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model,
+                        bg,
+                        model_output="raw",
+                        feature_perturbation="tree_path_dependent",
+                        n_jobs=-1,
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 9) XGBoost
+            if "xgbclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 10) CatBoost (classifier)
+            if "catboost" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 11) KNN
+            if "kneighborsclassifier" in lname:
+                return _permutation(model.predict_proba), "permutation-proba", False
+
+            # 12) SVM - linear kernel
+            if "svc" in lname or "svm" in lname:
+                kernel = getattr(model, "kernel", None)
+                if kernel == "linear":
+                    return shap.LinearExplainer(model, bg), "linear", False
+                return _permutation(predict_fn), "permutation-svm", False
+
+            # 13) Decision Tree
+            if "decisiontreeclassifier" in lname:
+                return (
+                    shap.TreeExplainer(
+                        model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                    ),
+                    "tree_path_dependent",
+                    True,
+                )
+
+            # 14) Naive Bayes
+            if "naive_bayes" in lname or lname.endswith("nb"):
+                fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn
+                return _permutation(fn), "permutation-proba", False
+
+            # 15) QDA
+            if "quadraticdiscriminantanalysis" in lname:
+                return _permutation(model.predict_proba), "permutation-proba", False
+
+            # 16) Dummy
+            if "dummyclassifier" in lname:
+                return None, None, False
+
+            # Default classification: permutation on predict_fn
+            return _permutation(predict_fn), "permutation-default", False
+
+        # Regression path
+        # Linear family
+        linear_keys = [
+            "linearregression",
+            "lasso",
+            "ridge",
+            "elasticnet",
+            "lars",
+            "lassolars",
+            "orthogonalmatchingpursuit",
+            "bayesianridge",
+            "ardregression",
+            "passiveaggressiveregressor",
+            "theilsenregressor",
+            "huberregressor",
+        ]
+        if any(k in lname for k in linear_keys):
+            return shap.LinearExplainer(model, bg), "linear", False
+
+        # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic)
+        if "kernelridge" in lname:
+            return _permutation(predict_fn), "permutation-kernelridge", False
+        if "svr" in lname or "svm" in lname:
+            kernel = getattr(model, "kernel", None)
+            if kernel == "linear":
+                return shap.LinearExplainer(model, bg), "linear", False
+            return _permutation(predict_fn), "permutation-svr", False
+        if "kneighborsregressor" in lname:
+            return _permutation(predict_fn), "permutation-knn", False
+        if "mlpregressor" in lname:
+            return _permutation(predict_fn), "permutation-mlp", False
+        if "ransacregressor" in lname:
+            return _permutation(predict_fn), "permutation-ransac", False
+
+        # Tree-based regressors
+        tree_class_names = [
+            "decisiontreeregressor",
+            "randomforestregressor",
+            "extratreesregressor",
+            "adaboostregressor",
+            "gradientboostingregressor",
+        ]
+        if any(k in lname for k in tree_class_names):
+            return (
+                shap.TreeExplainer(
+                    model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                ),
+                "tree_path_dependent",
+                True,
+            )
+
+        # Boosting libraries
+        if "lgbmregressor" in lname or "lightgbm" in lname:
+            return (
+                shap.TreeExplainer(
+                    model,
+                    bg,
+                    model_output="raw",
+                    feature_perturbation="tree_path_dependent",
+                    n_jobs=-1,
+                ),
+                "tree_path_dependent",
+                True,
+            )
+        if "xgbregressor" in lname or "xgboost" in lname:
+            return (
+                shap.TreeExplainer(
+                    model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                ),
+                "tree_path_dependent",
+                True,
+            )
+        if "catboost" in lname:
+            return (
+                shap.TreeExplainer(
+                    model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
+                ),
+                "tree_path_dependent",
+                True,
+            )
+
+        # Default regression: model-agnostic permutation explainer
+        return _permutation(predict_fn), "permutation-default", False
+
     def run(self):
         if (
             self.exp is None