Mercurial > repos > goeckslab > pycaret_predict
annotate pycaret_classification.py @ 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | a2aeeb754d76 |
| children |
| rev | line source |
|---|---|
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
1 import logging |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
2 import types |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
3 from typing import Dict |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
4 |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
5 import numpy as np |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
6 import pandas as pd |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
7 import plotly.graph_objects as go |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
8 from base_model_trainer import BaseModelTrainer |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
9 from dashboard import generate_classifier_explainer_dashboard |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
10 from pycaret.classification import ClassificationExperiment |
|
17
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
11 from sklearn.metrics import ( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
12 auc, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
13 confusion_matrix, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
14 matthews_corrcoef, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
15 precision_recall_curve, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
16 precision_recall_fscore_support, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
17 roc_curve, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
18 ) |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
19 from utils import predict_proba |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
20 |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
21 LOG = logging.getLogger(__name__) |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
22 |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
23 |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
24 def _apply_report_layout(fig: go.Figure) -> go.Figure: |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
25 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
26 fig.update_xaxes(automargin=True, title_standoff=12) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
27 fig.update_yaxes(automargin=True, title_standoff=12) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
28 fig.update_layout( |
|
15
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
29 plot_bgcolor="#ffffff", |
|
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
30 paper_bgcolor="#ffffff", |
|
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
31 ) |
|
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
32 fig.update_xaxes(gridcolor="#e8e8e8") |
|
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
33 fig.update_yaxes(gridcolor="#e8e8e8") |
|
a2aeeb754d76
planemo upload for repository https://github.com/goeckslab/gleam commit bc50fef8acb44aca15d0a1746e6c0c967da5bb17
goeckslab
parents:
12
diff
changeset
|
34 fig.update_layout( |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
35 autosize=True, |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
36 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
37 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
38 return fig |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
39 |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
40 |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
41 class ClassificationModelTrainer(BaseModelTrainer): |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
42 def __init__( |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
43 self, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
44 input_file, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
45 target_col, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
46 output_dir, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
47 task_type, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
48 random_seed, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
49 test_file=None, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
50 **kwargs, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
51 ): |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
52 super().__init__( |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
53 input_file, |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
54 target_col, |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
55 output_dir, |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
56 task_type, |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
57 random_seed, |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
58 test_file, |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
59 **kwargs, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
60 ) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
61 self.exp = ClassificationExperiment() |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
62 |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
63 def save_dashboard(self): |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
64 LOG.info("Saving explainer dashboard") |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
65 dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
66 dashboard.save_html("dashboard.html") |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
67 |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
68 def generate_plots(self): |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
69 LOG.info("Generating and saving plots") |
|
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
70 |
|
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
71 if not hasattr(self.best_model, "predict_proba"): |
|
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
72 self.best_model.predict_proba = types.MethodType( |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
73 predict_proba, self.best_model |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
74 ) |
|
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
75 LOG.warning( |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
76 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
77 ) |
|
2
0314dad38aaa
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
goeckslab
parents:
0
diff
changeset
|
78 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
79 plots = [ |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
80 "auc", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
81 "threshold", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
82 "pr", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
83 "error", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
84 "class_report", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
85 "learning", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
86 "calibration", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
87 "vc", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
88 "dimension", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
89 "manifold", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
90 "rfe", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
91 "feature", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
92 "feature_all", |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
93 ] |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
94 for plot_name in plots: |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
95 try: |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
96 if plot_name == "threshold": |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
97 plot_path = self.exp.plot_model( |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
98 self.best_model, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
99 plot=plot_name, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
100 save=True, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
101 plot_kwargs={"binary": True, "percentage": True}, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
102 ) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
103 self.plots[plot_name] = plot_path |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
104 elif plot_name == "auc" and not self.exp.is_multiclass: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
105 plot_path = self.exp.plot_model( |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
106 self.best_model, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
107 plot=plot_name, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
108 save=True, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
109 plot_kwargs={ |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
110 "micro": False, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
111 "macro": False, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
112 "per_class": False, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
113 "binary": True, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
114 }, |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
115 ) |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
116 self.plots[plot_name] = plot_path |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
117 else: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
118 plot_path = self.exp.plot_model( |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
119 self.best_model, plot=plot_name, save=True |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
120 ) |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
121 self.plots[plot_name] = plot_path |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
122 except Exception as e: |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
123 LOG.error(f"Error generating plot {plot_name}: {e}") |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
124 continue |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
125 |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
126 def generate_plots_explainer(self): |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
127 from explainerdashboard import ClassifierExplainer |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
128 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
129 LOG.info("Generating explainer plots") |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
130 |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
131 # Ensure predict_proba is available here too |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
132 if not hasattr(self.best_model, "predict_proba"): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
133 self.best_model.predict_proba = types.MethodType( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
134 predict_proba, self.best_model |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
135 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
136 LOG.warning( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
137 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
138 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
139 |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
140 X_test = self.exp.X_test_transformed.copy() |
|
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
141 y_test = self.exp.y_test_transformed |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
142 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
143 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
144 # a dict to hold the raw Figure objects or callables |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
145 self.explainer_plots: Dict[str, go.Figure] = {} |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
146 |
|
17
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
147 y_true, y_pred, label_values, y_scores = self._get_test_predictions() |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
148 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
149 # — Classification report (Plotly table) — |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
150 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
151 fig_report = self._build_classification_report_fig( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
152 y_true, y_pred, label_values |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
153 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
154 if fig_report is not None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
155 self.explainer_plots["class_report"] = fig_report |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
156 except Exception as e: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
157 LOG.warning(f"Could not generate Plotly classification report: {e}") |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
158 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
159 # — Confusion matrix with actual labels — |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
160 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
161 fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
162 if fig_cm is not None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
163 self.explainer_plots["confusion_matrix"] = fig_cm |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
164 except Exception as e: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
165 LOG.warning(f"Could not generate Plotly confusion matrix: {e}") |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
166 |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
167 # --- Threshold-aware overrides for CM / ROC / PR --- |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
168 prob_thresh = getattr(self, "probability_threshold", None) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
169 |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
170 # Only for binary classification and when threshold is provided |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
171 if (prob_thresh is not None) and (not self.exp.is_multiclass): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
172 # ---- ROC with threshold marker ---- |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
173 try: |
|
17
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
174 if y_scores is None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
175 raise ValueError("Predicted probabilities unavailable") |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
176 fpr, tpr, thr = roc_curve(y_true, y_scores) |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
177 roc_auc = auc(fpr, tpr) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
178 fig_roc = go.Figure() |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
179 fig_roc.add_scatter( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
180 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
181 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
182 if len(thr): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
183 mask = np.isfinite(thr) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
184 if mask.any(): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
185 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh))) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
186 idx = np.where(mask)[0][idx_local] |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
187 if 0 <= idx < len(fpr): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
188 fig_roc.add_scatter( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
189 x=[fpr[idx]], |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
190 y=[tpr[idx]], |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
191 mode="markers", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
192 name=f"@ {prob_thresh:.2f}", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
193 marker=dict(size=10), |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
194 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
195 fig_roc.update_layout( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
196 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
197 xaxis_title="False Positive Rate", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
198 yaxis_title="True Positive Rate", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
199 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
200 _apply_report_layout(fig_roc) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
201 self.explainer_plots["roc_auc"] = fig_roc |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
202 except Exception as e: |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
203 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
204 |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
205 # ---- PR with threshold marker ---- |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
206 try: |
|
17
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
207 if y_scores is None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
208 raise ValueError("Predicted probabilities unavailable") |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
209 precision, recall, thr_pr = precision_recall_curve(y_true, y_scores) |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
210 pr_auc = auc(recall, precision) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
211 fig_pr = go.Figure() |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
212 fig_pr.add_scatter( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
213 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
214 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
215 if len(thr_pr): |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
216 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh))) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
217 # note: thr_pr has length = len(precision) - 1 |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
218 idx_pr = max(0, min(idx_pr, len(recall) - 1)) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
219 fig_pr.add_scatter( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
220 x=[recall[idx_pr]], |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
221 y=[precision[idx_pr]], |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
222 mode="markers", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
223 name=f"@ {prob_thresh:.2f}", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
224 marker=dict(size=10), |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
225 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
226 fig_pr.update_layout( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
227 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
228 xaxis_title="Recall", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
229 yaxis_title="Precision", |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
230 ) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
231 _apply_report_layout(fig_pr) |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
232 self.explainer_plots["pr_auc"] = fig_pr |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
233 except Exception as e: |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
234 LOG.warning(f"Threshold marker on PR failed; falling back: {e}") |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
235 |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
236 # these go into the Test tab (don't overwrite overrides) |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
237 for key, fn in [ |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
238 ("roc_auc", explainer.plot_roc_auc), |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
239 ("pr_auc", explainer.plot_pr_auc), |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
240 ("lift_curve", explainer.plot_lift_curve), |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
241 ("confusion_matrix", explainer.plot_confusion_matrix), |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
242 ("threshold", explainer.plot_precision), # percentage vs probability |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
243 ("cumulative_precision", explainer.plot_cumulative_precision), |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
244 ]: |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
245 if key in self.explainer_plots: |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
246 continue |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
247 try: |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
248 fig = fn() |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
249 if fig is not None: |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
250 self.explainer_plots[key] = fig |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
251 except Exception as e: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
252 LOG.error(f"Error generating explainer plot {key}: {e}") |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
253 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
254 # mean SHAP importances |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
255 try: |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
256 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
257 except Exception as e: |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
258 LOG.warning(f"Could not generate shap_mean: {e}") |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
259 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
260 # permutation importances |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
261 try: |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
262 self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances( |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
263 kind="permutation" |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
264 ) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
265 except Exception as e: |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
266 LOG.warning(f"Could not generate shap_perm: {e}") |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
267 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
268 # PDPs for each feature (appended last) |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
269 valid_feats = [] |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
270 for feat in self.features_name: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
271 if feat in explainer.X.columns or feat in explainer.onehot_cols: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
272 valid_feats.append(feat) |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
273 else: |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
274 LOG.warning( |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
275 f"Skipping PDP for feature {feat!r}: not found in explainer data" |
|
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
276 ) |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
277 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
278 for feat in valid_feats: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
279 # wrap each PDP call to catch any unexpected AssertionErrors |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
280 def make_pdp_plotter(f): |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
281 def _plot(): |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
282 try: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
283 return explainer.plot_pdp(f) |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
284 except AssertionError as ae: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
285 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
286 return None |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
287 except Exception as e: |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
288 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") |
|
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
289 return None |
|
12
e674b9e946fb
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
goeckslab
parents:
8
diff
changeset
|
290 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
291 return _plot |
|
0
1f20fe57fdee
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff
changeset
|
292 |
|
8
1aed7d47c5ec
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
goeckslab
parents:
3
diff
changeset
|
293 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) |
|
17
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
294 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
295 def _get_test_predictions(self): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
296 """ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
297 Return y_true, y_pred, label list, and (optionally) positive-class |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
298 probabilities when available. Ensures predictions respect the optional |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
299 probability threshold for binary tasks. |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
300 """ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
301 y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
302 X_test = self.exp.X_test_transformed |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
303 prob_thresh = getattr(self, "probability_threshold", None) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
304 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
305 y_scores = None |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
306 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
307 proba = self.best_model.predict_proba(X_test) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
308 y_scores = proba |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
309 except Exception: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
310 LOG.debug("predict_proba unavailable for test predictions.") |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
311 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
312 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
313 if ( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
314 prob_thresh is not None |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
315 and not self.exp.is_multiclass |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
316 and y_scores is not None |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
317 and y_scores.ndim == 2 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
318 and y_scores.shape[1] > 1 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
319 ): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
320 classes = list(getattr(self.best_model, "classes_", [])) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
321 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
322 pos_idx = classes.index(1) if 1 in classes else 1 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
323 except Exception: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
324 pos_idx = 1 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
325 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
326 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
327 neg_label = classes[neg_idx] if len(classes) > neg_idx else 0 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
328 y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
329 y_scores = y_scores[:, pos_idx] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
330 else: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
331 y_pred = self.best_model.predict(X_test) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
332 except Exception as exc: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
333 LOG.warning("Falling back to raw predict for test predictions: %s", exc) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
334 y_pred = self.best_model.predict(X_test) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
335 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
336 y_pred = pd.Series(y_pred).reset_index(drop=True) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
337 if y_scores is not None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
338 y_scores = np.asarray(y_scores) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
339 if y_scores.ndim > 1 and y_scores.shape[1] == 1: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
340 y_scores = y_scores.ravel() |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
341 if self.exp.is_multiclass and y_scores.ndim > 1: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
342 # Avoid passing multiclass score matrices to ROC/PR utilities |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
343 y_scores = None |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
344 label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
345 return y_true, y_pred, label_values.tolist(), y_scores |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
346 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
347 def _threshold_suffix(self) -> str: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
348 """ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
349 Build a suffix like ' (threshold=0.50)' for binary tasks; omit for |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
350 multiclass where thresholds are not applied. |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
351 """ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
352 if getattr(self, "task_type", None) != "classification": |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
353 return "" |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
354 if getattr(self.exp, "is_multiclass", False): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
355 return "" |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
356 prob_thresh = getattr(self, "probability_threshold", None) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
357 if prob_thresh is None: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
358 return " (threshold=0.50)" |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
359 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
360 return f" (threshold={float(prob_thresh):.2f})" |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
361 except Exception: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
362 return f" (threshold={prob_thresh})" |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
363 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
364 def _build_confusion_matrix_fig(self, y_true, y_pred, labels): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
365 def _label_sort_key(lbl): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
366 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
367 return (0, float(lbl)) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
368 except Exception: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
369 return (1, str(lbl)) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
370 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
371 ordered_labels = sorted(labels, key=_label_sort_key) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
372 cm = confusion_matrix(y_true, y_pred, labels=ordered_labels) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
373 label_names = [str(lbl) for lbl in ordered_labels] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
374 fig_cm = go.Figure( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
375 data=go.Heatmap( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
376 z=cm, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
377 x=[f"Pred {lbl}" for lbl in label_names], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
378 y=[f"True {lbl}" for lbl in label_names], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
379 text=cm, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
380 texttemplate="%{text}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
381 colorscale="Blues", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
382 showscale=False, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
383 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
384 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
385 fig_cm.update_layout( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
386 title=f"Confusion Matrix{self._threshold_suffix()}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
387 xaxis_title=f"Predicted label ({self.target})", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
388 yaxis_title=f"True label ({self.target})", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
389 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
390 fig_cm.update_xaxes( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
391 type="category", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
392 categoryorder="array", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
393 categoryarray=[f"Pred {lbl}" for lbl in label_names], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
394 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
395 fig_cm.update_yaxes( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
396 type="category", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
397 categoryorder="array", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
398 categoryarray=[f"True {lbl}" for lbl in label_names], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
399 autorange="reversed", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
400 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
401 _apply_report_layout(fig_cm) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
402 return fig_cm |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
403 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
404 def _build_classification_report_fig(self, y_true, y_pred, labels): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
405 precision, recall, f1, support = precision_recall_fscore_support( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
406 y_true, y_pred, labels=labels, zero_division=0 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
407 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
408 mcc_scores = [] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
409 for lbl in labels: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
410 y_true_bin = (y_true == lbl).astype(int) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
411 y_pred_bin = (y_pred == lbl).astype(int) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
412 try: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
413 mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
414 except Exception: |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
415 mcc_val = 0.0 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
416 mcc_scores.append(mcc_val) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
417 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
418 label_names = [str(lbl) for lbl in labels] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
419 metrics = ["precision", "recall", "f1", "support"] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
420 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
421 max_support = float(max(support) if len(support) else 0) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
422 z_rows = [] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
423 text_rows = [] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
424 for i, lbl in enumerate(label_names): |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
425 norm_support = (support[i] / max_support) if max_support else 0.0 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
426 z_rows.append( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
427 [ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
428 precision[i], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
429 recall[i], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
430 f1[i], |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
431 norm_support, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
432 ] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
433 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
434 text_rows.append( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
435 [ |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
436 f"{precision[i]:.3f}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
437 f"{recall[i]:.3f}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
438 f"{f1[i]:.3f}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
439 f"{int(support[i])}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
440 ] |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
441 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
442 |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
443 fig = go.Figure( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
444 data=go.Heatmap( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
445 z=z_rows, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
446 x=metrics, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
447 y=label_names, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
448 colorscale="YlOrRd", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
449 zmin=0, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
450 zmax=1, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
451 colorbar=dict(title="Scale"), |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
452 text=text_rows, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
453 texttemplate="%{text}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
454 hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
455 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
456 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
457 fig.update_yaxes( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
458 title_text=f"Label ({self.target})", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
459 autorange="reversed", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
460 type="category", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
461 tickmode="array", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
462 tickvals=label_names, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
463 ticktext=label_names, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
464 showgrid=False, |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
465 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
466 fig.update_xaxes(title_text="", tickangle=45) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
467 fig.update_layout( |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
468 title=f"Per-Class Metrics{self._threshold_suffix()}", |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
469 margin=dict(l=70, r=60, t=70, b=80), |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
470 ) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
471 _apply_report_layout(fig) |
|
c5c324ac29fc
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
goeckslab
parents:
15
diff
changeset
|
472 return fig |
