Mercurial > repos > goeckslab > pycaret_predict
annotate pycaret_classification.py @ 12:e674b9e946fb draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:39:12 +0000 |
parents | 1aed7d47c5ec |
children |
rev | line source |
---|---|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
1 import logging |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
2 import types |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
3 from typing import Dict |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
4 |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
5 import numpy as np |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
6 import pandas as pd |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
7 import plotly.graph_objects as go |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
8 from base_model_trainer import BaseModelTrainer |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
9 from dashboard import generate_classifier_explainer_dashboard |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
10 from pycaret.classification import ClassificationExperiment |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
12 from utils import predict_proba |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
13 |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
14 LOG = logging.getLogger(__name__) |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
15 |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
16 |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
17 def _apply_report_layout(fig: go.Figure) -> go.Figure: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
19 fig.update_xaxes(automargin=True, title_standoff=12) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
20 fig.update_yaxes(automargin=True, title_standoff=12) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
21 fig.update_layout( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
22 autosize=True, |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
24 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
25 return fig |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
26 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
27 |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
28 class ClassificationModelTrainer(BaseModelTrainer): |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
29 def __init__( |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
30 self, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
31 input_file, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
32 target_col, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
33 output_dir, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
34 task_type, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
35 random_seed, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
36 test_file=None, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
37 **kwargs, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
38 ): |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
39 super().__init__( |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
40 input_file, |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
41 target_col, |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
42 output_dir, |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
43 task_type, |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
44 random_seed, |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
45 test_file, |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
46 **kwargs, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
47 ) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
48 self.exp = ClassificationExperiment() |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
49 |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
50 def save_dashboard(self): |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
51 LOG.info("Saving explainer dashboard") |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
52 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
53 dashboard.save_html("dashboard.html") |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
54 |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
55 def generate_plots(self): |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
56 LOG.info("Generating and saving plots") |
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
57 |
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
58 if not hasattr(self.best_model, "predict_proba"): |
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
59 self.best_model.predict_proba = types.MethodType( |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
60 predict_proba, self.best_model |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
61 ) |
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
62 LOG.warning( |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
64 ) |
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
65 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
66 plots = [ |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
67 "auc", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
68 "threshold", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
69 "pr", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
70 "error", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
71 "class_report", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
72 "learning", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
73 "calibration", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
74 "vc", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
75 "dimension", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
76 "manifold", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
77 "rfe", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
78 "feature", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
79 "feature_all", |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
80 ] |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
81 for plot_name in plots: |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
82 try: |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
83 if plot_name == "threshold": |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
84 plot_path = self.exp.plot_model( |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
85 self.best_model, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
86 plot=plot_name, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
87 save=True, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
88 plot_kwargs={"binary": True, "percentage": True}, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
89 ) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
90 self.plots[plot_name] = plot_path |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
91 elif plot_name == "auc" and not self.exp.is_multiclass: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
92 plot_path = self.exp.plot_model( |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
93 self.best_model, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
94 plot=plot_name, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
95 save=True, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
96 plot_kwargs={ |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
97 "micro": False, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
98 "macro": False, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
99 "per_class": False, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
100 "binary": True, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
101 }, |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
102 ) |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
103 self.plots[plot_name] = plot_path |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
104 else: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
105 plot_path = self.exp.plot_model( |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
106 self.best_model, plot=plot_name, save=True |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
107 ) |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
108 self.plots[plot_name] = plot_path |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
109 except Exception as e: |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
110 LOG.error(f"Error generating plot {plot_name}: {e}") |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
111 continue |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
112 |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
113 def generate_plots_explainer(self): |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
114 from explainerdashboard import ClassifierExplainer |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
115 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
116 LOG.info("Generating explainer plots") |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
117 |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
118 # Ensure predict_proba is available here too |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
119 if not hasattr(self.best_model, "predict_proba"): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
120 self.best_model.predict_proba = types.MethodType( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
121 predict_proba, self.best_model |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
122 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
123 LOG.warning( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
125 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
126 |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
127 X_test = self.exp.X_test_transformed.copy() |
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
128 y_test = self.exp.y_test_transformed |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
129 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
130 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
131 # a dict to hold the raw Figure objects or callables |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
132 self.explainer_plots: Dict[str, go.Figure] = {} |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
133 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
134 # --- Threshold-aware overrides for CM / ROC / PR --- |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
135 prob_thresh = getattr(self, "probability_threshold", None) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
136 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
137 # Only for binary classification and when threshold is provided |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
138 if (prob_thresh is not None) and (not self.exp.is_multiclass): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
139 X = self.exp.X_test_transformed |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
141 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
142 # Get positive-class scores (robust defaults) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
143 classes = list(getattr(self.best_model, "classes_", [0, 1])) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
144 try: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
145 pos_idx = classes.index(1) if 1 in classes else 1 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
146 except Exception: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
147 pos_idx = 1 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
148 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
149 proba = self.best_model.predict_proba(X) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
150 y_scores = proba[:, pos_idx] |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
151 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
152 # Derive label names consistently |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
155 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
156 # ---- Confusion Matrix @ threshold ---- |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
157 try: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
160 fig_cm = go.Figure( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
161 data=go.Heatmap( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
162 z=cm, |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
163 x=[f"Pred {neg_label}", f"Pred {pos_label}"], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
164 y=[f"True {neg_label}", f"True {pos_label}"], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
165 text=cm, |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
166 texttemplate="%{text}", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
167 colorscale="Blues", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
168 showscale=False, |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
169 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
170 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
171 fig_cm.update_layout( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
173 xaxis_title="Predicted label", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
174 yaxis_title="True label", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
175 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
176 _apply_report_layout(fig_cm) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
177 self.explainer_plots["confusion_matrix"] = fig_cm |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
178 except Exception as e: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
179 LOG.warning( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
180 f"Threshold-aware confusion matrix failed; falling back: {e}" |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
181 ) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
182 |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
183 # ---- ROC with threshold marker ---- |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
184 try: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
185 fpr, tpr, thr = roc_curve(y, y_scores) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
186 roc_auc = auc(fpr, tpr) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
187 fig_roc = go.Figure() |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
188 fig_roc.add_scatter( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
190 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
191 if len(thr): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
192 mask = np.isfinite(thr) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
193 if mask.any(): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
195 idx = np.where(mask)[0][idx_local] |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
196 if 0 <= idx < len(fpr): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
197 fig_roc.add_scatter( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
198 x=[fpr[idx]], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
199 y=[tpr[idx]], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
200 mode="markers", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
201 name=f"@ {prob_thresh:.2f}", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
202 marker=dict(size=10), |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
203 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
204 fig_roc.update_layout( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
206 xaxis_title="False Positive Rate", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
207 yaxis_title="True Positive Rate", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
208 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
209 _apply_report_layout(fig_roc) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
210 self.explainer_plots["roc_auc"] = fig_roc |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
211 except Exception as e: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
213 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
214 # ---- PR with threshold marker ---- |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
215 try: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
216 precision, recall, thr_pr = precision_recall_curve(y, y_scores) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
217 pr_auc = auc(recall, precision) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
218 fig_pr = go.Figure() |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
219 fig_pr.add_scatter( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
221 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
222 if len(thr_pr): |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
224 # note: thr_pr has length = len(precision) - 1 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
225 idx_pr = max(0, min(idx_pr, len(recall) - 1)) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
226 fig_pr.add_scatter( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
227 x=[recall[idx_pr]], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
228 y=[precision[idx_pr]], |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
229 mode="markers", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
230 name=f"@ {prob_thresh:.2f}", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
231 marker=dict(size=10), |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
232 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
233 fig_pr.update_layout( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
235 xaxis_title="Recall", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
236 yaxis_title="Precision", |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
237 ) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
238 _apply_report_layout(fig_pr) |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
239 self.explainer_plots["pr_auc"] = fig_pr |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
240 except Exception as e: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
242 |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
243 # these go into the Test tab (don't overwrite overrides) |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
244 for key, fn in [ |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
245 ("roc_auc", explainer.plot_roc_auc), |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
246 ("pr_auc", explainer.plot_pr_auc), |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
247 ("lift_curve", explainer.plot_lift_curve), |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
248 ("confusion_matrix", explainer.plot_confusion_matrix), |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
249 ("threshold", explainer.plot_precision), # percentage vs probability |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
250 ("cumulative_precision", explainer.plot_cumulative_precision), |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
251 ]: |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
252 if key in self.explainer_plots: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
253 continue |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
254 try: |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
255 fig = fn() |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
256 if fig is not None: |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
257 self.explainer_plots[key] = fig |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
258 except Exception as e: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
259 LOG.error(f"Error generating explainer plot {key}: {e}") |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
260 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
261 # mean SHAP importances |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
262 try: |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
263 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
264 except Exception as e: |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
265 LOG.warning(f"Could not generate shap_mean: {e}") |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
266 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
267 # permutation importances |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
268 try: |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
269 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
270 kind="permutation" |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
271 ) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
272 except Exception as e: |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
273 LOG.warning(f"Could not generate shap_perm: {e}") |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
274 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
275 # PDPs for each feature (appended last) |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
276 valid_feats = [] |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
277 for feat in self.features_name: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
278 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
279 valid_feats.append(feat) |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
280 else: |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
281 LOG.warning( |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
282 f"Skipping PDP for feature {feat!r}: not found in explainer data" |
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
283 ) |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
284 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
285 for feat in valid_feats: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
286 # wrap each PDP call to catch any unexpected AssertionErrors |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
287 def make_pdp_plotter(f): |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
288 def _plot(): |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
289 try: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
290 return explainer.plot_pdp(f) |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
291 except AssertionError as ae: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
293 return None |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
294 except Exception as e: |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
296 return None |
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
297 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
298 return _plot |
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
299 |
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |