comparison feature_importance.py @ 17:c5c324ac29fc draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
author goeckslab
date Sat, 06 Dec 2025 14:20:36 +0000
parents 4fee4504646e
children
comparison
equal deleted inserted replaced
16:4fee4504646e 17:c5c324ac29fc
285 display_features = list(X_data.columns) 285 display_features = list(X_data.columns)
286 max_display = len(display_features) 286 max_display = len(display_features)
287 287
288 # Background set 288 # Background set
289 bg = X_data.sample(min(len(X_data), 100), random_state=42) 289 bg = X_data.sample(min(len(X_data), 100), random_state=42)
290 predict_fn = ( 290 predict_fn = self._get_predict_fn(model)
291 model.predict_proba if hasattr(model, "predict_proba") else model.predict 291
292 # Optimized explainer based on model type
293 explainer, explainer_label, tree_based = self._choose_shap_explainer(
294 model, bg, predict_fn
292 ) 295 )
293 296 if explainer is None:
294 # Optimized explainer 297 LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model)
295 explainer = None 298 self.shap_model_name = None
296 explainer_label = None 299 return
297 if hasattr(model, "feature_importances_"):
298 explainer = shap.TreeExplainer(
299 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
300 )
301 explainer_label = "tree_path_dependent"
302 elif hasattr(model, "coef_"):
303 explainer = shap.LinearExplainer(model, bg)
304 explainer_label = "linear"
305 else:
306 explainer = shap.Explainer(predict_fn, bg)
307 explainer_label = explainer.__class__.__name__
308 300
309 try: 301 try:
310 shap_values = explainer(X_data) 302 shap_values = explainer(X_data)
311 self.shap_model_name = explainer.__class__.__name__ 303 self.shap_model_name = explainer.__class__.__name__
312 except Exception as e: 304 except Exception as e:
313 error_message = str(e) 305 error_message = str(e)
314 needs_tree_fallback = ( 306 needs_tree_fallback = (
315 hasattr(model, "feature_importances_") 307 tree_based
316 and "does not cover all the leaves" in error_message.lower() 308 and "does not cover all the leaves" in error_message.lower()
317 ) 309 )
318 feature_name_mismatch = "feature names should match" in error_message.lower() 310 feature_name_mismatch = "feature names should match" in error_message.lower()
319 if needs_tree_fallback: 311 if needs_tree_fallback:
320 LOG.warning( 312 LOG.warning(
346 "SHAP computation failed due to feature-name mismatch (%s). " 338 "SHAP computation failed due to feature-name mismatch (%s). "
347 "Falling back to model-agnostic SHAP explainer.", 339 "Falling back to model-agnostic SHAP explainer.",
348 error_message, 340 error_message,
349 ) 341 )
350 try: 342 try:
351 agnostic_explainer = shap.Explainer(predict_fn, bg) 343 agnostic_explainer = shap.Explainer(
344 predict_fn, bg, algorithm="permutation"
345 )
352 shap_values = agnostic_explainer(X_data) 346 shap_values = agnostic_explainer(X_data)
353 self.shap_model_name = ( 347 self.shap_model_name = (
354 f"{agnostic_explainer.__class__.__name__} (fallback)" 348 f"{agnostic_explainer.__class__.__name__} (fallback)"
355 ) 349 )
356 except Exception as fallback_exc: 350 except Exception as fallback_exc:
483 477
484 def encode_image_to_base64(self, img_path): 478 def encode_image_to_base64(self, img_path):
485 with open(img_path, "rb") as img_file: 479 with open(img_path, "rb") as img_file:
486 return base64.b64encode(img_file.read()).decode("utf-8") 480 return base64.b64encode(img_file.read()).decode("utf-8")
487 481
482 def _get_predict_fn(self, model):
483 if hasattr(model, "predict_proba"):
484 return model.predict_proba
485 if hasattr(model, "decision_function"):
486 return model.decision_function
487 return model.predict
488
489 def _choose_shap_explainer(self, model, bg, predict_fn):
490 """
491 Select a SHAP explainer following the prescribed priority order for
492 algorithms. Returns (explainer, label, is_tree_based).
493 """
494 if model is None:
495 return None, None, False
496
497 name = model.__class__.__name__
498 lname = name.lower()
499 task = getattr(self, "task_type", None)
500
501 def _permutation(fn):
502 return shap.Explainer(fn, bg, algorithm="permutation")
503
504 if task == "classification":
505 # 1) Logistic Regression
506 if "logisticregression" in lname:
507 return _permutation(model.predict_proba), "permutation-proba", False
508
509 # 2) Ridge Classifier
510 if "ridgeclassifier" in lname:
511 fn = (
512 model.decision_function
513 if hasattr(model, "decision_function")
514 else predict_fn
515 )
516 return _permutation(fn), "permutation-decision_function", False
517
518 # 3) LDA
519 if "lineardiscriminantanalysis" in lname:
520 return _permutation(model.predict_proba), "permutation-proba", False
521
522 # 4) Random Forest
523 if "randomforestclassifier" in lname:
524 return (
525 shap.TreeExplainer(
526 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
527 ),
528 "tree_path_dependent",
529 True,
530 )
531
532 # 5) Gradient Boosting
533 if "gradientboostingclassifier" in lname:
534 return (
535 shap.TreeExplainer(
536 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
537 ),
538 "tree_path_dependent",
539 True,
540 )
541
542 # 6) AdaBoost
543 if "adaboostclassifier" in lname:
544 return (
545 shap.TreeExplainer(
546 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
547 ),
548 "tree_path_dependent",
549 True,
550 )
551
552 # 7) Extra Trees
553 if "extratreesclassifier" in lname:
554 return (
555 shap.TreeExplainer(
556 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
557 ),
558 "tree_path_dependent",
559 True,
560 )
561
562 # 8) LightGBM
563 if "lgbmclassifier" in lname:
564 return (
565 shap.TreeExplainer(
566 model,
567 bg,
568 model_output="raw",
569 feature_perturbation="tree_path_dependent",
570 n_jobs=-1,
571 ),
572 "tree_path_dependent",
573 True,
574 )
575
576 # 9) XGBoost
577 if "xgbclassifier" in lname:
578 return (
579 shap.TreeExplainer(
580 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
581 ),
582 "tree_path_dependent",
583 True,
584 )
585
586 # 10) CatBoost (classifier)
587 if "catboost" in lname:
588 return (
589 shap.TreeExplainer(
590 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
591 ),
592 "tree_path_dependent",
593 True,
594 )
595
596 # 11) KNN
597 if "kneighborsclassifier" in lname:
598 return _permutation(model.predict_proba), "permutation-proba", False
599
600 # 12) SVM - linear kernel
601 if "svc" in lname or "svm" in lname:
602 kernel = getattr(model, "kernel", None)
603 if kernel == "linear":
604 return shap.LinearExplainer(model, bg), "linear", False
605 return _permutation(predict_fn), "permutation-svm", False
606
607 # 13) Decision Tree
608 if "decisiontreeclassifier" in lname:
609 return (
610 shap.TreeExplainer(
611 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
612 ),
613 "tree_path_dependent",
614 True,
615 )
616
617 # 14) Naive Bayes
618 if "naive_bayes" in lname or lname.endswith("nb"):
619 fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn
620 return _permutation(fn), "permutation-proba", False
621
622 # 15) QDA
623 if "quadraticdiscriminantanalysis" in lname:
624 return _permutation(model.predict_proba), "permutation-proba", False
625
626 # 16) Dummy
627 if "dummyclassifier" in lname:
628 return None, None, False
629
630 # Default classification: permutation on predict_fn
631 return _permutation(predict_fn), "permutation-default", False
632
633 # Regression path
634 # Linear family
635 linear_keys = [
636 "linearregression",
637 "lasso",
638 "ridge",
639 "elasticnet",
640 "lars",
641 "lassolars",
642 "orthogonalmatchingpursuit",
643 "bayesianridge",
644 "ardregression",
645 "passiveaggressiveregressor",
646 "theilsenregressor",
647 "huberregressor",
648 ]
649 if any(k in lname for k in linear_keys):
650 return shap.LinearExplainer(model, bg), "linear", False
651
652 # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic)
653 if "kernelridge" in lname:
654 return _permutation(predict_fn), "permutation-kernelridge", False
655 if "svr" in lname or "svm" in lname:
656 kernel = getattr(model, "kernel", None)
657 if kernel == "linear":
658 return shap.LinearExplainer(model, bg), "linear", False
659 return _permutation(predict_fn), "permutation-svr", False
660 if "kneighborsregressor" in lname:
661 return _permutation(predict_fn), "permutation-knn", False
662 if "mlpregressor" in lname:
663 return _permutation(predict_fn), "permutation-mlp", False
664 if "ransacregressor" in lname:
665 return _permutation(predict_fn), "permutation-ransac", False
666
667 # Tree-based regressors
668 tree_class_names = [
669 "decisiontreeregressor",
670 "randomforestregressor",
671 "extratreesregressor",
672 "adaboostregressor",
673 "gradientboostingregressor",
674 ]
675 if any(k in lname for k in tree_class_names):
676 return (
677 shap.TreeExplainer(
678 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
679 ),
680 "tree_path_dependent",
681 True,
682 )
683
684 # Boosting libraries
685 if "lgbmregressor" in lname or "lightgbm" in lname:
686 return (
687 shap.TreeExplainer(
688 model,
689 bg,
690 model_output="raw",
691 feature_perturbation="tree_path_dependent",
692 n_jobs=-1,
693 ),
694 "tree_path_dependent",
695 True,
696 )
697 if "xgbregressor" in lname or "xgboost" in lname:
698 return (
699 shap.TreeExplainer(
700 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
701 ),
702 "tree_path_dependent",
703 True,
704 )
705 if "catboost" in lname:
706 return (
707 shap.TreeExplainer(
708 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
709 ),
710 "tree_path_dependent",
711 True,
712 )
713
714 # Default regression: model-agnostic permutation explainer
715 return _permutation(predict_fn), "permutation-default", False
716
488 def run(self): 717 def run(self):
489 if ( 718 if (
490 self.exp is None 719 self.exp is None
491 or not hasattr(self.exp, "is_setup") 720 or not hasattr(self.exp, "is_setup")
492 or not self.exp.is_setup 721 or not self.exp.is_setup