Mercurial > repos > goeckslab > tabular_learner
annotate pycaret_classification.py @ 11:a76dfceb62e0 draft
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 15:46:05 +0000 |
| parents | ba45bc057d70 |
| children |
| rev | line source |
|---|---|
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
1 import logging |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
2 import types |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
3 from typing import Dict |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
4 |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
5 import numpy as np |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
6 import pandas as pd |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
7 import plotly.graph_objects as go |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
8 from base_model_trainer import BaseModelTrainer |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
9 from dashboard import generate_classifier_explainer_dashboard |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
10 from pycaret.classification import ClassificationExperiment |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
12 from utils import predict_proba |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
13 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
14 LOG = logging.getLogger(__name__) |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
15 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
16 |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
17 def _apply_report_layout(fig: go.Figure) -> go.Figure: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
19 fig.update_xaxes(automargin=True, title_standoff=12) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
20 fig.update_yaxes(automargin=True, title_standoff=12) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
21 fig.update_layout( |
|
11
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
22 plot_bgcolor="#ffffff", |
|
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
23 paper_bgcolor="#ffffff", |
|
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
24 ) |
|
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
25 fig.update_xaxes(gridcolor="#e8e8e8") |
|
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
26 fig.update_yaxes(gridcolor="#e8e8e8") |
|
a76dfceb62e0
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
8
diff
changeset
|
27 fig.update_layout( |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
28 autosize=True, |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
29 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
30 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
31 return fig |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
32 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
33 |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
34 class ClassificationModelTrainer(BaseModelTrainer): |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
35 def __init__( |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
36 self, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
37 input_file, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
38 target_col, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
39 output_dir, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
40 task_type, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
41 random_seed, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
42 test_file=None, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
43 **kwargs, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
44 ): |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
45 super().__init__( |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
46 input_file, |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
47 target_col, |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
48 output_dir, |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
49 task_type, |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
50 random_seed, |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
51 test_file, |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
52 **kwargs, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
53 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
54 self.exp = ClassificationExperiment() |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
55 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
56 def save_dashboard(self): |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
57 LOG.info("Saving explainer dashboard") |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
58 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
59 dashboard.save_html("dashboard.html") |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
60 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
61 def generate_plots(self): |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
62 LOG.info("Generating and saving plots") |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
63 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
64 if not hasattr(self.best_model, "predict_proba"): |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
65 self.best_model.predict_proba = types.MethodType( |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
66 predict_proba, self.best_model |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
67 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
68 LOG.warning( |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
69 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
70 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
71 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
72 plots = [ |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
73 "auc", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
74 "threshold", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
75 "pr", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
76 "error", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
77 "class_report", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
78 "learning", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
79 "calibration", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
80 "vc", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
81 "dimension", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
82 "manifold", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
83 "rfe", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
84 "feature", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
85 "feature_all", |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
86 ] |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
87 for plot_name in plots: |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
88 try: |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
89 if plot_name == "threshold": |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
90 plot_path = self.exp.plot_model( |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
91 self.best_model, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
92 plot=plot_name, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
93 save=True, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
94 plot_kwargs={"binary": True, "percentage": True}, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
95 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
96 self.plots[plot_name] = plot_path |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
97 elif plot_name == "auc" and not self.exp.is_multiclass: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
98 plot_path = self.exp.plot_model( |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
99 self.best_model, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
100 plot=plot_name, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
101 save=True, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
102 plot_kwargs={ |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
103 "micro": False, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
104 "macro": False, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
105 "per_class": False, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
106 "binary": True, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
107 }, |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
108 ) |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
109 self.plots[plot_name] = plot_path |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
110 else: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
111 plot_path = self.exp.plot_model( |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
112 self.best_model, plot=plot_name, save=True |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
113 ) |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
114 self.plots[plot_name] = plot_path |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
115 except Exception as e: |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
116 LOG.error(f"Error generating plot {plot_name}: {e}") |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
117 continue |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
118 |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
119 def generate_plots_explainer(self): |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
120 from explainerdashboard import ClassifierExplainer |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
121 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
122 LOG.info("Generating explainer plots") |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
123 |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
124 # Ensure predict_proba is available here too |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
125 if not hasattr(self.best_model, "predict_proba"): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
126 self.best_model.predict_proba = types.MethodType( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
127 predict_proba, self.best_model |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
128 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
129 LOG.warning( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
130 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
131 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
132 |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
133 X_test = self.exp.X_test_transformed.copy() |
|
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
134 y_test = self.exp.y_test_transformed |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
135 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
136 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
137 # a dict to hold the raw Figure objects or callables |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
138 self.explainer_plots: Dict[str, go.Figure] = {} |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
139 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
140 # --- Threshold-aware overrides for CM / ROC / PR --- |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
141 prob_thresh = getattr(self, "probability_threshold", None) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
142 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
143 # Only for binary classification and when threshold is provided |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
144 if (prob_thresh is not None) and (not self.exp.is_multiclass): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
145 X = self.exp.X_test_transformed |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
146 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
147 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
148 # Get positive-class scores (robust defaults) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
149 classes = list(getattr(self.best_model, "classes_", [0, 1])) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
150 try: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
151 pos_idx = classes.index(1) if 1 in classes else 1 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
152 except Exception: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
153 pos_idx = 1 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
154 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
155 proba = self.best_model.predict_proba(X) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
156 y_scores = proba[:, pos_idx] |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
157 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
158 # Derive label names consistently |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
159 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
160 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
161 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
162 # ---- Confusion Matrix @ threshold ---- |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
163 try: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
164 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
165 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
166 fig_cm = go.Figure( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
167 data=go.Heatmap( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
168 z=cm, |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
169 x=[f"Pred {neg_label}", f"Pred {pos_label}"], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
170 y=[f"True {neg_label}", f"True {pos_label}"], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
171 text=cm, |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
172 texttemplate="%{text}", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
173 colorscale="Blues", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
174 showscale=False, |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
175 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
176 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
177 fig_cm.update_layout( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
178 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
179 xaxis_title="Predicted label", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
180 yaxis_title="True label", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
181 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
182 _apply_report_layout(fig_cm) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
183 self.explainer_plots["confusion_matrix"] = fig_cm |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
184 except Exception as e: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
185 LOG.warning( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
186 f"Threshold-aware confusion matrix failed; falling back: {e}" |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
187 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
188 |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
189 # ---- ROC with threshold marker ---- |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
190 try: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
191 fpr, tpr, thr = roc_curve(y, y_scores) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
192 roc_auc = auc(fpr, tpr) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
193 fig_roc = go.Figure() |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
194 fig_roc.add_scatter( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
195 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
196 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
197 if len(thr): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
198 mask = np.isfinite(thr) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
199 if mask.any(): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
200 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
201 idx = np.where(mask)[0][idx_local] |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
202 if 0 <= idx < len(fpr): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
203 fig_roc.add_scatter( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
204 x=[fpr[idx]], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
205 y=[tpr[idx]], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
206 mode="markers", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
207 name=f"@ {prob_thresh:.2f}", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
208 marker=dict(size=10), |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
209 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
210 fig_roc.update_layout( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
211 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
212 xaxis_title="False Positive Rate", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
213 yaxis_title="True Positive Rate", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
214 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
215 _apply_report_layout(fig_roc) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
216 self.explainer_plots["roc_auc"] = fig_roc |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
217 except Exception as e: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
218 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
219 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
220 # ---- PR with threshold marker ---- |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
221 try: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
222 precision, recall, thr_pr = precision_recall_curve(y, y_scores) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
223 pr_auc = auc(recall, precision) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
224 fig_pr = go.Figure() |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
225 fig_pr.add_scatter( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
226 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
227 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
228 if len(thr_pr): |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
229 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
230 # note: thr_pr has length = len(precision) - 1 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
231 idx_pr = max(0, min(idx_pr, len(recall) - 1)) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
232 fig_pr.add_scatter( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
233 x=[recall[idx_pr]], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
234 y=[precision[idx_pr]], |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
235 mode="markers", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
236 name=f"@ {prob_thresh:.2f}", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
237 marker=dict(size=10), |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
238 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
239 fig_pr.update_layout( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
240 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
241 xaxis_title="Recall", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
242 yaxis_title="Precision", |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
243 ) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
244 _apply_report_layout(fig_pr) |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
245 self.explainer_plots["pr_auc"] = fig_pr |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
246 except Exception as e: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
247 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
248 |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
249 # these go into the Test tab (don't overwrite overrides) |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
250 for key, fn in [ |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
251 ("roc_auc", explainer.plot_roc_auc), |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
252 ("pr_auc", explainer.plot_pr_auc), |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
253 ("lift_curve", explainer.plot_lift_curve), |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
254 ("confusion_matrix", explainer.plot_confusion_matrix), |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
255 ("threshold", explainer.plot_precision), # percentage vs probability |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
256 ("cumulative_precision", explainer.plot_cumulative_precision), |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
257 ]: |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
258 if key in self.explainer_plots: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
259 continue |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
260 try: |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
261 fig = fn() |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
262 if fig is not None: |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
263 self.explainer_plots[key] = fig |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
264 except Exception as e: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
265 LOG.error(f"Error generating explainer plot {key}: {e}") |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
266 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
267 # mean SHAP importances |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
268 try: |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
269 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
270 except Exception as e: |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
271 LOG.warning(f"Could not generate shap_mean: {e}") |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
272 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
273 # permutation importances |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
274 try: |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
275 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
276 kind="permutation" |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
277 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
278 except Exception as e: |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
279 LOG.warning(f"Could not generate shap_perm: {e}") |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
280 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
281 # PDPs for each feature (appended last) |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
282 valid_feats = [] |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
283 for feat in self.features_name: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
284 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
285 valid_feats.append(feat) |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
286 else: |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
287 LOG.warning( |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
288 f"Skipping PDP for feature {feat!r}: not found in explainer data" |
|
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
289 ) |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
290 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
291 for feat in valid_feats: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
292 # wrap each PDP call to catch any unexpected AssertionErrors |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
293 def make_pdp_plotter(f): |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
294 def _plot(): |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
295 try: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
296 return explainer.plot_pdp(f) |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
297 except AssertionError as ae: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
298 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
299 return None |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
300 except Exception as e: |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
301 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
|
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
302 return None |
|
8
ba45bc057d70
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
4
diff
changeset
|
303 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
304 return _plot |
|
0
209b663a4f62
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
goeckslab
parents:
diff
changeset
|
305 |
|
4
11fdac5affb3
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
0
diff
changeset
|
306 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
