Mercurial > repos > goeckslab > pycaret_predict
diff feature_importance.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | 4fee4504646e |
| children |
line wrap: on
line diff
--- a/feature_importance.py Fri Nov 28 22:28:26 2025 +0000 +++ b/feature_importance.py Sat Dec 06 14:20:36 2025 +0000 @@ -287,24 +287,16 @@ # Background set bg = X_data.sample(min(len(X_data), 100), random_state=42) - predict_fn = ( - model.predict_proba if hasattr(model, "predict_proba") else model.predict - ) + predict_fn = self._get_predict_fn(model) - # Optimized explainer - explainer = None - explainer_label = None - if hasattr(model, "feature_importances_"): - explainer = shap.TreeExplainer( - model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 - ) - explainer_label = "tree_path_dependent" - elif hasattr(model, "coef_"): - explainer = shap.LinearExplainer(model, bg) - explainer_label = "linear" - else: - explainer = shap.Explainer(predict_fn, bg) - explainer_label = explainer.__class__.__name__ + # Optimized explainer based on model type + explainer, explainer_label, tree_based = self._choose_shap_explainer( + model, bg, predict_fn + ) + if explainer is None: + LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model) + self.shap_model_name = None + return try: shap_values = explainer(X_data) @@ -312,7 +304,7 @@ except Exception as e: error_message = str(e) needs_tree_fallback = ( - hasattr(model, "feature_importances_") + tree_based and "does not cover all the leaves" in error_message.lower() ) feature_name_mismatch = "feature names should match" in error_message.lower() @@ -348,7 +340,9 @@ error_message, ) try: - agnostic_explainer = shap.Explainer(predict_fn, bg) + agnostic_explainer = shap.Explainer( + predict_fn, bg, algorithm="permutation" + ) shap_values = agnostic_explainer(X_data) self.shap_model_name = ( f"{agnostic_explainer.__class__.__name__} (fallback)" @@ -485,6 +479,241 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _get_predict_fn(self, model): + if hasattr(model, "predict_proba"): + return model.predict_proba + if hasattr(model, "decision_function"): + return model.decision_function + return model.predict + + def _choose_shap_explainer(self, model, bg, predict_fn): + """ + Select a SHAP explainer following the prescribed priority order for + algorithms. Returns (explainer, label, is_tree_based). + """ + if model is None: + return None, None, False + + name = model.__class__.__name__ + lname = name.lower() + task = getattr(self, "task_type", None) + + def _permutation(fn): + return shap.Explainer(fn, bg, algorithm="permutation") + + if task == "classification": + # 1) Logistic Regression + if "logisticregression" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 2) Ridge Classifier + if "ridgeclassifier" in lname: + fn = ( + model.decision_function + if hasattr(model, "decision_function") + else predict_fn + ) + return _permutation(fn), "permutation-decision_function", False + + # 3) LDA + if "lineardiscriminantanalysis" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 4) Random Forest + if "randomforestclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 5) Gradient Boosting + if "gradientboostingclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 6) AdaBoost + if "adaboostclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 7) Extra Trees + if "extratreesclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 8) LightGBM + if "lgbmclassifier" in lname: + return ( + shap.TreeExplainer( + model, + bg, + model_output="raw", + feature_perturbation="tree_path_dependent", + n_jobs=-1, + ), + "tree_path_dependent", + True, + ) + + # 9) XGBoost + if "xgbclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 10) CatBoost (classifier) + if "catboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 11) KNN + if "kneighborsclassifier" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 12) SVM - linear kernel + if "svc" in lname or "svm" in lname: + kernel = getattr(model, "kernel", None) + if kernel == "linear": + return shap.LinearExplainer(model, bg), "linear", False + return _permutation(predict_fn), "permutation-svm", False + + # 13) Decision Tree + if "decisiontreeclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 14) Naive Bayes + if "naive_bayes" in lname or lname.endswith("nb"): + fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn + return _permutation(fn), "permutation-proba", False + + # 15) QDA + if "quadraticdiscriminantanalysis" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 16) Dummy + if "dummyclassifier" in lname: + return None, None, False + + # Default classification: permutation on predict_fn + return _permutation(predict_fn), "permutation-default", False + + # Regression path + # Linear family + linear_keys = [ + "linearregression", + "lasso", + "ridge", + "elasticnet", + "lars", + "lassolars", + "orthogonalmatchingpursuit", + "bayesianridge", + "ardregression", + "passiveaggressiveregressor", + "theilsenregressor", + "huberregressor", + ] + if any(k in lname for k in linear_keys): + return shap.LinearExplainer(model, bg), "linear", False + + # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic) + if "kernelridge" in lname: + return _permutation(predict_fn), "permutation-kernelridge", False + if "svr" in lname or "svm" in lname: + kernel = getattr(model, "kernel", None) + if kernel == "linear": + return shap.LinearExplainer(model, bg), "linear", False + return _permutation(predict_fn), "permutation-svr", False + if "kneighborsregressor" in lname: + return _permutation(predict_fn), "permutation-knn", False + if "mlpregressor" in lname: + return _permutation(predict_fn), "permutation-mlp", False + if "ransacregressor" in lname: + return _permutation(predict_fn), "permutation-ransac", False + + # Tree-based regressors + tree_class_names = [ + "decisiontreeregressor", + "randomforestregressor", + "extratreesregressor", + "adaboostregressor", + "gradientboostingregressor", + ] + if any(k in lname for k in tree_class_names): + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # Boosting libraries + if "lgbmregressor" in lname or "lightgbm" in lname: + return ( + shap.TreeExplainer( + model, + bg, + model_output="raw", + feature_perturbation="tree_path_dependent", + n_jobs=-1, + ), + "tree_path_dependent", + True, + ) + if "xgbregressor" in lname or "xgboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + if "catboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # Default regression: model-agnostic permutation explainer + return _permutation(predict_fn), "permutation-default", False + def run(self): if ( self.exp is None
