Mercurial > repos > goeckslab > tabular_learner
annotate pycaret_classification.py @ 8:ba45bc057d70 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:38:55 +0000 |
parents | 11fdac5affb3 |
children |
rev | line source |
---|---|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
1 import logging |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
2 import types |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
3 from typing import Dict |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
4 |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
5 import numpy as np |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
6 import pandas as pd |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
7 import plotly.graph_objects as go |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
8 from base_model_trainer import BaseModelTrainer |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
9 from dashboard import generate_classifier_explainer_dashboard |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
10 from pycaret.classification import ClassificationExperiment |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
12 from utils import predict_proba |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
13 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
14 LOG = logging.getLogger(__name__) |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
15 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
16 |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
17 def _apply_report_layout(fig: go.Figure) -> go.Figure: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
19 fig.update_xaxes(automargin=True, title_standoff=12) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
20 fig.update_yaxes(automargin=True, title_standoff=12) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
21 fig.update_layout( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
22 autosize=True, |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
24 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
25 return fig |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
26 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
27 |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
28 class ClassificationModelTrainer(BaseModelTrainer): |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
29 def __init__( |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
30 self, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
31 input_file, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
32 target_col, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
33 output_dir, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
34 task_type, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
35 random_seed, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
36 test_file=None, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
37 **kwargs, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
38 ): |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
39 super().__init__( |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
40 input_file, |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
41 target_col, |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
42 output_dir, |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
43 task_type, |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
44 random_seed, |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
45 test_file, |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
46 **kwargs, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
47 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
48 self.exp = ClassificationExperiment() |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
49 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
50 def save_dashboard(self): |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
51 LOG.info("Saving explainer dashboard") |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
52 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
53 dashboard.save_html("dashboard.html") |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
54 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
55 def generate_plots(self): |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
56 LOG.info("Generating and saving plots") |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
57 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
58 if not hasattr(self.best_model, "predict_proba"): |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
59 self.best_model.predict_proba = types.MethodType( |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
60 predict_proba, self.best_model |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
61 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
62 LOG.warning( |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
64 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
65 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
66 plots = [ |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
67 "auc", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
68 "threshold", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
69 "pr", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
70 "error", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
71 "class_report", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
72 "learning", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
73 "calibration", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
74 "vc", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
75 "dimension", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
76 "manifold", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
77 "rfe", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
78 "feature", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
79 "feature_all", |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
80 ] |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
81 for plot_name in plots: |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
82 try: |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
83 if plot_name == "threshold": |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
84 plot_path = self.exp.plot_model( |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
85 self.best_model, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
86 plot=plot_name, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
87 save=True, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
88 plot_kwargs={"binary": True, "percentage": True}, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
89 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
90 self.plots[plot_name] = plot_path |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
91 elif plot_name == "auc" and not self.exp.is_multiclass: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
92 plot_path = self.exp.plot_model( |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
93 self.best_model, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
94 plot=plot_name, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
95 save=True, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
96 plot_kwargs={ |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
97 "micro": False, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
98 "macro": False, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
99 "per_class": False, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
100 "binary": True, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
101 }, |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
102 ) |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
103 self.plots[plot_name] = plot_path |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
104 else: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
105 plot_path = self.exp.plot_model( |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
106 self.best_model, plot=plot_name, save=True |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
107 ) |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
108 self.plots[plot_name] = plot_path |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
109 except Exception as e: |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
110 LOG.error(f"Error generating plot {plot_name}: {e}") |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
111 continue |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
112 |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
113 def generate_plots_explainer(self): |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
114 from explainerdashboard import ClassifierExplainer |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
115 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
116 LOG.info("Generating explainer plots") |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
117 |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
118 # Ensure predict_proba is available here too |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
119 if not hasattr(self.best_model, "predict_proba"): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
120 self.best_model.predict_proba = types.MethodType( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
121 predict_proba, self.best_model |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
122 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
123 LOG.warning( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
125 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
126 |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
127 X_test = self.exp.X_test_transformed.copy() |
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
128 y_test = self.exp.y_test_transformed |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
130 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
131 # a dict to hold the raw Figure objects or callables |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
132 self.explainer_plots: Dict[str, go.Figure] = {} |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
133 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
134 # --- Threshold-aware overrides for CM / ROC / PR --- |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
135 prob_thresh = getattr(self, "probability_threshold", None) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
136 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
137 # Only for binary classification and when threshold is provided |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
138 if (prob_thresh is not None) and (not self.exp.is_multiclass): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
139 X = self.exp.X_test_transformed |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
141 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
142 # Get positive-class scores (robust defaults) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
143 classes = list(getattr(self.best_model, "classes_", [0, 1])) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
144 try: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
145 pos_idx = classes.index(1) if 1 in classes else 1 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
146 except Exception: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
147 pos_idx = 1 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
148 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
149 proba = self.best_model.predict_proba(X) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
150 y_scores = proba[:, pos_idx] |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
151 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
152 # Derive label names consistently |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
155 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
156 # ---- Confusion Matrix @ threshold ---- |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
157 try: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
160 fig_cm = go.Figure( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
161 data=go.Heatmap( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
162 z=cm, |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
164 y=[f"True {neg_label}", f"True {pos_label}"], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
165 text=cm, |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
166 texttemplate="%{text}", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
167 colorscale="Blues", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
168 showscale=False, |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
169 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
170 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
171 fig_cm.update_layout( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
173 xaxis_title="Predicted label", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
174 yaxis_title="True label", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
175 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
176 _apply_report_layout(fig_cm) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
177 self.explainer_plots["confusion_matrix"] = fig_cm |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
178 except Exception as e: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
179 LOG.warning( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
180 f"Threshold-aware confusion matrix failed; falling back: {e}" |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
181 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
182 |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
183 # ---- ROC with threshold marker ---- |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
184 try: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
185 fpr, tpr, thr = roc_curve(y, y_scores) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
186 roc_auc = auc(fpr, tpr) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
187 fig_roc = go.Figure() |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
188 fig_roc.add_scatter( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
190 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
191 if len(thr): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
192 mask = np.isfinite(thr) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
193 if mask.any(): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
195 idx = np.where(mask)[0][idx_local] |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
196 if 0 <= idx < len(fpr): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
197 fig_roc.add_scatter( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
198 x=[fpr[idx]], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
199 y=[tpr[idx]], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
200 mode="markers", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
201 name=f"@ {prob_thresh:.2f}", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
202 marker=dict(size=10), |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
203 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
204 fig_roc.update_layout( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
206 xaxis_title="False Positive Rate", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
207 yaxis_title="True Positive Rate", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
208 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
209 _apply_report_layout(fig_roc) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
210 self.explainer_plots["roc_auc"] = fig_roc |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
211 except Exception as e: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
213 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
214 # ---- PR with threshold marker ---- |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
215 try: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
217 pr_auc = auc(recall, precision) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
218 fig_pr = go.Figure() |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
219 fig_pr.add_scatter( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
221 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
222 if len(thr_pr): |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
224 # note: thr_pr has length = len(precision) - 1 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
226 fig_pr.add_scatter( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
227 x=[recall[idx_pr]], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
228 y=[precision[idx_pr]], |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
229 mode="markers", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
230 name=f"@ {prob_thresh:.2f}", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
231 marker=dict(size=10), |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
232 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
233 fig_pr.update_layout( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
235 xaxis_title="Recall", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
236 yaxis_title="Precision", |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
237 ) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
238 _apply_report_layout(fig_pr) |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
239 self.explainer_plots["pr_auc"] = fig_pr |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
240 except Exception as e: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
242 |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
243 # these go into the Test tab (don't overwrite overrides) |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
244 for key, fn in [ |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
245 ("roc_auc", explainer.plot_roc_auc), |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
246 ("pr_auc", explainer.plot_pr_auc), |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
247 ("lift_curve", explainer.plot_lift_curve), |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
248 ("confusion_matrix", explainer.plot_confusion_matrix), |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
249 ("threshold", explainer.plot_precision), # percentage vs probability |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
250 ("cumulative_precision", explainer.plot_cumulative_precision), |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
251 ]: |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
252 if key in self.explainer_plots: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
253 continue |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
254 try: |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
255 fig = fn() |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
256 if fig is not None: |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
257 self.explainer_plots[key] = fig |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
258 except Exception as e: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
259 LOG.error(f"Error generating explainer plot {key}: {e}") |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
260 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
261 # mean SHAP importances |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
262 try: |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
263 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
264 except Exception as e: |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
265 LOG.warning(f"Could not generate shap_mean: {e}") |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
266 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
267 # permutation importances |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
268 try: |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
269 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
270 kind="permutation" |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
271 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
272 except Exception as e: |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
273 LOG.warning(f"Could not generate shap_perm: {e}") |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
274 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
275 # PDPs for each feature (appended last) |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
276 valid_feats = [] |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
277 for feat in self.features_name: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
278 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
279 valid_feats.append(feat) |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
280 else: |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
281 LOG.warning( |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
282 f"Skipping PDP for feature {feat!r}: not found in explainer data" |
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
283 ) |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
284 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
285 for feat in valid_feats: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
286 # wrap each PDP call to catch any unexpected AssertionErrors |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
287 def make_pdp_plotter(f): |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
288 def _plot(): |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
289 try: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
290 return explainer.plot_pdp(f) |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
291 except AssertionError as ae: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
293 return None |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
294 except Exception as e: |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
296 return None |
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
297 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
298 return _plot |
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
299 |
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |