Mercurial > repos > goeckslab > tabular_learner
comparison feature_importance.py @ 13:bf0df21a1ea3 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:23 +0000 |
| parents | 15707141e7da |
| children |
comparison
equal
deleted
inserted
replaced
| 12:15707141e7da | 13:bf0df21a1ea3 |
|---|---|
| 285 display_features = list(X_data.columns) | 285 display_features = list(X_data.columns) |
| 286 max_display = len(display_features) | 286 max_display = len(display_features) |
| 287 | 287 |
| 288 # Background set | 288 # Background set |
| 289 bg = X_data.sample(min(len(X_data), 100), random_state=42) | 289 bg = X_data.sample(min(len(X_data), 100), random_state=42) |
| 290 predict_fn = ( | 290 predict_fn = self._get_predict_fn(model) |
| 291 model.predict_proba if hasattr(model, "predict_proba") else model.predict | 291 |
| 292 # Optimized explainer based on model type | |
| 293 explainer, explainer_label, tree_based = self._choose_shap_explainer( | |
| 294 model, bg, predict_fn | |
| 292 ) | 295 ) |
| 293 | 296 if explainer is None: |
| 294 # Optimized explainer | 297 LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model) |
| 295 explainer = None | 298 self.shap_model_name = None |
| 296 explainer_label = None | 299 return |
| 297 if hasattr(model, "feature_importances_"): | |
| 298 explainer = shap.TreeExplainer( | |
| 299 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 300 ) | |
| 301 explainer_label = "tree_path_dependent" | |
| 302 elif hasattr(model, "coef_"): | |
| 303 explainer = shap.LinearExplainer(model, bg) | |
| 304 explainer_label = "linear" | |
| 305 else: | |
| 306 explainer = shap.Explainer(predict_fn, bg) | |
| 307 explainer_label = explainer.__class__.__name__ | |
| 308 | 300 |
| 309 try: | 301 try: |
| 310 shap_values = explainer(X_data) | 302 shap_values = explainer(X_data) |
| 311 self.shap_model_name = explainer.__class__.__name__ | 303 self.shap_model_name = explainer.__class__.__name__ |
| 312 except Exception as e: | 304 except Exception as e: |
| 313 error_message = str(e) | 305 error_message = str(e) |
| 314 needs_tree_fallback = ( | 306 needs_tree_fallback = ( |
| 315 hasattr(model, "feature_importances_") | 307 tree_based |
| 316 and "does not cover all the leaves" in error_message.lower() | 308 and "does not cover all the leaves" in error_message.lower() |
| 317 ) | 309 ) |
| 318 feature_name_mismatch = "feature names should match" in error_message.lower() | 310 feature_name_mismatch = "feature names should match" in error_message.lower() |
| 319 if needs_tree_fallback: | 311 if needs_tree_fallback: |
| 320 LOG.warning( | 312 LOG.warning( |
| 346 "SHAP computation failed due to feature-name mismatch (%s). " | 338 "SHAP computation failed due to feature-name mismatch (%s). " |
| 347 "Falling back to model-agnostic SHAP explainer.", | 339 "Falling back to model-agnostic SHAP explainer.", |
| 348 error_message, | 340 error_message, |
| 349 ) | 341 ) |
| 350 try: | 342 try: |
| 351 agnostic_explainer = shap.Explainer(predict_fn, bg) | 343 agnostic_explainer = shap.Explainer( |
| 344 predict_fn, bg, algorithm="permutation" | |
| 345 ) | |
| 352 shap_values = agnostic_explainer(X_data) | 346 shap_values = agnostic_explainer(X_data) |
| 353 self.shap_model_name = ( | 347 self.shap_model_name = ( |
| 354 f"{agnostic_explainer.__class__.__name__} (fallback)" | 348 f"{agnostic_explainer.__class__.__name__} (fallback)" |
| 355 ) | 349 ) |
| 356 except Exception as fallback_exc: | 350 except Exception as fallback_exc: |
| 483 | 477 |
| 484 def encode_image_to_base64(self, img_path): | 478 def encode_image_to_base64(self, img_path): |
| 485 with open(img_path, "rb") as img_file: | 479 with open(img_path, "rb") as img_file: |
| 486 return base64.b64encode(img_file.read()).decode("utf-8") | 480 return base64.b64encode(img_file.read()).decode("utf-8") |
| 487 | 481 |
| 482 def _get_predict_fn(self, model): | |
| 483 if hasattr(model, "predict_proba"): | |
| 484 return model.predict_proba | |
| 485 if hasattr(model, "decision_function"): | |
| 486 return model.decision_function | |
| 487 return model.predict | |
| 488 | |
| 489 def _choose_shap_explainer(self, model, bg, predict_fn): | |
| 490 """ | |
| 491 Select a SHAP explainer following the prescribed priority order for | |
| 492 algorithms. Returns (explainer, label, is_tree_based). | |
| 493 """ | |
| 494 if model is None: | |
| 495 return None, None, False | |
| 496 | |
| 497 name = model.__class__.__name__ | |
| 498 lname = name.lower() | |
| 499 task = getattr(self, "task_type", None) | |
| 500 | |
| 501 def _permutation(fn): | |
| 502 return shap.Explainer(fn, bg, algorithm="permutation") | |
| 503 | |
| 504 if task == "classification": | |
| 505 # 1) Logistic Regression | |
| 506 if "logisticregression" in lname: | |
| 507 return _permutation(model.predict_proba), "permutation-proba", False | |
| 508 | |
| 509 # 2) Ridge Classifier | |
| 510 if "ridgeclassifier" in lname: | |
| 511 fn = ( | |
| 512 model.decision_function | |
| 513 if hasattr(model, "decision_function") | |
| 514 else predict_fn | |
| 515 ) | |
| 516 return _permutation(fn), "permutation-decision_function", False | |
| 517 | |
| 518 # 3) LDA | |
| 519 if "lineardiscriminantanalysis" in lname: | |
| 520 return _permutation(model.predict_proba), "permutation-proba", False | |
| 521 | |
| 522 # 4) Random Forest | |
| 523 if "randomforestclassifier" in lname: | |
| 524 return ( | |
| 525 shap.TreeExplainer( | |
| 526 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 527 ), | |
| 528 "tree_path_dependent", | |
| 529 True, | |
| 530 ) | |
| 531 | |
| 532 # 5) Gradient Boosting | |
| 533 if "gradientboostingclassifier" in lname: | |
| 534 return ( | |
| 535 shap.TreeExplainer( | |
| 536 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 537 ), | |
| 538 "tree_path_dependent", | |
| 539 True, | |
| 540 ) | |
| 541 | |
| 542 # 6) AdaBoost | |
| 543 if "adaboostclassifier" in lname: | |
| 544 return ( | |
| 545 shap.TreeExplainer( | |
| 546 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 547 ), | |
| 548 "tree_path_dependent", | |
| 549 True, | |
| 550 ) | |
| 551 | |
| 552 # 7) Extra Trees | |
| 553 if "extratreesclassifier" in lname: | |
| 554 return ( | |
| 555 shap.TreeExplainer( | |
| 556 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 557 ), | |
| 558 "tree_path_dependent", | |
| 559 True, | |
| 560 ) | |
| 561 | |
| 562 # 8) LightGBM | |
| 563 if "lgbmclassifier" in lname: | |
| 564 return ( | |
| 565 shap.TreeExplainer( | |
| 566 model, | |
| 567 bg, | |
| 568 model_output="raw", | |
| 569 feature_perturbation="tree_path_dependent", | |
| 570 n_jobs=-1, | |
| 571 ), | |
| 572 "tree_path_dependent", | |
| 573 True, | |
| 574 ) | |
| 575 | |
| 576 # 9) XGBoost | |
| 577 if "xgbclassifier" in lname: | |
| 578 return ( | |
| 579 shap.TreeExplainer( | |
| 580 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 581 ), | |
| 582 "tree_path_dependent", | |
| 583 True, | |
| 584 ) | |
| 585 | |
| 586 # 10) CatBoost (classifier) | |
| 587 if "catboost" in lname: | |
| 588 return ( | |
| 589 shap.TreeExplainer( | |
| 590 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 591 ), | |
| 592 "tree_path_dependent", | |
| 593 True, | |
| 594 ) | |
| 595 | |
| 596 # 11) KNN | |
| 597 if "kneighborsclassifier" in lname: | |
| 598 return _permutation(model.predict_proba), "permutation-proba", False | |
| 599 | |
| 600 # 12) SVM - linear kernel | |
| 601 if "svc" in lname or "svm" in lname: | |
| 602 kernel = getattr(model, "kernel", None) | |
| 603 if kernel == "linear": | |
| 604 return shap.LinearExplainer(model, bg), "linear", False | |
| 605 return _permutation(predict_fn), "permutation-svm", False | |
| 606 | |
| 607 # 13) Decision Tree | |
| 608 if "decisiontreeclassifier" in lname: | |
| 609 return ( | |
| 610 shap.TreeExplainer( | |
| 611 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 612 ), | |
| 613 "tree_path_dependent", | |
| 614 True, | |
| 615 ) | |
| 616 | |
| 617 # 14) Naive Bayes | |
| 618 if "naive_bayes" in lname or lname.endswith("nb"): | |
| 619 fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn | |
| 620 return _permutation(fn), "permutation-proba", False | |
| 621 | |
| 622 # 15) QDA | |
| 623 if "quadraticdiscriminantanalysis" in lname: | |
| 624 return _permutation(model.predict_proba), "permutation-proba", False | |
| 625 | |
| 626 # 16) Dummy | |
| 627 if "dummyclassifier" in lname: | |
| 628 return None, None, False | |
| 629 | |
| 630 # Default classification: permutation on predict_fn | |
| 631 return _permutation(predict_fn), "permutation-default", False | |
| 632 | |
| 633 # Regression path | |
| 634 # Linear family | |
| 635 linear_keys = [ | |
| 636 "linearregression", | |
| 637 "lasso", | |
| 638 "ridge", | |
| 639 "elasticnet", | |
| 640 "lars", | |
| 641 "lassolars", | |
| 642 "orthogonalmatchingpursuit", | |
| 643 "bayesianridge", | |
| 644 "ardregression", | |
| 645 "passiveaggressiveregressor", | |
| 646 "theilsenregressor", | |
| 647 "huberregressor", | |
| 648 ] | |
| 649 if any(k in lname for k in linear_keys): | |
| 650 return shap.LinearExplainer(model, bg), "linear", False | |
| 651 | |
| 652 # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic) | |
| 653 if "kernelridge" in lname: | |
| 654 return _permutation(predict_fn), "permutation-kernelridge", False | |
| 655 if "svr" in lname or "svm" in lname: | |
| 656 kernel = getattr(model, "kernel", None) | |
| 657 if kernel == "linear": | |
| 658 return shap.LinearExplainer(model, bg), "linear", False | |
| 659 return _permutation(predict_fn), "permutation-svr", False | |
| 660 if "kneighborsregressor" in lname: | |
| 661 return _permutation(predict_fn), "permutation-knn", False | |
| 662 if "mlpregressor" in lname: | |
| 663 return _permutation(predict_fn), "permutation-mlp", False | |
| 664 if "ransacregressor" in lname: | |
| 665 return _permutation(predict_fn), "permutation-ransac", False | |
| 666 | |
| 667 # Tree-based regressors | |
| 668 tree_class_names = [ | |
| 669 "decisiontreeregressor", | |
| 670 "randomforestregressor", | |
| 671 "extratreesregressor", | |
| 672 "adaboostregressor", | |
| 673 "gradientboostingregressor", | |
| 674 ] | |
| 675 if any(k in lname for k in tree_class_names): | |
| 676 return ( | |
| 677 shap.TreeExplainer( | |
| 678 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 679 ), | |
| 680 "tree_path_dependent", | |
| 681 True, | |
| 682 ) | |
| 683 | |
| 684 # Boosting libraries | |
| 685 if "lgbmregressor" in lname or "lightgbm" in lname: | |
| 686 return ( | |
| 687 shap.TreeExplainer( | |
| 688 model, | |
| 689 bg, | |
| 690 model_output="raw", | |
| 691 feature_perturbation="tree_path_dependent", | |
| 692 n_jobs=-1, | |
| 693 ), | |
| 694 "tree_path_dependent", | |
| 695 True, | |
| 696 ) | |
| 697 if "xgbregressor" in lname or "xgboost" in lname: | |
| 698 return ( | |
| 699 shap.TreeExplainer( | |
| 700 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 701 ), | |
| 702 "tree_path_dependent", | |
| 703 True, | |
| 704 ) | |
| 705 if "catboost" in lname: | |
| 706 return ( | |
| 707 shap.TreeExplainer( | |
| 708 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 709 ), | |
| 710 "tree_path_dependent", | |
| 711 True, | |
| 712 ) | |
| 713 | |
| 714 # Default regression: model-agnostic permutation explainer | |
| 715 return _permutation(predict_fn), "permutation-default", False | |
| 716 | |
| 488 def run(self): | 717 def run(self): |
| 489 if ( | 718 if ( |
| 490 self.exp is None | 719 self.exp is None |
| 491 or not hasattr(self.exp, "is_setup") | 720 or not hasattr(self.exp, "is_setup") |
| 492 or not self.exp.is_setup | 721 or not self.exp.is_setup |
