annotate plot_logic.py @ 0:375c36923da1 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
author goeckslab
date Tue, 09 Dec 2025 23:49:47 +0000
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1 from __future__ import annotations
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
2
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
3 import html
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
4 import os
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
5 from html import escape as _escape
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
6 from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
7
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
8 import matplotlib.pyplot as plt
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
9 import numpy as np
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
10 import pandas as pd
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
11 import plotly.express as px
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
12 import plotly.graph_objects as go
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
13 import shap
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
14 from feature_help_modal import get_metrics_help_modal
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
15 from report_utils import build_tabbed_html, get_html_closing, get_html_template
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
16 from sklearn.calibration import calibration_curve
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
17 from sklearn.metrics import (
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
18 auc,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
19 average_precision_score,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
20 classification_report,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
21 confusion_matrix,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
22 log_loss,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
23 precision_recall_curve,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
24 roc_auc_score,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
25 roc_curve,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
26 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
27 from sklearn.model_selection import learning_curve as skl_learning_curve
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
28 from sklearn.preprocessing import label_binarize
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
29
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
30 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
31 # Utilities
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
32 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
33
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
34
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
35 def plot_with_table_style_title(fig, title: str) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
36 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
37 Render a Plotly figure with a report-style <h2> header so it matches the
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
38 green table section headers.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
39 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
40 # kill Plotly’s built-in title
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
41 fig.update_layout(title=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
42
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
43 # figure HTML without PlotlyJS (we load it once globally)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
44 plot_html = fig.to_html(full_html=False, include_plotlyjs=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
45
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
46 # use <h2> — your CSS already styles <h2> like the table headers
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
47 return f"""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
48 <h2>{html.escape(title)}</h2>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
49 <div class="plotly-center">{plot_html}</div>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
50 """.strip()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
51
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
52
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
53 def _save_plotly(fig: go.Figure, path: Optional[str]) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
54 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
55 Save a Plotly figure. If `path` ends with `.html`, save interactive HTML.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
56 If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
57 If None, do nothing (caller may choose to display in notebook).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
58 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
59 if not path:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
60 return
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
61 ext = os.path.splitext(path)[1].lower()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
62 if ext == ".html":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
63 fig.write_html(path, include_plotlyjs="cdn", full_html=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
64 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
65 # Requires kaleido: pip install -U kaleido
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
66 fig.write_image(path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
67
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
68
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
69 def _save_matplotlib(path: Optional[str]) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
70 """Save current Matplotlib figure if `path` is provided, else show()."""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
71 if path:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
72 plt.savefig(path, bbox_inches="tight")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
73 plt.close()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
74 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
75 plt.show()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
76
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
77 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
78 # Classification Plots
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
79 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
80
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
81
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
82 def generate_confusion_matrix_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
83 y_true,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
84 y_pred,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
85 title: str = "Confusion Matrix",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
86 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
87 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
88 y_pred = np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
89
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
90 # Class order (works for strings or numbers)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
91 labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
92 cm = confusion_matrix(y_true, y_pred, labels=labels)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
93 max_val = cm.max() if cm.size else 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
94
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
95 # Use categorical axes by passing string labels for x/y
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
96 cats = [str(label) for label in labels]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
97 total = int(cm.sum())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
98
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
99 fig = go.Figure(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
100 data=go.Heatmap(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
101 z=cm,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
102 x=cats, # categorical x
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
103 y=cats, # categorical y
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
104 colorscale="Blues",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
105 showscale=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
106 colorbar=dict(title="Count"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
107 xgap=2,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
108 ygap=2,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
109 hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
110 zmin=0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
111 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
112 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
113
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
114 # Add annotations with count and percentage (all white text, matching sample_output.html)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
115 annotations = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
116 for i in range(cm.shape[0]):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
117 for j in range(cm.shape[1]):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
118 val = int(cm[i, j])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
119 pct = (val / total * 100) if total > 0 else 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
120 text_color = "white" if max_val and val > (max_val / 2) else "black"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
121 # Count annotation (bold, bottom)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
122 annotations.append(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
123 dict(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
124 x=cats[j],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
125 y=cats[i],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
126 text=f"<b>{val}</b>",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
127 showarrow=False,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
128 font=dict(color=text_color, size=14),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
129 xanchor="center",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
130 yanchor="bottom",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
131 yshift=2
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
132 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
133 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
134 # Percentage annotation (top)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
135 annotations.append(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
136 dict(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
137 x=cats[j],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
138 y=cats[i],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
139 text=f"{pct:.1f}%",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
140 showarrow=False,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
141 font=dict(color=text_color, size=13),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
142 xanchor="center",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
143 yanchor="top",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
144 yshift=-2
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
145 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
146 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
147
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
148 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
149 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
150 xaxis_title="Predicted label",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
151 yaxis_title="True label",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
152 xaxis=dict(type="category"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
153 yaxis=dict(type="category", autorange="reversed"), # typical CM orientation
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
154 margin=dict(l=80, r=20, t=40, b=80),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
155 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
156 plot_bgcolor="white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
157 paper_bgcolor="white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
158 annotations=annotations
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
159 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
160 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
161
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
162
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
163 def generate_roc_curve_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
164 y_true_bin: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
165 y_score: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
166 title: str = "ROC Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
167 marker_threshold: float | None = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
168 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
169 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
170 y_score = np.asarray(y_score).astype(float).reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
171
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
172 fpr, tpr, thr = roc_curve(y_true_bin, y_score)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
173 roc_auc = auc(fpr, tpr)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
174
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
175 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
176 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
177 x=fpr, y=tpr, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
178 name=f"ROC (AUC={roc_auc:.3f})",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
179 line=dict(width=3)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
180 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
181
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
182 # 45° chance line (no legend to keep it clean)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
183 fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
184 line=dict(dash="dash", width=2, color="#888"), showlegend=False))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
185
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
186 # Optional marker at the user threshold
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
187 if marker_threshold is not None and len(thr):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
188 # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
189 finite = np.isfinite(thr)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
190 if np.any(finite):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
191 idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
192 idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
193 x_m, y_m = float(fpr[idx]), float(tpr[idx])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
194
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
195 fig.add_trace(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
196 go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
197 x=[x_m], y=[y_m],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
198 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
199 name=f"@ {float(marker_threshold):.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
200 marker=dict(size=12, color="red", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
201 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
202 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
203 fig.add_annotation(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
204 x=0.02, y=0.98, xref="paper", yref="paper",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
205 text=f"threshold = {float(marker_threshold):.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
206 showarrow=False,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
207 font=dict(color="black", size=12),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
208 align="left"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
209 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
210
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
211 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
212 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
213 xaxis_title="False Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
214 yaxis_title="True Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
215 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
216 legend=dict(x=1, y=0, xanchor="right"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
217 margin=dict(l=60, r=20, t=60, b=60),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
218 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
219 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
220
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
221
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
222 def generate_pr_curve_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
223 y_true_bin: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
224 y_score: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
225 title: str = "Precision–Recall Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
226 marker_threshold: float | None = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
227 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
228 y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
229 y_score = np.asarray(y_score).astype(float).reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
230
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
231 precision, recall, thr = precision_recall_curve(y_true_bin, y_score)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
232 pr_auc = auc(recall, precision)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
233
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
234 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
235 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
236 x=recall, y=precision, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
237 name=f"PR (AUC={pr_auc:.3f})",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
238 line=dict(width=3)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
239 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
240
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
241 # Optional marker at the user threshold
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
242 if marker_threshold is not None and len(thr):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
243 # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
244 j = int(np.argmin(np.abs(thr - float(marker_threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
245 j = int(np.clip(j, 0, len(thr) - 1))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
246 x_m, y_m = float(recall[j + 1]), float(precision[j + 1])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
247
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
248 fig.add_trace(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
249 go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
250 x=[x_m], y=[y_m],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
251 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
252 name=f"@ {float(marker_threshold):.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
253 marker=dict(size=12, color="red", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
254 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
255 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
256 fig.add_annotation(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
257 x=0.02, y=0.98, xref="paper", yref="paper",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
258 text=f"threshold = {float(marker_threshold):.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
259 showarrow=False,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
260 font=dict(color="black", size=12),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
261 align="left"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
262 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
263
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
264 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
265 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
266 xaxis_title="Recall",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
267 yaxis_title="Precision",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
268 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
269 legend=dict(x=1, y=0, xanchor="right"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
270 margin=dict(l=60, r=20, t=60, b=60),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
271 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
272 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
273
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
274
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
275 def generate_calibration_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
276 y_true_bin: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
277 y_prob: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
278 n_bins: int = 10,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
279 title: str = "Calibration Plot",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
280 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
281 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
282 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
283 Binary calibration curve (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
284 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
285 prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
286 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
287 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
288 x=prob_pred, y=prob_true, mode="lines+markers", name="Model",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
289 line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
290 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
291 fig.add_trace(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
292 go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
293 x=[0, 1], y=[0, 1],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
294 mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
295 line=dict(dash="dash", color="#808080", width=2),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
296 name="Perfect"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
297 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
298 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
299 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
300 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
301 xaxis_title="Predicted Probability",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
302 yaxis_title="Observed Probability",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
303 yaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
304 xaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
305 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
306 margin=dict(l=60, r=40, t=50, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
307 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
308 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
309 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
310
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
311
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
312 def generate_threshold_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
313 y_true_bin: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
314 y_prob: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
315 title: str = "Threshold Plot",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
316 user_threshold: float | None = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
317 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
318 y_true = np.asarray(y_true_bin, dtype=int).ravel()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
319 p = np.asarray(y_prob, dtype=float).ravel()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
320 p = np.nan_to_num(p, nan=0.0)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
321 p = np.clip(p, 0.0, 1.0)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
322
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
323 def _compute_metrics(thresholds: np.ndarray):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
324 """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays."""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
325 prec, rec, f1, qrate = [], [], [], []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
326 for t in thresholds:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
327 yhat = (p >= t).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
328 tp = int(((yhat == 1) & (y_true == 1)).sum())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
329 fp = int(((yhat == 1) & (y_true == 0)).sum())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
330 fn = int(((yhat == 0) & (y_true == 1)).sum())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
331
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
332 pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
333 rc = tp / (tp + fn) if (tp + fn) else 0.0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
334 f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
335 q = float(yhat.mean())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
336
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
337 prec.append(pr)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
338 rec.append(rc)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
339 f1.append(f)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
340 qrate.append(q)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
341 return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
342
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
343 # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
344 th = np.linspace(0.0, 1.0, 101)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
345 prec, rec, f1_arr, qrate = _compute_metrics(th)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
346
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
347 # Compute F1*-optimal threshold using actual score distribution (more precise than grid)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
348 cand_th = np.unique(np.concatenate(([0.0, 1.0], p)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
349 # cap to a reasonable size by sampling if extremely large
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
350 if cand_th.size > 2000:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
351 cand_th = np.linspace(0.0, 1.0, 2001)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
352 _, _, f1_cand, _ = _compute_metrics(cand_th)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
353
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
354 if np.all(np.isnan(f1_cand)):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
355 t_star = 0.5 # fallback when no valid F1 can be computed
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
356 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
357 f1_max = np.nanmax(f1_cand)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
358 best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
359 # pick the middle of the best candidates to avoid biasing toward 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
360 best_idx = int(best_idxs[len(best_idxs) // 2])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
361 t_star = float(cand_th[best_idx])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
362
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
363 # Replace NaNs for plotting (set to 0 where precision is undefined)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
364 prec_plot = np.nan_to_num(prec, nan=0.0)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
365
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
366 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
367
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
368 # Precision (blue line)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
369 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
370 x=th, y=prec_plot, mode="lines", name="Precision",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
371 line=dict(width=3, color="#1f77b4"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
372 hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
373 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
374
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
375 # Recall (orange line)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
376 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
377 x=th, y=rec, mode="lines", name="Recall",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
378 line=dict(width=3, color="#ff7f0e"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
379 hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
380 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
381
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
382 # F1 (green line)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
383 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
384 x=th, y=f1_arr, mode="lines", name="F1",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
385 line=dict(width=3, color="#2ca02c"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
386 hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
387 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
388
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
389 # Queue Rate (grey dashed line)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
390 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
391 x=th, y=qrate, mode="lines", name="Queue Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
392 line=dict(width=2, color="#808080", dash="dash"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
393 hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
394 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
395
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
396 # F1*-optimal threshold marker (dashed vertical line)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
397 fig.add_vline(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
398 x=t_star,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
399 line_width=2,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
400 line_dash="dash",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
401 line_color="black",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
402 annotation_text=f"t* = {t_star:.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
403 annotation_position="top"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
404 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
405
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
406 # User threshold (solid red line) if provided
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
407 if user_threshold is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
408 fig.add_vline(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
409 x=float(user_threshold),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
410 line_width=2,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
411 line_color="red",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
412 annotation_text=f"threshold = {float(user_threshold):.2f}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
413 annotation_position="top"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
414 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
415
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
416 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
417 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
418 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
419 xaxis=dict(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
420 title="Discrimination Threshold",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
421 range=[0, 1],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
422 gridcolor="#e0e0e0",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
423 showgrid=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
424 zeroline=False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
425 ),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
426 yaxis=dict(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
427 title="Score",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
428 range=[0, 1],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
429 gridcolor="#e0e0e0",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
430 showgrid=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
431 zeroline=False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
432 ),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
433 legend=dict(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
434 orientation="h",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
435 yanchor="bottom",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
436 y=1.02,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
437 xanchor="right",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
438 x=1.0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
439 ),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
440 margin=dict(l=60, r=20, t=40, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
441 plot_bgcolor="white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
442 paper_bgcolor="white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
443 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
444 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
445
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
446
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
447 def generate_per_class_metrics_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
448 y_true: Sequence,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
449 y_pred: Sequence,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
450 metrics: Sequence[str] = ("precision", "recall", "f1_score"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
451 title: str = "Classification Report",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
452 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
453 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
454 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
455 Per-class metrics heatmap (Plotly), similar to sklearn's classification report.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
456 Rows = classes, columns = metrics; cell text shows the value (0–1).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
457 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
458 # Map display names -> sklearn keys
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
459 key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
460 report = classification_report(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
461 y_true, y_pred, output_dict=True, zero_division=0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
462 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
463
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
464 # Order classes sensibly (numeric if possible, else lexical)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
465 def _sort_key(x):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
466 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
467 return (0, float(x))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
468 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
469 return (1, str(x))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
470
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
471 # Use all classes seen in y_true or y_pred (so rows don't jump around)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
472 uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
473 classes = [str(c) for c in uniq]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
474
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
475 # Build Z matrix (rows=classes, cols=metrics)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
476 used_metrics = [key_map.get(m, m) for m in metrics]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
477 z = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
478 for c in classes:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
479 row = report.get(c, {})
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
480 z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
481 z = np.array(z, dtype=float)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
482
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
483 # Pretty cell labels
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
484 z_text = [[f"{v:.2f}" for v in r] for r in z]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
485
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
486 fig = go.Figure(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
487 data=go.Heatmap(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
488 z=z,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
489 x=list(metrics), # keep display names ("precision", "recall", "f1_score")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
490 y=classes, # classes as strings
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
491 colorscale="Reds",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
492 zmin=0.0,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
493 zmax=1.0,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
494 colorbar=dict(title="Value"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
495 text=z_text,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
496 texttemplate="%{text}",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
497 hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
498 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
499 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
500 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
501 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
502 xaxis_title="",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
503 yaxis_title="Class",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
504 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
505 margin=dict(l=60, r=60, t=70, b=40),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
506 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
507
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
508 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
509 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
510
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
511
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
512 def generate_multiclass_roc_curve_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
513 y_true: Sequence,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
514 y_prob: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
515 classes: Sequence,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
516 title: str = "Multiclass ROC Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
517 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
518 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
519 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
520 One-vs-rest ROC curves for multiclass (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
521 Handles binary passed as 2-column probs as well.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
522 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
523 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
524 y_prob = np.asarray(y_prob)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
525
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
526 # Normalize to shape (n_samples, n_classes)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
527 if y_prob.ndim == 1 or y_prob.shape[1] == 1:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
528 y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
529
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
530 y_true_bin = label_binarize(y_true, classes=classes)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
531 if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
532 y_true_bin = np.hstack([1 - y_true_bin, y_true_bin])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
533
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
534 if y_prob.shape[1] != y_true_bin.shape[1]:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
535 raise ValueError(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
536 f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}."
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
537 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
538
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
539 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
540 for i, cls in enumerate(classes):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
541 fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
542 auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
543 fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
544
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
545 fig.add_trace(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
546 go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
547 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
548 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
549 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
550 xaxis_title="False Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
551 yaxis_title="True Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
552 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
553 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
554 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
555 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
556
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
557
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
558 def generate_multiclass_pr_curve_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
559 y_true: Sequence,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
560 y_prob: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
561 classes: Optional[Sequence] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
562 title: str = "Precision–Recall Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
563 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
564 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
565 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
566 Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
567 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
568 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
569 y_prob = np.asarray(y_prob)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
570 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
571
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
572 if not classes or len(classes) == 2:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
573 precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
574 ap = average_precision_score(y_true, y_prob[:, 1])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
575 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
576 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
577 for i, cls in enumerate(classes):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
578 y_true_bin = (y_true == cls).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
579 y_prob_cls = y_prob[:, i]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
580 precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
581 ap = average_precision_score(y_true_bin, y_prob_cls)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
582 fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
583
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
584 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
585 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
586 xaxis_title="Recall",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
587 yaxis_title="Precision",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
588 yaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
589 xaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
590 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
591 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
592 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
593 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
594
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
595
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
596 def generate_metric_comparison_bar(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
597 metrics_scores: Mapping[str, Sequence[float]],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
598 phases: Sequence[str] = ("train", "val", "test"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
599 title: str = "Metric Comparison Across Phases",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
600 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
601 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
602 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
603 Grouped bar chart comparing metrics across phases (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
604 metrics_scores: {metric_name: [train, val, test]}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
605 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
606 df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"})
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
607 df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
608 fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
609 ymax = max(1.0, df_m["Score"].max() * 1.05)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
610 fig.update_yaxes(range=[0, ymax])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
611 fig.update_layout(template="plotly_white")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
612 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
613 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
614
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
615 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
616 # Regression Plots
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
617 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
618
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
619
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
620 def generate_scatter_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
621 y_true: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
622 y_pred: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
623 title: str = "Predicted vs Actual",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
624 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
625 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
626 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
627 Predicted vs. Actual scatter with y=x reference (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
628 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
629 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
630 y_pred = np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
631 vmin = float(min(np.min(y_true), np.min(y_pred)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
632 vmax = float(max(np.max(y_true), np.max(y_pred)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
633
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
634 fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
635 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
636 fig.update_layout(template="plotly_white")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
637 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
638 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
639
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
640
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
641 def generate_residual_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
642 y_true: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
643 y_pred: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
644 title: str = "Residual Plot",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
645 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
646 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
647 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
648 Residuals vs Predicted (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
649 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
650 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
651 y_pred = np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
652 residuals = y_true - y_pred
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
653
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
654 fig = px.scatter(x=y_pred, y=residuals, opacity=0.6,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
655 labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"},
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
656 title=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
657 fig.add_hline(y=0, line_dash="dash")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
658 fig.update_layout(template="plotly_white")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
659 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
660 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
661
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
662
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
663 def generate_residual_histogram(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
664 y_true: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
665 y_pred: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
666 bins: int = 30,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
667 title: str = "Residual Histogram",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
668 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
669 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
670 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
671 Residuals histogram (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
672 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
673 residuals = np.asarray(y_true) - np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
674 fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
675 fig.update_layout(yaxis_title="Frequency", template="plotly_white")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
676 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
677 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
678
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
679
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
680 def generate_regression_calibration_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
681 y_true: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
682 y_pred: Sequence[float],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
683 num_bins: int = 10,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
684 title: str = "Regression Calibration Plot",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
685 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
686 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
687 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
688 Binned Actual vs Predicted means (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
689 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
690 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
691 y_pred = np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
692
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
693 order = np.argsort(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
694 y_true_sorted = y_true[order]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
695 y_pred_sorted = y_pred[order]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
696
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
697 bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
698 bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
699 bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
700
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
701 vmin = float(min(np.min(y_pred), np.min(y_true)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
702 vmax = float(max(np.max(y_pred), np.max(y_true)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
703
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
704 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
705 fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
706 name="Binned Actual vs Predicted"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
707 fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
708 name="Ideal"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
709 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
710 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
711 xaxis_title="Mean Predicted per bin",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
712 yaxis_title="Mean Actual per bin",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
713 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
714 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
715 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
716 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
717
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
718 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
719 # Confidence / Diagnostics
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
720 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
721
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
722
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
723 def plot_error_vs_confidence(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
724 y_true: Union[Sequence[int], np.ndarray],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
725 y_proba: Union[Sequence[float], np.ndarray],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
726 n_bins: int = 10,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
727 title: str = "Error vs Confidence",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
728 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
729 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
730 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
731 Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
732 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
733 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
734 y_proba = np.asarray(y_proba).reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
735 y_pred = (y_proba >= 0.5).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
736 confidence = np.maximum(y_proba, 1 - y_proba)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
737 error = (y_pred != y_true).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
738
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
739 bins = np.linspace(0.0, 1.0, n_bins + 1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
740 idx = np.digitize(confidence, bins, right=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
741
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
742 centers, err_rates = [], []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
743 for i in range(1, len(bins)):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
744 mask = (idx == i)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
745 if mask.any():
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
746 centers.append(float(confidence[mask].mean()))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
747 err_rates.append(float(error[mask].mean()))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
748
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
749 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
750 fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
751 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
752 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
753 xaxis_title="Confidence (max predicted probability)",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
754 yaxis_title="Error Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
755 yaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
756 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
757 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
758 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
759 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
760
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
761
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
762 def plot_confidence_histogram(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
763 y_proba: np.ndarray,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
764 bins: int = 20,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
765 title: str = "Confidence Histogram",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
766 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
767 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
768 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
769 Histogram of max predicted probabilities (Plotly).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
770 Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
771 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
772 y_proba = np.asarray(y_proba)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
773 if y_proba.ndim == 1:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
774 confidences = np.maximum(y_proba, 1 - y_proba)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
775 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
776 confidences = np.max(y_proba, axis=1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
777
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
778 fig = px.histogram(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
779 x=confidences,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
780 nbins=bins,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
781 range_x=(0, 1),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
782 histnorm="percent",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
783 labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"},
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
784 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
785 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
786 if fig.data:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
787 fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
788 fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
789 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
790 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
791
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
792 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
793 # Learning Curve
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
794 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
795
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
796
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
797 def generate_learning_curve_from_predictions(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
798 y_true,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
799 y_pred=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
800 y_proba=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
801 classes=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
802 metric: str = "accuracy",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
803 train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
804 n_repeats: int = 5,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
805 seed: int = 42,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
806 title: str = "Learning Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
807 path: str | None = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
808 return_stats: bool = False,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
809 ) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
810 rng = np.random.default_rng(seed)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
811 y_true = np.asarray(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
812 N = len(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
813
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
814 if metric == "accuracy" and y_pred is None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
815 raise ValueError("accuracy curve requires y_pred")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
816 if metric == "log_loss" and y_proba is None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
817 raise ValueError("log_loss curve requires y_proba")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
818
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
819 if y_proba is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
820 y_proba = np.asarray(y_proba)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
821 if y_pred is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
822 y_pred = np.asarray(y_pred)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
823
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
824 sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
825 means, stds = [], []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
826 for n in sizes:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
827 vals = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
828 for _ in range(n_repeats):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
829 idx = rng.choice(N, size=n, replace=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
830 if metric == "accuracy":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
831 vals.append(float((y_true[idx] == y_pred[idx]).mean()))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
832 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
833 if y_proba.ndim == 1:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
834 p = y_proba[idx]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
835 pp = np.column_stack([1 - p, p])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
836 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
837 pp = y_proba[idx]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
838 vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes)))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
839 means.append(np.mean(vals))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
840 stds.append(np.std(vals))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
841
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
842 if return_stats:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
843 return sizes, means, stds
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
844
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
845 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
846 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
847 x=sizes, y=means, mode="lines+markers", name="Train",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
848 line=dict(width=3, shape="spline"), marker=dict(size=7),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
849 error_y=dict(type="data", array=stds, visible=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
850 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
851 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
852 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
853 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
854 xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
855 yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
856 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
857 margin=dict(l=50, r=20, t=60, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
858 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
859 if path:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
860 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
861 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
862
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
863
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
864 def build_train_html_and_plots(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
865 predictor,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
866 problem_type: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
867 df_train: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
868 label_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
869 tmpdir: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
870 df_val: Optional[pd.DataFrame] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
871 seed: int = 42,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
872 perf_table_html: str | None = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
873 threshold: Optional[float] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
874 section_tile: str = "Training Diagnostics",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
875 ) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
876 y_true = df_train[label_column].values
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
877 y_true_val = df_val[label_column].values if df_val is not None else None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
878 # predictions on TRAIN
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
879 pred_labels, pred_proba = None, None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
880 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
881 pred_labels = predictor.predict(df_train)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
882 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
883 pass
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
884 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
885 proba_raw = predictor.predict_proba(df_train)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
886 pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
887 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
888 pred_proba = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
889
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
890 # predictions on VAL (if provided)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
891 pred_labels_val, pred_proba_val = None, None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
892 if df_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
893 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
894 pred_labels_val = predictor.predict(df_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
895 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
896 pred_labels_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
897 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
898 proba_raw_val = predictor.predict_proba(df_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
899 pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
900 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
901 pred_proba_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
902
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
903 pos_scores_train: Optional[np.ndarray] = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
904 pos_scores_val: Optional[np.ndarray] = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
905 if problem_type == "binary":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
906 if pred_proba is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
907 pos_scores_train = (
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
908 pred_proba.reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
909 if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
910 else pred_proba[:, -1]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
911 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
912 if pred_proba_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
913 pos_scores_val = (
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
914 pred_proba_val.reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
915 if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
916 else pred_proba_val[:, -1]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
917 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
918
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
919 # Collect plots then append in desired order
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
920 perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
921 acc_plot = loss_plot = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
922 cm_train = pc_train = cm_val = pc_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
923 threshold_val_plot = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
924 roc_combined = pr_combined = cal_combined = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
925 mc_roc_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
926 conf_train = conf_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
927 bar_train = bar_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
928
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
929 # 1) Learning Curve — Accuracy
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
930 if problem_type in ("binary", "multiclass"):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
931 acc_fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
932 added_acc = False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
933 if pred_labels is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
934 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
935 y_true=y_true,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
936 y_pred=np.asarray(pred_labels),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
937 metric="accuracy",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
938 title="Learning Curves — Label Accuracy",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
939 seed=seed,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
940 return_stats=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
941 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
942 acc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
943 x=train_sizes, y=train_means, mode="lines+markers", name="Train",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
944 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
945 error_y=dict(type="data", array=train_stds, visible=True),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
946 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
947 added_acc = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
948 if pred_labels_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
949 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
950 y_true=y_true_val,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
951 y_pred=np.asarray(pred_labels_val),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
952 metric="accuracy",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
953 title="Learning Curves — Label Accuracy",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
954 seed=seed,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
955 return_stats=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
956 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
957 acc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
958 x=val_sizes, y=val_means, mode="lines+markers", name="Validation",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
959 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
960 error_y=dict(type="data", array=val_stds, visible=True),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
961 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
962 added_acc = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
963 if added_acc:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
964 acc_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
965 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
966 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
967 xaxis=dict(title="samples", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
968 yaxis=dict(title="accuracy", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
969 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
970 margin=dict(l=50, r=20, t=60, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
971 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
972 acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
973
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
974 # 2) Learning Curve — Loss
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
975 if problem_type in ("binary", "multiclass"):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
976 classes = np.unique(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
977 loss_fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
978 added_loss = False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
979 if pred_proba is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
980 pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
981 train_sizes, train_means, train_stds = generate_learning_curve_from_predictions(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
982 y_true=y_true,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
983 y_proba=pp,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
984 classes=classes,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
985 metric="log_loss",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
986 title="Learning Curves — Label Loss",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
987 seed=seed,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
988 return_stats=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
989 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
990 loss_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
991 x=train_sizes, y=train_means, mode="lines+markers", name="Train",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
992 line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
993 error_y=dict(type="data", array=train_stds, visible=True),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
994 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
995 added_loss = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
996 if pred_proba_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
997 pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
998 val_sizes, val_means, val_stds = generate_learning_curve_from_predictions(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
999 y_true=y_true_val,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1000 y_proba=pp_val,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1001 classes=classes,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1002 metric="log_loss",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1003 title="Learning Curves — Label Loss",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1004 seed=seed,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1005 return_stats=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1006 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1007 loss_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1008 x=val_sizes, y=val_means, mode="lines+markers", name="Validation",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1009 line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1010 error_y=dict(type="data", array=val_stds, visible=True),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1011 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1012 added_loss = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1013 if added_loss:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1014 loss_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1015 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1016 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1017 xaxis=dict(title="epoch", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1018 yaxis=dict(title="loss", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1019 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1020 margin=dict(l=50, r=20, t=60, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1021 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1022 loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1023
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1024 # Confusion matrices & per-class metrics
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1025 cm_train = pc_train = cm_val = pc_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1026
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1027 # Probability diagnostics (binary)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1028 if problem_type == "binary":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1029 # Combined Calibration (Train/Val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1030 cal_fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1031 added_cal = False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1032 if pos_scores_train is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1033 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1034 prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1035 cal_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1036 x=prob_pred, y=prob_true, mode="lines+markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1037 name="Train",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1038 line=dict(color="#1f77b4", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1039 marker=dict(size=7, color="#1f77b4"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1040 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1041 added_cal = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1042 if pos_scores_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1043 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1044 prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1045 cal_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1046 x=prob_pred_v, y=prob_true_v, mode="lines+markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1047 name="Validation",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1048 line=dict(color="#ff7f0e", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1049 marker=dict(size=7, color="#ff7f0e"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1050 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1051 added_cal = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1052 if added_cal:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1053 cal_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1054 x=[0, 1], y=[0, 1],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1055 mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1056 line=dict(dash="dash", color="#808080", width=2),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1057 name="Perfect",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1058 showlegend=True,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1059 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1060 cal_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1061 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1062 xaxis_title="Predicted Probability",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1063 yaxis_title="Observed Probability",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1064 xaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1065 yaxis=dict(range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1066 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1067 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1068 margin=dict(l=60, r=40, t=50, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1069 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1070 cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1071
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1072 # Combined ROC (Train/Val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1073 roc_fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1074 added_roc = False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1075 if pos_scores_train is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1076 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1077 fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1078 roc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1079 x=fpr_tr, y=tpr_tr, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1080 name="Train",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1081 line=dict(color="#1f77b4", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1082 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1083 if threshold is not None and np.isfinite(thr_tr).any():
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1084 finite = np.isfinite(thr_tr)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1085 idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1086 idx = int(np.nonzero(finite)[0][idx_local])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1087 roc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1088 x=[fpr_tr[idx]], y=[tpr_tr[idx]],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1089 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1090 name="Train @ threshold",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1091 marker=dict(size=12, color="#1f77b4", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1092 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1093 added_roc = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1094 if pos_scores_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1095 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1096 fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1097 roc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1098 x=fpr_v, y=tpr_v, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1099 name="Validation",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1100 line=dict(color="#ff7f0e", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1101 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1102 if threshold is not None and np.isfinite(thr_v).any():
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1103 finite = np.isfinite(thr_v)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1104 idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1105 idx = int(np.nonzero(finite)[0][idx_local])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1106 roc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1107 x=[fpr_v[idx]], y=[tpr_v[idx]],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1108 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1109 name="Val @ threshold",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1110 marker=dict(size=12, color="#ff7f0e", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1111 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1112 added_roc = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1113 if added_roc:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1114 roc_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1115 x=[0, 1], y=[0, 1], mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1116 line=dict(dash="dash", width=2, color="#808080"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1117 showlegend=False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1118 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1119 roc_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1120 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1121 xaxis_title="False Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1122 yaxis_title="True Positive Rate",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1123 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1124 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1125 margin=dict(l=60, r=20, t=60, b=60),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1126 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1127 roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1128
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1129 # Combined PR (Train/Val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1130 pr_fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1131 added_pr = False
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1132 if pos_scores_train is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1133 y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1134 prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1135 pr_auc_tr = auc(rec_tr, prec_tr)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1136 pr_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1137 x=rec_tr, y=prec_tr, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1138 name=f"Train (AUC={pr_auc_tr:.3f})",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1139 line=dict(color="#1f77b4", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1140 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1141 if threshold is not None and len(thr_tr):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1142 j = int(np.argmin(np.abs(thr_tr - float(threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1143 j = int(np.clip(j, 0, len(thr_tr) - 1))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1144 pr_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1145 x=[rec_tr[j + 1]], y=[prec_tr[j + 1]],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1146 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1147 name="Train @ threshold",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1148 marker=dict(size=12, color="#1f77b4", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1149 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1150 added_pr = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1151 if pos_scores_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1152 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1153 prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1154 pr_auc_v = auc(rec_v, prec_v)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1155 pr_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1156 x=rec_v, y=prec_v, mode="lines",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1157 name=f"Validation (AUC={pr_auc_v:.3f})",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1158 line=dict(color="#ff7f0e", width=3),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1159 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1160 if threshold is not None and len(thr_v):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1161 j = int(np.argmin(np.abs(thr_v - float(threshold))))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1162 j = int(np.clip(j, 0, len(thr_v) - 1))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1163 pr_fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1164 x=[rec_v[j + 1]], y=[prec_v[j + 1]],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1165 mode="markers",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1166 name="Val @ threshold",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1167 marker=dict(size=12, color="#ff7f0e", symbol="x")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1168 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1169 added_pr = True
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1170 if added_pr:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1171 pr_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1172 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1173 xaxis_title="Recall",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1174 yaxis_title="Precision",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1175 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1176 legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1177 margin=dict(l=60, r=20, t=60, b=60),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1178 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1179 pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1180
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1181 if pos_scores_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1182 y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1183 fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1184 user_threshold=threshold)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1185 threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1186
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1187 # Multiclass OVR ROC (validation)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1188 if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1189 classes_val = np.unique(y_true_val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1190 fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1191 mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1192
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1193 # Prediction Confidence Histogram (train/val)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1194 conf_train = conf_val = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1195
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1196 # Per-class accuracy bars
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1197 if problem_type in ("binary", "multiclass") and pred_labels is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1198 classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1199 acc_vals = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1200 for c in classes_for_bar:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1201 mask = y_true == c
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1202 acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1203 bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1204 bar_fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1205 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1206 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1207 xaxis=dict(title="Label", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1208 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1209 margin=dict(l=50, r=20, t=60, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1210 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1211 bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1212 if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1213 classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1214 acc_vals_val = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1215 for c in classes_for_bar_val:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1216 mask = y_true_val == c
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1217 acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1218 bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1219 bar_fig_val.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1220 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1221 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1222 xaxis=dict(title="Label", gridcolor="#eee"),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1223 yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1224 margin=dict(l=50, r=20, t=60, b=50),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1225 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1226 bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1227
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1228 # Assemble in requested order
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1229 pieces: list[str] = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1230 if perf_card:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1231 pieces.append(perf_card)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1232 for block in (threshold_val_plot, roc_combined, pr_combined):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1233 if block:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1234 pieces.append(block)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1235 # Remaining plots (keep existing order)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1236 for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1237 if block:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1238 pieces.append(block)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1239 # Learning curves should appear last in the tab
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1240 for block in (acc_plot, loss_plot):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1241 if block:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1242 pieces.append(block)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1243
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1244 if not pieces:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1245 return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1246
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1247 return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1248
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1249
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1250 def generate_learning_curve(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1251 estimator,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1252 X,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1253 y,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1254 scoring: str = "r2",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1255 cv_folds: int = 5,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1256 n_jobs: int = -1,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1257 train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10),
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1258 title: str = "Learning Curve",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1259 path: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1260 ) -> go.Figure:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1261 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1262 Learning curve using sklearn.learning_curve, visualized with Plotly.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1263 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1264 sizes, train_scores, test_scores = skl_learning_curve(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1265 estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1266 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1267 train_mean = train_scores.mean(axis=1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1268 train_std = train_scores.std(axis=1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1269 test_mean = test_scores.mean(axis=1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1270 test_std = test_scores.std(axis=1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1271
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1272 fig = go.Figure()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1273 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1274 x=sizes, y=train_mean, mode="lines+markers", name="Training score",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1275 error_y=dict(type="data", array=train_std, visible=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1276 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1277 fig.add_trace(go.Scatter(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1278 x=sizes, y=test_mean, mode="lines+markers", name="CV score",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1279 error_y=dict(type="data", array=test_std, visible=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1280 ))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1281 fig.update_layout(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1282 title=None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1283 xaxis_title="Training examples",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1284 yaxis_title=scoring,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1285 template="plotly_white",
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1286 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1287 _save_plotly(fig, path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1288 return fig
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1289
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1290 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1291 # SHAP (Matplotlib-based)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1292 # =========================
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1293
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1294
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1295 def generate_shap_summary_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1296 shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1297 ) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1298 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1299 SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited;
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1300 keep matplotlib for clarity and stability.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1301 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1302 plt.figure(figsize=(10, 8))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1303 shap.summary_plot(shap_values, features, show=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1304 plt.title(title)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1305 _save_matplotlib(path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1306
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1307
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1308 def generate_shap_force_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1309 explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1310 ) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1311 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1312 SHAP force plot (Matplotlib).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1313 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1314 shap_values = explainer(instance)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1315 plt.figure(figsize=(10, 4))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1316 shap.plots.force(shap_values[0], show=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1317 plt.title(title)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1318 _save_matplotlib(path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1319
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1320
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1321 def generate_shap_waterfall_plot(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1322 explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1323 ) -> None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1324 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1325 SHAP waterfall plot (Matplotlib).
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1326 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1327 shap_values = explainer(instance)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1328 plt.figure(figsize=(10, 6))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1329 shap.plots.waterfall(shap_values[0], show=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1330 plt.title(title)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1331 _save_matplotlib(path)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1332
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1333
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1334 def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1335 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1336 Return 'binary', 'multiclass', or 'regression'.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1337 Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1338 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1339 # AutoGluon predictors usually expose .problem_type; be defensive.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1340 pt = getattr(predictor, "problem_type", None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1341 if isinstance(pt, str):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1342 pt_l = pt.lower()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1343 if "regression" in pt_l:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1344 return "regression"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1345 if "binary" in pt_l:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1346 return "binary"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1347 if "multiclass" in pt_l or "multiclass" in pt_l:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1348 return "multiclass"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1349
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1350 y = df_train_full[label_column]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1351 if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1352 return "regression"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1353 return "binary" if y.nunique() == 2 else "multiclass"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1354
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1355
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1356 def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1357 """Make evaluate() outputs JSON/csv friendly floats."""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1358 out = {}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1359 for k, v in d.items():
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1360 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1361 out[k] = float(v)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1362 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1363 # keep only real-valued scalars
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1364 pass
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1365 return out
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1366
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1367
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1368 def evaluate_all(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1369 predictor,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1370 df_train: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1371 df_val: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1372 df_test: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1373 label_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1374 problem_type: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1375 ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1376 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1377 Run predictor.evaluate on train/val/test and normalize the result dicts to floats.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1378 MultiModalPredictor does not accept the `silent` kwarg, so call defensively.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1379 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1380 def _evaluate(df):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1381 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1382 return predictor.evaluate(df, silent=True)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1383 except TypeError:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1384 return predictor.evaluate(df)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1385
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1386 train_scores = _safe_floatify(_evaluate(df_train))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1387 val_scores = _safe_floatify(_evaluate(df_val))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1388 test_scores = _safe_floatify(_evaluate(df_test))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1389 return train_scores, val_scores, test_scores
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1390
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1391
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1392 def build_summary_html(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1393 predictor,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1394 df_train: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1395 df_val: Optional[pd.DataFrame],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1396 df_test: Optional[pd.DataFrame],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1397 label_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1398 extra_run_rows: Optional[list[tuple[str, str]]] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1399 class_balance_html: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1400 perf_table_html: Optional[str] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1401 ) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1402 sections = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1403
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1404 # Dataset Overview (first section in the tab)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1405 if class_balance_html:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1406 sections.append(f"""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1407 <section class="section">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1408 <h2 class="section-title">Dataset Overview</h2>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1409 <div class="card">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1410 {class_balance_html}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1411 </div>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1412 </section>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1413 """.strip())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1414
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1415 # Performance Summary
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1416 if perf_table_html:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1417 sections.append(f"""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1418 <section class="section">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1419 <h2 class="section-title">Model Performance Summary</h2>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1420 <div class="card">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1421 {perf_table_html}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1422 </div>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1423 </section>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1424 """.strip())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1425
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1426 # Model Configuration
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1427
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1428 # Remove Predictor type and Framework, and ensure Model Architecture is present
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1429 base_rows: list[tuple[str, str]] = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1430 if extra_run_rows:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1431 # Remove any rows with keys 'Predictor type' or 'Framework'
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1432 base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1433
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1434 def _fmt(v):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1435 if v is None or v == "":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1436 return "—"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1437 return _escape(str(v))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1438
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1439 rows_html = "\n".join(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1440 f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1441 for k, v in base_rows
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1442 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1443
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1444 sections.append(f"""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1445 <section class="section">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1446 <h2 class="section-title">Model Configuration</h2>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1447 <div class="card">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1448 <table class="kv-table">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1449 <thead><tr><th>Key</th><th>Value</th></tr></thead>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1450 <tbody>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1451 {rows_html}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1452 </tbody>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1453 </table>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1454 </div>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1455 </section>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1456 """.strip())
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1457
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1458 return "\n".join(sections).strip()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1459
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1460
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1461 def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1462 """Build a visualization of feature importance."""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1463 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1464 # Try to get feature importance from predictor
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1465 fi = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1466 if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1467 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1468 fi = predictor.feature_importance(df_train)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1469 except Exception as e:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1470 return f"<p>Could not compute feature importance: {e}</p>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1471
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1472 if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1473 return "<p>Feature importance not available for this model.</p>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1474
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1475 # Format as a sortable table
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1476 rows = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1477 if isinstance(fi, pd.DataFrame):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1478 fi = fi.sort_values("importance", ascending=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1479 for _, row in fi.iterrows():
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1480 feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1481 imp = float(row["importance"])
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1482 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1483 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1484 # Handle other formats (dict, etc)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1485 for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1486 rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1487
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1488 if not rows:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1489 return "<p>No feature importance values available.</p>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1490
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1491 table_html = f"""
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1492 <table class="performance-summary">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1493 <thead>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1494 <tr>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1495 <th class="sortable">Feature</th>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1496 <th class="sortable">Importance</th>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1497 </tr>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1498 </thead>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1499 <tbody>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1500 {"".join(rows)}
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1501 </tbody>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1502 </table>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1503 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1504 return table_html
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1505
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1506 except Exception as e:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1507 return f"<p>Error building feature importance visualization: {e}</p>"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1508
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1509
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1510 def build_test_html_and_plots(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1511 predictor,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1512 problem_type: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1513 df_test: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1514 label_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1515 tmpdir: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1516 threshold: Optional[float] = None,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1517 ) -> Tuple[str, List[str]]:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1518 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1519 Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1520 Returns: (html_template_with_{}, list_of_plot_divs)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1521 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1522 plots: List[str] = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1523
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1524 y_true = df_test[label_column].values
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1525 classes = np.unique(y_true)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1526
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1527 # Try proba/labels where meaningful
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1528 pred_labels = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1529 pred_proba = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1530 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1531 pred_labels = predictor.predict(df_test)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1532 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1533 pass
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1534 try:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1535 # MultiModalPredictor exposes predict_proba for classification problems.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1536 pred_proba = predictor.predict_proba(df_test)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1537 except Exception:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1538 pred_proba = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1539
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1540 proba_arr = None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1541 if pred_proba is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1542 if isinstance(pred_proba, pd.Series):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1543 proba_arr = pred_proba.to_numpy().reshape(-1, 1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1544 elif isinstance(pred_proba, pd.DataFrame):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1545 proba_arr = pred_proba.to_numpy()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1546 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1547 proba_arr = np.asarray(pred_proba)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1548
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1549 # Thresholded labels for binary
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1550 if problem_type == "binary" and threshold is not None and proba_arr is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1551 pos_label, neg_label = classes.max(), classes.min()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1552 pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1553 pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1554
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1555 # Confusion matrix / per-class now reflect thresholded labels
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1556 if problem_type in ("binary", "multiclass") and pred_labels is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1557 cm_title = "Confusion Matrix"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1558 if threshold is not None and problem_type == "binary":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1559 thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1560 cm_title = f"Confusion Matrix (Threshold = {thr_str})"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1561 fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1562 plots.append(plot_with_table_style_title(fig_cm, cm_title))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1563
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1564 fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1565 plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1566
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1567 # ROC/PR where possible — choose positive-class scores safely
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1568 pos_label = classes.max() # or set explicitly, e.g., 1 or "yes"
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1569
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1570 if isinstance(pred_proba, pd.DataFrame):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1571 proba_arr = pred_proba.to_numpy()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1572 if pos_label in pred_proba.columns:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1573 pos_idx = list(pred_proba.columns).index(pos_label)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1574 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1575 pos_idx = -1 # fallback to last column
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1576 elif isinstance(pred_proba, pd.Series):
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1577 proba_arr = pred_proba.to_numpy().reshape(-1, 1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1578 pos_idx = 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1579 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1580 proba_arr = np.asarray(pred_proba) if pred_proba is not None else None
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1581 pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1582
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1583 if proba_arr is not None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1584 y_bin = (y_true == pos_label).astype(int)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1585 pos_scores = (
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1586 proba_arr.reshape(-1)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1587 if proba_arr.ndim == 1 or proba_arr.shape[1] == 1
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1588 else proba_arr[:, pos_idx]
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1589 )
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1590
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1591 fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1592 plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1593
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1594 fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1595 plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1596
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1597 # Additional diagnostics aligned with ImageLearner style
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1598 if problem_type == "binary":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1599 conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1600 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1601 else:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1602 conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1603 plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1604
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1605 if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1606 fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1607 plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1608
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1609 # Regression visuals
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1610 if problem_type == "regression":
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1611 if pred_labels is None:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1612 pred_labels = predictor.predict(df_test)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1613 fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1614 plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1615
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1616 fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1617 plots.append(plot_with_table_style_title(fig_res, "Residual Plot"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1618
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1619 fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1620 plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1621
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1622 fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1623 plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration"))
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1624
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1625 # Small HTML template with placeholder for metric rows the caller fills in
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1626 test_html_template = """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1627 <h2>Test Performance Summary</h2>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1628 <table class="performance-summary">
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1629 <thead><tr><th>Metric</th><th>Test</th></tr></thead>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1630 <tbody>{}</tbody>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1631 </table>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1632 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1633 return test_html_template, plots
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1634
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1635
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1636 def build_feature_html(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1637 predictor,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1638 df_train: pd.DataFrame,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1639 label_column: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1640 include_modalities: bool = True, # ← NEW
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1641 include_class_balance: bool = True, # ← NEW
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1642 ) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1643 sections = []
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1644
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1645 # (Typical feature importance content…)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1646 fi_html = build_feature_importance_html(predictor, df_train, label_column)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1647 sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1648
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1649 # Previously: Modalities & Inputs and/or Class Balance may have been here.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1650 # Only render them if flags are True.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1651 if include_modalities:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1652 from report_utils import build_modalities_html
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1653 modalities_html = build_modalities_html(predictor, df_train, label_column)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1654 sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1655
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1656 if include_class_balance:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1657 from report_utils import build_class_balance_html
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1658 cb_html = build_class_balance_html(df_train, label_column)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1659 sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>")
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1660
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1661 return "\n".join(sections)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1662
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1663
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1664 def assemble_full_html_report(
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1665 summary_html: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1666 train_html: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1667 test_html: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1668 plots: List[str],
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1669 feature_html: str,
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1670 ) -> str:
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1671 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1672 Wrap the four tabs using utils.build_tabbed_html and return full HTML.
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1673 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1674 # Append plots under the Test tab (already wrapped with titles)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1675 test_full = test_html + "".join(plots)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1676
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1677 tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1678
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1679 html_out = get_html_template()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1680
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1681 # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False)
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1682 html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n'
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1683
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1684 # Optional: centering tweaks
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1685 html_out += """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1686 <style>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1687 .plotly-center { display: flex; justify-content: center; }
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1688 .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; }
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1689 .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; }
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1690 </style>
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1691 """
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1692 # Help modal HTML/JS
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1693 html_out += get_metrics_help_modal()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1694
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1695 html_out += tabs
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1696 html_out += get_html_closing()
375c36923da1 planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
goeckslab
parents:
diff changeset
1697 return html_out