Mercurial > repos > goeckslab > multimodal_learner
annotate plot_logic.py @ 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children |
| rev | line source |
|---|---|
|
0
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1 from __future__ import annotations |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
2 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
3 import html |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
4 import os |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
5 from html import escape as _escape |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
6 from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
7 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
8 import matplotlib.pyplot as plt |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
9 import numpy as np |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
10 import pandas as pd |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
11 import plotly.express as px |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
12 import plotly.graph_objects as go |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
13 import shap |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
14 from feature_help_modal import get_metrics_help_modal |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
15 from report_utils import build_tabbed_html, get_html_closing, get_html_template |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
16 from sklearn.calibration import calibration_curve |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
17 from sklearn.metrics import ( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
18 auc, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
19 average_precision_score, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
20 classification_report, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
21 confusion_matrix, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
22 log_loss, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
23 precision_recall_curve, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
24 roc_auc_score, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
25 roc_curve, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
26 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
27 from sklearn.model_selection import learning_curve as skl_learning_curve |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
28 from sklearn.preprocessing import label_binarize |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
29 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
30 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
31 # Utilities |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
32 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
33 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
34 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
35 def plot_with_table_style_title(fig, title: str) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
36 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
37 Render a Plotly figure with a report-style <h2> header so it matches the |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
38 green table section headers. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
39 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
40 # kill Plotly’s built-in title |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
41 fig.update_layout(title=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
42 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
43 # figure HTML without PlotlyJS (we load it once globally) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
44 plot_html = fig.to_html(full_html=False, include_plotlyjs=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
45 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
46 # use <h2> — your CSS already styles <h2> like the table headers |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
47 return f""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
48 <h2>{html.escape(title)}</h2> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
49 <div class="plotly-center">{plot_html}</div> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
50 """.strip() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
51 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
52 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
53 def _save_plotly(fig: go.Figure, path: Optional[str]) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
54 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
55 Save a Plotly figure. If `path` ends with `.html`, save interactive HTML. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
56 If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
57 If None, do nothing (caller may choose to display in notebook). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
58 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
59 if not path: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
60 return |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
61 ext = os.path.splitext(path)[1].lower() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
62 if ext == ".html": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
63 fig.write_html(path, include_plotlyjs="cdn", full_html=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
64 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
65 # Requires kaleido: pip install -U kaleido |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
66 fig.write_image(path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
67 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
68 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
69 def _save_matplotlib(path: Optional[str]) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
70 """Save current Matplotlib figure if `path` is provided, else show().""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
71 if path: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
72 plt.savefig(path, bbox_inches="tight") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
73 plt.close() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
74 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
75 plt.show() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
76 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
77 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
78 # Classification Plots |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
79 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
80 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
81 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
82 def generate_confusion_matrix_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
83 y_true, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
84 y_pred, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
85 title: str = "Confusion Matrix", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
86 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
87 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
88 y_pred = np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
89 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
90 # Class order (works for strings or numbers) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
91 labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
92 cm = confusion_matrix(y_true, y_pred, labels=labels) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
93 max_val = cm.max() if cm.size else 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
94 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
95 # Use categorical axes by passing string labels for x/y |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
96 cats = [str(label) for label in labels] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
97 total = int(cm.sum()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
98 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
99 fig = go.Figure( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
100 data=go.Heatmap( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
101 z=cm, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
102 x=cats, # categorical x |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
103 y=cats, # categorical y |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
104 colorscale="Blues", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
105 showscale=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
106 colorbar=dict(title="Count"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
107 xgap=2, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
108 ygap=2, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
109 hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
110 zmin=0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
111 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
112 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
113 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
114 # Add annotations with count and percentage (all white text, matching sample_output.html) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
115 annotations = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
116 for i in range(cm.shape[0]): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
117 for j in range(cm.shape[1]): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
118 val = int(cm[i, j]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
119 pct = (val / total * 100) if total > 0 else 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
120 text_color = "white" if max_val and val > (max_val / 2) else "black" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
121 # Count annotation (bold, bottom) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
122 annotations.append( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
123 dict( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
124 x=cats[j], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
125 y=cats[i], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
126 text=f"<b>{val}</b>", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
127 showarrow=False, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
128 font=dict(color=text_color, size=14), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
129 xanchor="center", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
130 yanchor="bottom", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
131 yshift=2 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
132 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
133 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
134 # Percentage annotation (top) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
135 annotations.append( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
136 dict( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
137 x=cats[j], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
138 y=cats[i], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
139 text=f"{pct:.1f}%", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
140 showarrow=False, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
141 font=dict(color=text_color, size=13), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
142 xanchor="center", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
143 yanchor="top", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
144 yshift=-2 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
145 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
146 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
147 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
148 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
149 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
150 xaxis_title="Predicted label", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
151 yaxis_title="True label", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
152 xaxis=dict(type="category"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
153 yaxis=dict(type="category", autorange="reversed"), # typical CM orientation |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
154 margin=dict(l=80, r=20, t=40, b=80), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
155 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
156 plot_bgcolor="white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
157 paper_bgcolor="white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
158 annotations=annotations |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
159 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
160 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
161 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
162 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
163 def generate_roc_curve_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
164 y_true_bin: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
165 y_score: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
166 title: str = "ROC Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
167 marker_threshold: float | None = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
168 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
169 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
170 y_score = np.asarray(y_score).astype(float).reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
171 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
172 fpr, tpr, thr = roc_curve(y_true_bin, y_score) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
173 roc_auc = auc(fpr, tpr) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
174 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
175 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
176 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
177 x=fpr, y=tpr, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
178 name=f"ROC (AUC={roc_auc:.3f})", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
179 line=dict(width=3) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
180 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
181 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
182 # 45° chance line (no legend to keep it clean) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
183 fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
184 line=dict(dash="dash", width=2, color="#888"), showlegend=False)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
185 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
186 # Optional marker at the user threshold |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
187 if marker_threshold is not None and len(thr): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
188 # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
189 finite = np.isfinite(thr) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
190 if np.any(finite): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
191 idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
192 idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
193 x_m, y_m = float(fpr[idx]), float(tpr[idx]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
194 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
195 fig.add_trace( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
196 go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
197 x=[x_m], y=[y_m], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
198 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
199 name=f"@ {float(marker_threshold):.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
200 marker=dict(size=12, color="red", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
201 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
202 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
203 fig.add_annotation( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
204 x=0.02, y=0.98, xref="paper", yref="paper", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
205 text=f"threshold = {float(marker_threshold):.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
206 showarrow=False, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
207 font=dict(color="black", size=12), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
208 align="left" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
209 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
210 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
211 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
212 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
213 xaxis_title="False Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
214 yaxis_title="True Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
215 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
216 legend=dict(x=1, y=0, xanchor="right"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
217 margin=dict(l=60, r=20, t=60, b=60), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
218 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
219 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
220 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
221 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
222 def generate_pr_curve_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
223 y_true_bin: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
224 y_score: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
225 title: str = "Precision–Recall Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
226 marker_threshold: float | None = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
227 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
228 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
229 y_score = np.asarray(y_score).astype(float).reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
230 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
231 precision, recall, thr = precision_recall_curve(y_true_bin, y_score) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
232 pr_auc = auc(recall, precision) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
233 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
234 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
235 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
236 x=recall, y=precision, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
237 name=f"PR (AUC={pr_auc:.3f})", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
238 line=dict(width=3) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
239 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
240 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
241 # Optional marker at the user threshold |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
242 if marker_threshold is not None and len(thr): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
243 # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
244 j = int(np.argmin(np.abs(thr - float(marker_threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
245 j = int(np.clip(j, 0, len(thr) - 1)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
246 x_m, y_m = float(recall[j + 1]), float(precision[j + 1]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
247 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
248 fig.add_trace( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
249 go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
250 x=[x_m], y=[y_m], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
251 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
252 name=f"@ {float(marker_threshold):.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
253 marker=dict(size=12, color="red", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
254 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
255 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
256 fig.add_annotation( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
257 x=0.02, y=0.98, xref="paper", yref="paper", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
258 text=f"threshold = {float(marker_threshold):.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
259 showarrow=False, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
260 font=dict(color="black", size=12), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
261 align="left" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
262 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
263 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
264 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
265 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
266 xaxis_title="Recall", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
267 yaxis_title="Precision", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
268 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
269 legend=dict(x=1, y=0, xanchor="right"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
270 margin=dict(l=60, r=20, t=60, b=60), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
271 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
272 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
273 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
274 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
275 def generate_calibration_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
276 y_true_bin: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
277 y_prob: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
278 n_bins: int = 10, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
279 title: str = "Calibration Plot", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
280 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
281 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
282 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
283 Binary calibration curve (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
284 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
285 prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
286 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
287 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
288 x=prob_pred, y=prob_true, mode="lines+markers", name="Model", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
289 line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
290 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
291 fig.add_trace( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
292 go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
293 x=[0, 1], y=[0, 1], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
294 mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
295 line=dict(dash="dash", color="#808080", width=2), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
296 name="Perfect" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
297 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
298 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
299 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
300 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
301 xaxis_title="Predicted Probability", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
302 yaxis_title="Observed Probability", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
303 yaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
304 xaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
305 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
306 margin=dict(l=60, r=40, t=50, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
307 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
308 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
309 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
310 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
311 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
312 def generate_threshold_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
313 y_true_bin: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
314 y_prob: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
315 title: str = "Threshold Plot", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
316 user_threshold: float | None = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
317 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
318 y_true = np.asarray(y_true_bin, dtype=int).ravel() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
319 p = np.asarray(y_prob, dtype=float).ravel() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
320 p = np.nan_to_num(p, nan=0.0) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
321 p = np.clip(p, 0.0, 1.0) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
322 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
323 def _compute_metrics(thresholds: np.ndarray): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
324 """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays.""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
325 prec, rec, f1, qrate = [], [], [], [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
326 for t in thresholds: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
327 yhat = (p >= t).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
328 tp = int(((yhat == 1) & (y_true == 1)).sum()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
329 fp = int(((yhat == 1) & (y_true == 0)).sum()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
330 fn = int(((yhat == 0) & (y_true == 1)).sum()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
331 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
332 pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
333 rc = tp / (tp + fn) if (tp + fn) else 0.0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
334 f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
335 q = float(yhat.mean()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
336 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
337 prec.append(pr) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
338 rec.append(rc) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
339 f1.append(f) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
340 qrate.append(q) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
341 return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
342 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
343 # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
344 th = np.linspace(0.0, 1.0, 101) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
345 prec, rec, f1_arr, qrate = _compute_metrics(th) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
346 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
347 # Compute F1*-optimal threshold using actual score distribution (more precise than grid) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
348 cand_th = np.unique(np.concatenate(([0.0, 1.0], p))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
349 # cap to a reasonable size by sampling if extremely large |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
350 if cand_th.size > 2000: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
351 cand_th = np.linspace(0.0, 1.0, 2001) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
352 _, _, f1_cand, _ = _compute_metrics(cand_th) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
353 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
354 if np.all(np.isnan(f1_cand)): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
355 t_star = 0.5 # fallback when no valid F1 can be computed |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
356 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
357 f1_max = np.nanmax(f1_cand) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
358 best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
359 # pick the middle of the best candidates to avoid biasing toward 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
360 best_idx = int(best_idxs[len(best_idxs) // 2]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
361 t_star = float(cand_th[best_idx]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
362 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
363 # Replace NaNs for plotting (set to 0 where precision is undefined) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
364 prec_plot = np.nan_to_num(prec, nan=0.0) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
365 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
366 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
367 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
368 # Precision (blue line) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
369 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
370 x=th, y=prec_plot, mode="lines", name="Precision", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
371 line=dict(width=3, color="#1f77b4"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
372 hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
373 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
374 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
375 # Recall (orange line) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
376 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
377 x=th, y=rec, mode="lines", name="Recall", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
378 line=dict(width=3, color="#ff7f0e"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
379 hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
380 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
381 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
382 # F1 (green line) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
383 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
384 x=th, y=f1_arr, mode="lines", name="F1", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
385 line=dict(width=3, color="#2ca02c"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
386 hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
387 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
388 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
389 # Queue Rate (grey dashed line) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
390 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
391 x=th, y=qrate, mode="lines", name="Queue Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
392 line=dict(width=2, color="#808080", dash="dash"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
393 hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
394 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
395 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
396 # F1*-optimal threshold marker (dashed vertical line) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
397 fig.add_vline( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
398 x=t_star, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
399 line_width=2, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
400 line_dash="dash", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
401 line_color="black", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
402 annotation_text=f"t* = {t_star:.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
403 annotation_position="top" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
404 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
405 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
406 # User threshold (solid red line) if provided |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
407 if user_threshold is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
408 fig.add_vline( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
409 x=float(user_threshold), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
410 line_width=2, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
411 line_color="red", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
412 annotation_text=f"threshold = {float(user_threshold):.2f}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
413 annotation_position="top" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
414 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
415 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
416 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
417 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
418 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
419 xaxis=dict( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
420 title="Discrimination Threshold", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
421 range=[0, 1], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
422 gridcolor="#e0e0e0", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
423 showgrid=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
424 zeroline=False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
425 ), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
426 yaxis=dict( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
427 title="Score", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
428 range=[0, 1], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
429 gridcolor="#e0e0e0", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
430 showgrid=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
431 zeroline=False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
432 ), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
433 legend=dict( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
434 orientation="h", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
435 yanchor="bottom", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
436 y=1.02, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
437 xanchor="right", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
438 x=1.0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
439 ), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
440 margin=dict(l=60, r=20, t=40, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
441 plot_bgcolor="white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
442 paper_bgcolor="white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
443 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
444 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
445 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
446 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
447 def generate_per_class_metrics_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
448 y_true: Sequence, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
449 y_pred: Sequence, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
450 metrics: Sequence[str] = ("precision", "recall", "f1_score"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
451 title: str = "Classification Report", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
452 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
453 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
454 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
455 Per-class metrics heatmap (Plotly), similar to sklearn's classification report. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
456 Rows = classes, columns = metrics; cell text shows the value (0–1). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
457 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
458 # Map display names -> sklearn keys |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
459 key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
460 report = classification_report( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
461 y_true, y_pred, output_dict=True, zero_division=0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
462 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
463 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
464 # Order classes sensibly (numeric if possible, else lexical) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
465 def _sort_key(x): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
466 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
467 return (0, float(x)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
468 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
469 return (1, str(x)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
470 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
471 # Use all classes seen in y_true or y_pred (so rows don't jump around) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
472 uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
473 classes = [str(c) for c in uniq] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
474 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
475 # Build Z matrix (rows=classes, cols=metrics) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
476 used_metrics = [key_map.get(m, m) for m in metrics] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
477 z = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
478 for c in classes: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
479 row = report.get(c, {}) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
480 z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
481 z = np.array(z, dtype=float) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
482 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
483 # Pretty cell labels |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
484 z_text = [[f"{v:.2f}" for v in r] for r in z] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
485 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
486 fig = go.Figure( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
487 data=go.Heatmap( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
488 z=z, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
489 x=list(metrics), # keep display names ("precision", "recall", "f1_score") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
490 y=classes, # classes as strings |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
491 colorscale="Reds", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
492 zmin=0.0, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
493 zmax=1.0, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
494 colorbar=dict(title="Value"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
495 text=z_text, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
496 texttemplate="%{text}", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
497 hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
498 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
499 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
500 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
501 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
502 xaxis_title="", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
503 yaxis_title="Class", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
504 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
505 margin=dict(l=60, r=60, t=70, b=40), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
506 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
507 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
508 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
509 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
510 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
511 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
512 def generate_multiclass_roc_curve_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
513 y_true: Sequence, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
514 y_prob: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
515 classes: Sequence, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
516 title: str = "Multiclass ROC Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
517 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
518 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
519 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
520 One-vs-rest ROC curves for multiclass (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
521 Handles binary passed as 2-column probs as well. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
522 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
523 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
524 y_prob = np.asarray(y_prob) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
525 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
526 # Normalize to shape (n_samples, n_classes) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
527 if y_prob.ndim == 1 or y_prob.shape[1] == 1: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
528 y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
529 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
530 y_true_bin = label_binarize(y_true, classes=classes) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
531 if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
532 y_true_bin = np.hstack([1 - y_true_bin, y_true_bin]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
533 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
534 if y_prob.shape[1] != y_true_bin.shape[1]: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
535 raise ValueError( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
536 f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}." |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
537 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
538 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
539 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
540 for i, cls in enumerate(classes): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
541 fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
542 auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
543 fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
544 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
545 fig.add_trace( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
546 go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
547 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
548 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
549 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
550 xaxis_title="False Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
551 yaxis_title="True Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
552 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
553 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
554 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
555 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
556 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
557 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
558 def generate_multiclass_pr_curve_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
559 y_true: Sequence, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
560 y_prob: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
561 classes: Optional[Sequence] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
562 title: str = "Precision–Recall Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
563 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
564 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
565 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
566 Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
567 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
568 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
569 y_prob = np.asarray(y_prob) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
570 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
571 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
572 if not classes or len(classes) == 2: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
573 precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
574 ap = average_precision_score(y_true, y_prob[:, 1]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
575 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
576 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
577 for i, cls in enumerate(classes): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
578 y_true_bin = (y_true == cls).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
579 y_prob_cls = y_prob[:, i] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
580 precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
581 ap = average_precision_score(y_true_bin, y_prob_cls) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
582 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
583 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
584 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
585 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
586 xaxis_title="Recall", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
587 yaxis_title="Precision", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
588 yaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
589 xaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
590 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
591 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
592 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
593 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
594 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
595 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
596 def generate_metric_comparison_bar( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
597 metrics_scores: Mapping[str, Sequence[float]], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
598 phases: Sequence[str] = ("train", "val", "test"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
599 title: str = "Metric Comparison Across Phases", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
600 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
601 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
602 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
603 Grouped bar chart comparing metrics across phases (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
604 metrics_scores: {metric_name: [train, val, test]} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
605 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
606 df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"}) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
607 df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
608 fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
609 ymax = max(1.0, df_m["Score"].max() * 1.05) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
610 fig.update_yaxes(range=[0, ymax]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
611 fig.update_layout(template="plotly_white") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
612 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
613 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
614 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
615 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
616 # Regression Plots |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
617 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
618 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
619 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
620 def generate_scatter_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
621 y_true: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
622 y_pred: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
623 title: str = "Predicted vs Actual", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
624 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
625 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
626 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
627 Predicted vs. Actual scatter with y=x reference (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
628 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
629 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
630 y_pred = np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
631 vmin = float(min(np.min(y_true), np.min(y_pred))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
632 vmax = float(max(np.max(y_true), np.max(y_pred))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
633 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
634 fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
635 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
636 fig.update_layout(template="plotly_white") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
637 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
638 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
639 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
640 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
641 def generate_residual_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
642 y_true: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
643 y_pred: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
644 title: str = "Residual Plot", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
645 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
646 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
647 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
648 Residuals vs Predicted (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
649 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
650 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
651 y_pred = np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
652 residuals = y_true - y_pred |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
653 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
654 fig = px.scatter(x=y_pred, y=residuals, opacity=0.6, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
655 labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"}, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
656 title=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
657 fig.add_hline(y=0, line_dash="dash") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
658 fig.update_layout(template="plotly_white") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
659 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
660 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
661 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
662 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
663 def generate_residual_histogram( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
664 y_true: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
665 y_pred: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
666 bins: int = 30, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
667 title: str = "Residual Histogram", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
668 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
669 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
670 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
671 Residuals histogram (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
672 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
673 residuals = np.asarray(y_true) - np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
674 fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
675 fig.update_layout(yaxis_title="Frequency", template="plotly_white") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
676 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
677 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
678 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
679 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
680 def generate_regression_calibration_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
681 y_true: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
682 y_pred: Sequence[float], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
683 num_bins: int = 10, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
684 title: str = "Regression Calibration Plot", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
685 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
686 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
687 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
688 Binned Actual vs Predicted means (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
689 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
690 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
691 y_pred = np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
692 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
693 order = np.argsort(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
694 y_true_sorted = y_true[order] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
695 y_pred_sorted = y_pred[order] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
696 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
697 bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
698 bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
699 bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
700 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
701 vmin = float(min(np.min(y_pred), np.min(y_true))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
702 vmax = float(max(np.max(y_pred), np.max(y_true))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
703 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
704 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
705 fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
706 name="Binned Actual vs Predicted")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
707 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
708 name="Ideal")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
709 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
710 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
711 xaxis_title="Mean Predicted per bin", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
712 yaxis_title="Mean Actual per bin", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
713 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
714 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
715 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
716 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
717 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
718 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
719 # Confidence / Diagnostics |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
720 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
721 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
722 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
723 def plot_error_vs_confidence( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
724 y_true: Union[Sequence[int], np.ndarray], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
725 y_proba: Union[Sequence[float], np.ndarray], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
726 n_bins: int = 10, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
727 title: str = "Error vs Confidence", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
728 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
729 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
730 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
731 Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
732 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
733 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
734 y_proba = np.asarray(y_proba).reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
735 y_pred = (y_proba >= 0.5).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
736 confidence = np.maximum(y_proba, 1 - y_proba) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
737 error = (y_pred != y_true).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
738 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
739 bins = np.linspace(0.0, 1.0, n_bins + 1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
740 idx = np.digitize(confidence, bins, right=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
741 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
742 centers, err_rates = [], [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
743 for i in range(1, len(bins)): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
744 mask = (idx == i) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
745 if mask.any(): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
746 centers.append(float(confidence[mask].mean())) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
747 err_rates.append(float(error[mask].mean())) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
748 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
749 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
750 fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
751 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
752 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
753 xaxis_title="Confidence (max predicted probability)", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
754 yaxis_title="Error Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
755 yaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
756 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
757 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
758 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
759 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
760 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
761 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
762 def plot_confidence_histogram( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
763 y_proba: np.ndarray, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
764 bins: int = 20, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
765 title: str = "Confidence Histogram", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
766 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
767 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
768 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
769 Histogram of max predicted probabilities (Plotly). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
770 Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
771 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
772 y_proba = np.asarray(y_proba) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
773 if y_proba.ndim == 1: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
774 confidences = np.maximum(y_proba, 1 - y_proba) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
775 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
776 confidences = np.max(y_proba, axis=1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
777 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
778 fig = px.histogram( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
779 x=confidences, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
780 nbins=bins, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
781 range_x=(0, 1), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
782 histnorm="percent", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
783 labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"}, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
784 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
785 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
786 if fig.data: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
787 fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
788 fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
789 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
790 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
791 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
792 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
793 # Learning Curve |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
794 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
795 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
796 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
797 def generate_learning_curve_from_predictions( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
798 y_true, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
799 y_pred=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
800 y_proba=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
801 classes=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
802 metric: str = "accuracy", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
803 train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
804 n_repeats: int = 5, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
805 seed: int = 42, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
806 title: str = "Learning Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
807 path: str | None = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
808 return_stats: bool = False, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
809 ) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
810 rng = np.random.default_rng(seed) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
811 y_true = np.asarray(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
812 N = len(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
813 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
814 if metric == "accuracy" and y_pred is None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
815 raise ValueError("accuracy curve requires y_pred") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
816 if metric == "log_loss" and y_proba is None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
817 raise ValueError("log_loss curve requires y_proba") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
818 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
819 if y_proba is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
820 y_proba = np.asarray(y_proba) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
821 if y_pred is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
822 y_pred = np.asarray(y_pred) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
823 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
824 sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
825 means, stds = [], [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
826 for n in sizes: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
827 vals = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
828 for _ in range(n_repeats): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
829 idx = rng.choice(N, size=n, replace=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
830 if metric == "accuracy": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
831 vals.append(float((y_true[idx] == y_pred[idx]).mean())) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
832 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
833 if y_proba.ndim == 1: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
834 p = y_proba[idx] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
835 pp = np.column_stack([1 - p, p]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
836 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
837 pp = y_proba[idx] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
838 vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
839 means.append(np.mean(vals)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
840 stds.append(np.std(vals)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
841 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
842 if return_stats: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
843 return sizes, means, stds |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
844 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
845 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
846 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
847 x=sizes, y=means, mode="lines+markers", name="Train", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
848 line=dict(width=3, shape="spline"), marker=dict(size=7), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
849 error_y=dict(type="data", array=stds, visible=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
850 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
851 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
852 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
853 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
854 xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
855 yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
856 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
857 margin=dict(l=50, r=20, t=60, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
858 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
859 if path: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
860 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
861 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
862 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
863 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
864 def build_train_html_and_plots( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
865 predictor, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
866 problem_type: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
867 df_train: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
868 label_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
869 tmpdir: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
870 df_val: Optional[pd.DataFrame] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
871 seed: int = 42, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
872 perf_table_html: str | None = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
873 threshold: Optional[float] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
874 section_tile: str = "Training Diagnostics", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
875 ) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
876 y_true = df_train[label_column].values |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
877 y_true_val = df_val[label_column].values if df_val is not None else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
878 # predictions on TRAIN |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
879 pred_labels, pred_proba = None, None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
880 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
881 pred_labels = predictor.predict(df_train) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
882 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
883 pass |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
884 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
885 proba_raw = predictor.predict_proba(df_train) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
886 pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
887 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
888 pred_proba = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
889 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
890 # predictions on VAL (if provided) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
891 pred_labels_val, pred_proba_val = None, None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
892 if df_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
893 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
894 pred_labels_val = predictor.predict(df_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
895 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
896 pred_labels_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
897 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
898 proba_raw_val = predictor.predict_proba(df_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
899 pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
900 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
901 pred_proba_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
902 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
903 pos_scores_train: Optional[np.ndarray] = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
904 pos_scores_val: Optional[np.ndarray] = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
905 if problem_type == "binary": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
906 if pred_proba is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
907 pos_scores_train = ( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
908 pred_proba.reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
909 if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
910 else pred_proba[:, -1] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
911 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
912 if pred_proba_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
913 pos_scores_val = ( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
914 pred_proba_val.reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
915 if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
916 else pred_proba_val[:, -1] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
917 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
918 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
919 # Collect plots then append in desired order |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
920 perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
921 acc_plot = loss_plot = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
922 cm_train = pc_train = cm_val = pc_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
923 threshold_val_plot = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
924 roc_combined = pr_combined = cal_combined = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
925 mc_roc_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
926 conf_train = conf_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
927 bar_train = bar_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
928 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
929 # 1) Learning Curve — Accuracy |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
930 if problem_type in ("binary", "multiclass"): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
931 acc_fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
932 added_acc = False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
933 if pred_labels is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
934 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
935 y_true=y_true, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
936 y_pred=np.asarray(pred_labels), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
937 metric="accuracy", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
938 title="Learning Curves — Label Accuracy", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
939 seed=seed, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
940 return_stats=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
941 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
942 acc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
943 x=train_sizes, y=train_means, mode="lines+markers", name="Train", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
944 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
945 error_y=dict(type="data", array=train_stds, visible=True), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
946 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
947 added_acc = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
948 if pred_labels_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
949 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
950 y_true=y_true_val, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
951 y_pred=np.asarray(pred_labels_val), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
952 metric="accuracy", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
953 title="Learning Curves — Label Accuracy", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
954 seed=seed, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
955 return_stats=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
956 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
957 acc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
958 x=val_sizes, y=val_means, mode="lines+markers", name="Validation", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
959 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
960 error_y=dict(type="data", array=val_stds, visible=True), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
961 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
962 added_acc = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
963 if added_acc: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
964 acc_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
965 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
966 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
967 xaxis=dict(title="samples", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
968 yaxis=dict(title="accuracy", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
969 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
970 margin=dict(l=50, r=20, t=60, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
971 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
972 acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
973 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
974 # 2) Learning Curve — Loss |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
975 if problem_type in ("binary", "multiclass"): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
976 classes = np.unique(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
977 loss_fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
978 added_loss = False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
979 if pred_proba is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
980 pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
981 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
982 y_true=y_true, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
983 y_proba=pp, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
984 classes=classes, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
985 metric="log_loss", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
986 title="Learning Curves — Label Loss", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
987 seed=seed, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
988 return_stats=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
989 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
990 loss_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
991 x=train_sizes, y=train_means, mode="lines+markers", name="Train", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
992 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
993 error_y=dict(type="data", array=train_stds, visible=True), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
994 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
995 added_loss = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
996 if pred_proba_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
997 pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
998 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
999 y_true=y_true_val, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1000 y_proba=pp_val, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1001 classes=classes, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1002 metric="log_loss", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1003 title="Learning Curves — Label Loss", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1004 seed=seed, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1005 return_stats=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1006 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1007 loss_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1008 x=val_sizes, y=val_means, mode="lines+markers", name="Validation", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1009 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1010 error_y=dict(type="data", array=val_stds, visible=True), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1011 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1012 added_loss = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1013 if added_loss: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1014 loss_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1015 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1016 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1017 xaxis=dict(title="epoch", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1018 yaxis=dict(title="loss", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1019 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1020 margin=dict(l=50, r=20, t=60, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1021 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1022 loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1023 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1024 # Confusion matrices & per-class metrics |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1025 cm_train = pc_train = cm_val = pc_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1026 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1027 # Probability diagnostics (binary) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1028 if problem_type == "binary": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1029 # Combined Calibration (Train/Val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1030 cal_fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1031 added_cal = False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1032 if pos_scores_train is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1033 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1034 prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1035 cal_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1036 x=prob_pred, y=prob_true, mode="lines+markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1037 name="Train", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1038 line=dict(color="#1f77b4", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1039 marker=dict(size=7, color="#1f77b4"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1040 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1041 added_cal = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1042 if pos_scores_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1043 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1044 prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1045 cal_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1046 x=prob_pred_v, y=prob_true_v, mode="lines+markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1047 name="Validation", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1048 line=dict(color="#ff7f0e", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1049 marker=dict(size=7, color="#ff7f0e"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1050 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1051 added_cal = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1052 if added_cal: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1053 cal_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1054 x=[0, 1], y=[0, 1], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1055 mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1056 line=dict(dash="dash", color="#808080", width=2), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1057 name="Perfect", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1058 showlegend=True, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1059 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1060 cal_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1061 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1062 xaxis_title="Predicted Probability", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1063 yaxis_title="Observed Probability", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1064 xaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1065 yaxis=dict(range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1066 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1067 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1068 margin=dict(l=60, r=40, t=50, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1069 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1070 cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1071 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1072 # Combined ROC (Train/Val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1073 roc_fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1074 added_roc = False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1075 if pos_scores_train is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1076 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1077 fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1078 roc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1079 x=fpr_tr, y=tpr_tr, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1080 name="Train", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1081 line=dict(color="#1f77b4", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1082 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1083 if threshold is not None and np.isfinite(thr_tr).any(): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1084 finite = np.isfinite(thr_tr) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1085 idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1086 idx = int(np.nonzero(finite)[0][idx_local]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1087 roc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1088 x=[fpr_tr[idx]], y=[tpr_tr[idx]], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1089 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1090 name="Train @ threshold", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1091 marker=dict(size=12, color="#1f77b4", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1092 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1093 added_roc = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1094 if pos_scores_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1095 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1096 fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1097 roc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1098 x=fpr_v, y=tpr_v, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1099 name="Validation", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1100 line=dict(color="#ff7f0e", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1101 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1102 if threshold is not None and np.isfinite(thr_v).any(): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1103 finite = np.isfinite(thr_v) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1104 idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1105 idx = int(np.nonzero(finite)[0][idx_local]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1106 roc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1107 x=[fpr_v[idx]], y=[tpr_v[idx]], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1108 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1109 name="Val @ threshold", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1110 marker=dict(size=12, color="#ff7f0e", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1111 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1112 added_roc = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1113 if added_roc: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1114 roc_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1115 x=[0, 1], y=[0, 1], mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1116 line=dict(dash="dash", width=2, color="#808080"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1117 showlegend=False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1118 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1119 roc_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1120 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1121 xaxis_title="False Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1122 yaxis_title="True Positive Rate", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1123 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1124 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1125 margin=dict(l=60, r=20, t=60, b=60), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1126 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1127 roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1128 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1129 # Combined PR (Train/Val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1130 pr_fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1131 added_pr = False |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1132 if pos_scores_train is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1133 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1134 prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1135 pr_auc_tr = auc(rec_tr, prec_tr) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1136 pr_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1137 x=rec_tr, y=prec_tr, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1138 name=f"Train (AUC={pr_auc_tr:.3f})", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1139 line=dict(color="#1f77b4", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1140 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1141 if threshold is not None and len(thr_tr): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1142 j = int(np.argmin(np.abs(thr_tr - float(threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1143 j = int(np.clip(j, 0, len(thr_tr) - 1)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1144 pr_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1145 x=[rec_tr[j + 1]], y=[prec_tr[j + 1]], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1146 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1147 name="Train @ threshold", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1148 marker=dict(size=12, color="#1f77b4", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1149 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1150 added_pr = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1151 if pos_scores_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1152 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1153 prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1154 pr_auc_v = auc(rec_v, prec_v) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1155 pr_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1156 x=rec_v, y=prec_v, mode="lines", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1157 name=f"Validation (AUC={pr_auc_v:.3f})", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1158 line=dict(color="#ff7f0e", width=3), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1159 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1160 if threshold is not None and len(thr_v): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1161 j = int(np.argmin(np.abs(thr_v - float(threshold)))) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1162 j = int(np.clip(j, 0, len(thr_v) - 1)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1163 pr_fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1164 x=[rec_v[j + 1]], y=[prec_v[j + 1]], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1165 mode="markers", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1166 name="Val @ threshold", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1167 marker=dict(size=12, color="#ff7f0e", symbol="x") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1168 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1169 added_pr = True |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1170 if added_pr: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1171 pr_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1172 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1173 xaxis_title="Recall", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1174 yaxis_title="Precision", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1175 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1176 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1177 margin=dict(l=60, r=20, t=60, b=60), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1178 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1179 pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1180 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1181 if pos_scores_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1182 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1183 fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1184 user_threshold=threshold) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1185 threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1186 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1187 # Multiclass OVR ROC (validation) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1188 if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1189 classes_val = np.unique(y_true_val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1190 fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1191 mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1192 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1193 # Prediction Confidence Histogram (train/val) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1194 conf_train = conf_val = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1195 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1196 # Per-class accuracy bars |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1197 if problem_type in ("binary", "multiclass") and pred_labels is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1198 classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1199 acc_vals = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1200 for c in classes_for_bar: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1201 mask = y_true == c |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1202 acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1203 bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1204 bar_fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1205 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1206 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1207 xaxis=dict(title="Label", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1208 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1209 margin=dict(l=50, r=20, t=60, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1210 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1211 bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1212 if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1213 classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1214 acc_vals_val = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1215 for c in classes_for_bar_val: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1216 mask = y_true_val == c |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1217 acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1218 bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1219 bar_fig_val.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1220 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1221 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1222 xaxis=dict(title="Label", gridcolor="#eee"), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1223 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1224 margin=dict(l=50, r=20, t=60, b=50), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1225 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1226 bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1227 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1228 # Assemble in requested order |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1229 pieces: list[str] = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1230 if perf_card: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1231 pieces.append(perf_card) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1232 for block in (threshold_val_plot, roc_combined, pr_combined): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1233 if block: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1234 pieces.append(block) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1235 # Remaining plots (keep existing order) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1236 for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1237 if block: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1238 pieces.append(block) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1239 # Learning curves should appear last in the tab |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1240 for block in (acc_plot, loss_plot): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1241 if block: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1242 pieces.append(block) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1243 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1244 if not pieces: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1245 return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1246 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1247 return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1248 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1249 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1250 def generate_learning_curve( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1251 estimator, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1252 X, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1253 y, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1254 scoring: str = "r2", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1255 cv_folds: int = 5, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1256 n_jobs: int = -1, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1257 train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10), |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1258 title: str = "Learning Curve", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1259 path: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1260 ) -> go.Figure: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1261 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1262 Learning curve using sklearn.learning_curve, visualized with Plotly. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1263 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1264 sizes, train_scores, test_scores = skl_learning_curve( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1265 estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1266 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1267 train_mean = train_scores.mean(axis=1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1268 train_std = train_scores.std(axis=1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1269 test_mean = test_scores.mean(axis=1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1270 test_std = test_scores.std(axis=1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1271 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1272 fig = go.Figure() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1273 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1274 x=sizes, y=train_mean, mode="lines+markers", name="Training score", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1275 error_y=dict(type="data", array=train_std, visible=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1276 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1277 fig.add_trace(go.Scatter( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1278 x=sizes, y=test_mean, mode="lines+markers", name="CV score", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1279 error_y=dict(type="data", array=test_std, visible=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1280 )) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1281 fig.update_layout( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1282 title=None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1283 xaxis_title="Training examples", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1284 yaxis_title=scoring, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1285 template="plotly_white", |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1286 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1287 _save_plotly(fig, path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1288 return fig |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1289 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1290 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1291 # SHAP (Matplotlib-based) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1292 # ========================= |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1293 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1294 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1295 def generate_shap_summary_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1296 shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1297 ) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1298 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1299 SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited; |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1300 keep matplotlib for clarity and stability. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1301 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1302 plt.figure(figsize=(10, 8)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1303 shap.summary_plot(shap_values, features, show=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1304 plt.title(title) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1305 _save_matplotlib(path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1306 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1307 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1308 def generate_shap_force_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1309 explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1310 ) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1311 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1312 SHAP force plot (Matplotlib). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1313 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1314 shap_values = explainer(instance) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1315 plt.figure(figsize=(10, 4)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1316 shap.plots.force(shap_values[0], show=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1317 plt.title(title) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1318 _save_matplotlib(path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1319 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1320 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1321 def generate_shap_waterfall_plot( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1322 explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1323 ) -> None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1324 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1325 SHAP waterfall plot (Matplotlib). |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1326 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1327 shap_values = explainer(instance) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1328 plt.figure(figsize=(10, 6)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1329 shap.plots.waterfall(shap_values[0], show=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1330 plt.title(title) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1331 _save_matplotlib(path) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1332 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1333 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1334 def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1335 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1336 Return 'binary', 'multiclass', or 'regression'. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1337 Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1338 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1339 # AutoGluon predictors usually expose .problem_type; be defensive. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1340 pt = getattr(predictor, "problem_type", None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1341 if isinstance(pt, str): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1342 pt_l = pt.lower() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1343 if "regression" in pt_l: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1344 return "regression" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1345 if "binary" in pt_l: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1346 return "binary" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1347 if "multiclass" in pt_l or "multiclass" in pt_l: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1348 return "multiclass" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1349 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1350 y = df_train_full[label_column] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1351 if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1352 return "regression" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1353 return "binary" if y.nunique() == 2 else "multiclass" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1354 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1355 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1356 def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1357 """Make evaluate() outputs JSON/csv friendly floats.""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1358 out = {} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1359 for k, v in d.items(): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1360 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1361 out[k] = float(v) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1362 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1363 # keep only real-valued scalars |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1364 pass |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1365 return out |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1366 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1367 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1368 def evaluate_all( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1369 predictor, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1370 df_train: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1371 df_val: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1372 df_test: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1373 label_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1374 problem_type: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1375 ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1376 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1377 Run predictor.evaluate on train/val/test and normalize the result dicts to floats. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1378 MultiModalPredictor does not accept the `silent` kwarg, so call defensively. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1379 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1380 def _evaluate(df): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1381 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1382 return predictor.evaluate(df, silent=True) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1383 except TypeError: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1384 return predictor.evaluate(df) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1385 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1386 train_scores = _safe_floatify(_evaluate(df_train)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1387 val_scores = _safe_floatify(_evaluate(df_val)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1388 test_scores = _safe_floatify(_evaluate(df_test)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1389 return train_scores, val_scores, test_scores |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1390 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1391 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1392 def build_summary_html( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1393 predictor, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1394 df_train: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1395 df_val: Optional[pd.DataFrame], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1396 df_test: Optional[pd.DataFrame], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1397 label_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1398 extra_run_rows: Optional[list[tuple[str, str]]] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1399 class_balance_html: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1400 perf_table_html: Optional[str] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1401 ) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1402 sections = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1403 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1404 # Dataset Overview (first section in the tab) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1405 if class_balance_html: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1406 sections.append(f""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1407 <section class="section"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1408 <h2 class="section-title">Dataset Overview</h2> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1409 <div class="card"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1410 {class_balance_html} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1411 </div> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1412 </section> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1413 """.strip()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1414 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1415 # Performance Summary |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1416 if perf_table_html: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1417 sections.append(f""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1418 <section class="section"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1419 <h2 class="section-title">Model Performance Summary</h2> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1420 <div class="card"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1421 {perf_table_html} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1422 </div> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1423 </section> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1424 """.strip()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1425 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1426 # Model Configuration |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1427 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1428 # Remove Predictor type and Framework, and ensure Model Architecture is present |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1429 base_rows: list[tuple[str, str]] = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1430 if extra_run_rows: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1431 # Remove any rows with keys 'Predictor type' or 'Framework' |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1432 base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1433 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1434 def _fmt(v): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1435 if v is None or v == "": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1436 return "—" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1437 return _escape(str(v)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1438 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1439 rows_html = "\n".join( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1440 f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1441 for k, v in base_rows |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1442 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1443 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1444 sections.append(f""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1445 <section class="section"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1446 <h2 class="section-title">Model Configuration</h2> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1447 <div class="card"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1448 <table class="kv-table"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1449 <thead><tr><th>Key</th><th>Value</th></tr></thead> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1450 <tbody> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1451 {rows_html} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1452 </tbody> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1453 </table> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1454 </div> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1455 </section> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1456 """.strip()) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1457 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1458 return "\n".join(sections).strip() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1459 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1460 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1461 def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1462 """Build a visualization of feature importance.""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1463 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1464 # Try to get feature importance from predictor |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1465 fi = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1466 if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1467 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1468 fi = predictor.feature_importance(df_train) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1469 except Exception as e: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1470 return f"<p>Could not compute feature importance: {e}</p>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1471 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1472 if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1473 return "<p>Feature importance not available for this model.</p>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1474 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1475 # Format as a sortable table |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1476 rows = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1477 if isinstance(fi, pd.DataFrame): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1478 fi = fi.sort_values("importance", ascending=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1479 for _, row in fi.iterrows(): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1480 feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1481 imp = float(row["importance"]) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1482 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1483 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1484 # Handle other formats (dict, etc) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1485 for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1486 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1487 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1488 if not rows: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1489 return "<p>No feature importance values available.</p>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1490 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1491 table_html = f""" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1492 <table class="performance-summary"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1493 <thead> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1494 <tr> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1495 <th class="sortable">Feature</th> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1496 <th class="sortable">Importance</th> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1497 </tr> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1498 </thead> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1499 <tbody> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1500 {"".join(rows)} |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1501 </tbody> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1502 </table> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1503 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1504 return table_html |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1505 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1506 except Exception as e: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1507 return f"<p>Error building feature importance visualization: {e}</p>" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1508 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1509 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1510 def build_test_html_and_plots( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1511 predictor, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1512 problem_type: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1513 df_test: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1514 label_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1515 tmpdir: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1516 threshold: Optional[float] = None, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1517 ) -> Tuple[str, List[str]]: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1518 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1519 Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1520 Returns: (html_template_with_{}, list_of_plot_divs) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1521 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1522 plots: List[str] = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1523 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1524 y_true = df_test[label_column].values |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1525 classes = np.unique(y_true) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1526 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1527 # Try proba/labels where meaningful |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1528 pred_labels = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1529 pred_proba = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1530 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1531 pred_labels = predictor.predict(df_test) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1532 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1533 pass |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1534 try: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1535 # MultiModalPredictor exposes predict_proba for classification problems. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1536 pred_proba = predictor.predict_proba(df_test) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1537 except Exception: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1538 pred_proba = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1539 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1540 proba_arr = None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1541 if pred_proba is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1542 if isinstance(pred_proba, pd.Series): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1543 proba_arr = pred_proba.to_numpy().reshape(-1, 1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1544 elif isinstance(pred_proba, pd.DataFrame): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1545 proba_arr = pred_proba.to_numpy() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1546 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1547 proba_arr = np.asarray(pred_proba) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1548 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1549 # Thresholded labels for binary |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1550 if problem_type == "binary" and threshold is not None and proba_arr is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1551 pos_label, neg_label = classes.max(), classes.min() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1552 pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1553 pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1554 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1555 # Confusion matrix / per-class now reflect thresholded labels |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1556 if problem_type in ("binary", "multiclass") and pred_labels is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1557 cm_title = "Confusion Matrix" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1558 if threshold is not None and problem_type == "binary": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1559 thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1560 cm_title = f"Confusion Matrix (Threshold = {thr_str})" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1561 fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1562 plots.append(plot_with_table_style_title(fig_cm, cm_title)) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1563 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1564 fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1565 plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1566 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1567 # ROC/PR where possible — choose positive-class scores safely |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1568 pos_label = classes.max() # or set explicitly, e.g., 1 or "yes" |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1569 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1570 if isinstance(pred_proba, pd.DataFrame): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1571 proba_arr = pred_proba.to_numpy() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1572 if pos_label in pred_proba.columns: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1573 pos_idx = list(pred_proba.columns).index(pos_label) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1574 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1575 pos_idx = -1 # fallback to last column |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1576 elif isinstance(pred_proba, pd.Series): |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1577 proba_arr = pred_proba.to_numpy().reshape(-1, 1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1578 pos_idx = 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1579 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1580 proba_arr = np.asarray(pred_proba) if pred_proba is not None else None |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1581 pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1582 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1583 if proba_arr is not None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1584 y_bin = (y_true == pos_label).astype(int) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1585 pos_scores = ( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1586 proba_arr.reshape(-1) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1587 if proba_arr.ndim == 1 or proba_arr.shape[1] == 1 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1588 else proba_arr[:, pos_idx] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1589 ) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1590 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1591 fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1592 plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1593 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1594 fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1595 plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1596 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1597 # Additional diagnostics aligned with ImageLearner style |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1598 if problem_type == "binary": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1599 conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1600 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1601 else: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1602 conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1603 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1604 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1605 if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1606 fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1607 plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1608 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1609 # Regression visuals |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1610 if problem_type == "regression": |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1611 if pred_labels is None: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1612 pred_labels = predictor.predict(df_test) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1613 fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1614 plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1615 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1616 fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1617 plots.append(plot_with_table_style_title(fig_res, "Residual Plot")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1618 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1619 fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1620 plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1621 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1622 fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1623 plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration")) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1624 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1625 # Small HTML template with placeholder for metric rows the caller fills in |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1626 test_html_template = """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1627 <h2>Test Performance Summary</h2> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1628 <table class="performance-summary"> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1629 <thead><tr><th>Metric</th><th>Test</th></tr></thead> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1630 <tbody>{}</tbody> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1631 </table> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1632 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1633 return test_html_template, plots |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1634 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1635 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1636 def build_feature_html( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1637 predictor, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1638 df_train: pd.DataFrame, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1639 label_column: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1640 include_modalities: bool = True, # ← NEW |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1641 include_class_balance: bool = True, # ← NEW |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1642 ) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1643 sections = [] |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1644 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1645 # (Typical feature importance content…) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1646 fi_html = build_feature_importance_html(predictor, df_train, label_column) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1647 sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1648 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1649 # Previously: Modalities & Inputs and/or Class Balance may have been here. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1650 # Only render them if flags are True. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1651 if include_modalities: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1652 from report_utils import build_modalities_html |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1653 modalities_html = build_modalities_html(predictor, df_train, label_column) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1654 sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1655 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1656 if include_class_balance: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1657 from report_utils import build_class_balance_html |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1658 cb_html = build_class_balance_html(df_train, label_column) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1659 sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>") |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1660 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1661 return "\n".join(sections) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1662 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1663 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1664 def assemble_full_html_report( |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1665 summary_html: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1666 train_html: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1667 test_html: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1668 plots: List[str], |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1669 feature_html: str, |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1670 ) -> str: |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1671 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1672 Wrap the four tabs using utils.build_tabbed_html and return full HTML. |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1673 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1674 # Append plots under the Test tab (already wrapped with titles) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1675 test_full = test_html + "".join(plots) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1676 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1677 tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1678 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1679 html_out = get_html_template() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1680 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1681 # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False) |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1682 html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n' |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1683 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1684 # Optional: centering tweaks |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1685 html_out += """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1686 <style> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1687 .plotly-center { display: flex; justify-content: center; } |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1688 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1689 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1690 </style> |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1691 """ |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1692 # Help modal HTML/JS |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1693 html_out += get_metrics_help_modal() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1694 |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1695 html_out += tabs |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1696 html_out += get_html_closing() |
|
375c36923da1
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff
changeset
|
1697 return html_out |
