comparison base_model_trainer.py @ 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000
parents ccd798db5abb
children
comparison
equal deleted inserted replaced
5:c846405830eb 6:a32ff7201629
5 5
6 import h5py 6 import h5py
7 import joblib 7 import joblib
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 from feature_help_modal import get_feature_metrics_help_modal
10 from feature_importance import FeatureImportanceAnalyzer 11 from feature_importance import FeatureImportanceAnalyzer
11 from sklearn.metrics import average_precision_score 12 from sklearn.metrics import average_precision_score
12 from utils import get_html_closing, get_html_template 13 from utils import get_html_closing, get_html_template
13 14
14 logging.basicConfig(level=logging.DEBUG) 15 logging.basicConfig(level=logging.DEBUG)
15 LOG = logging.getLogger(__name__) 16 LOG = logging.getLogger(__name__)
16 17
17 18
18 class BaseModelTrainer: 19 class BaseModelTrainer:
19
20 def __init__( 20 def __init__(
21 self, 21 self,
22 input_file, 22 input_file,
23 target_col, 23 target_col,
24 output_dir, 24 output_dir,
25 task_type, 25 task_type,
26 random_seed, 26 random_seed,
27 test_file=None, 27 test_file=None,
28 **kwargs): 28 **kwargs,
29 ):
29 self.exp = None # This will be set in the subclass 30 self.exp = None # This will be set in the subclass
30 self.input_file = input_file 31 self.input_file = input_file
31 self.target_col = target_col 32 self.target_col = target_col
32 self.output_dir = output_dir 33 self.output_dir = output_dir
33 self.task_type = task_type 34 self.task_type = task_type
45 setattr(self, key, value) 46 setattr(self, key, value)
46 self.setup_params = {} 47 self.setup_params = {}
47 self.test_file = test_file 48 self.test_file = test_file
48 self.test_data = None 49 self.test_data = None
49 50
51 if not self.output_dir:
52 raise ValueError("output_dir must be specified and not None")
53
50 LOG.info(f"Model kwargs: {self.__dict__}") 54 LOG.info(f"Model kwargs: {self.__dict__}")
51 55
52 def load_data(self): 56 def load_data(self):
53 LOG.info(f"Loading data from {self.input_file}") 57 LOG.info(f"Loading data from {self.input_file}")
54 self.data = pd.read_csv(self.input_file, sep=None, engine='python') 58 self.data = pd.read_csv(self.input_file, sep=None, engine="python")
55 self.data.columns = self.data.columns.str.replace('.', '_') 59 self.data.columns = self.data.columns.str.replace(".", "_")
56 60
57 numeric_cols = self.data.select_dtypes(include=['number']).columns 61 # Remove prediction_label if present
58 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns 62 if "prediction_label" in self.data.columns:
63 self.data = self.data.drop(columns=["prediction_label"])
64
65 numeric_cols = self.data.select_dtypes(include=["number"]).columns
66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
59 67
60 self.data[numeric_cols] = self.data[numeric_cols].apply( 68 self.data[numeric_cols] = self.data[numeric_cols].apply(
61 pd.to_numeric, errors='coerce') 69 pd.to_numeric, errors="coerce"
70 )
62 71
63 if len(non_numeric_cols) > 0: 72 if len(non_numeric_cols) > 0:
64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") 73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
65 74
66 names = self.data.columns.to_list() 75 names = self.data.columns.to_list()
67 target_index = int(self.target_col) - 1 76 target_index = int(self.target_col) - 1
68 self.target = names[target_index] 77 self.target = names[target_index]
69 self.features_name = [name 78 self.features_name = [name for i, name in enumerate(names) if i != target_index]
70 for i, name in enumerate(names) 79 if hasattr(self, "missing_value_strategy"):
71 if i != target_index] 80 if self.missing_value_strategy == "mean":
72 if hasattr(self, 'missing_value_strategy'): 81 self.data = self.data.fillna(self.data.mean(numeric_only=True))
73 if self.missing_value_strategy == 'mean': 82 elif self.missing_value_strategy == "median":
74 self.data = self.data.fillna( 83 self.data = self.data.fillna(self.data.median(numeric_only=True))
75 self.data.mean(numeric_only=True)) 84 elif self.missing_value_strategy == "drop":
76 elif self.missing_value_strategy == 'median':
77 self.data = self.data.fillna(
78 self.data.median(numeric_only=True))
79 elif self.missing_value_strategy == 'drop':
80 self.data = self.data.dropna() 85 self.data = self.data.dropna()
81 else: 86 else:
82 # Default strategy if not specified 87 # Default strategy if not specified
83 self.data = self.data.fillna(self.data.median(numeric_only=True)) 88 self.data = self.data.fillna(self.data.median(numeric_only=True))
84 89
85 if self.test_file: 90 if self.test_file:
86 LOG.info(f"Loading test data from {self.test_file}") 91 LOG.info(f"Loading test data from {self.test_file}")
87 self.test_data = pd.read_csv( 92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
88 self.test_file, sep=None, engine='python')
89 self.test_data = self.test_data[numeric_cols].apply( 93 self.test_data = self.test_data[numeric_cols].apply(
90 pd.to_numeric, errors='coerce') 94 pd.to_numeric, errors="coerce"
91 self.test_data.columns = self.test_data.columns.str.replace( 95 )
92 '.', '_' 96 self.test_data.columns = self.test_data.columns.str.replace(".", "_")
93 )
94 97
95 def setup_pycaret(self): 98 def setup_pycaret(self):
96 LOG.info("Initializing PyCaret") 99 LOG.info("Initializing PyCaret")
97 self.setup_params = { 100 self.setup_params = {
98 'target': self.target, 101 "target": self.target,
99 'session_id': self.random_seed, 102 "session_id": self.random_seed,
100 'html': True, 103 "html": True,
101 'log_experiment': False, 104 "log_experiment": False,
102 'system_log': False, 105 "system_log": False,
103 'index': False, 106 "index": False,
104 } 107 }
105 108
106 if self.test_data is not None: 109 if self.test_data is not None:
107 self.setup_params['test_data'] = self.test_data 110 self.setup_params["test_data"] = self.test_data
108 111
109 if hasattr(self, 'train_size') and self.train_size is not None \ 112 if (
110 and self.test_data is None: 113 hasattr(self, "train_size")
111 self.setup_params['train_size'] = self.train_size 114 and self.train_size is not None
112 115 and self.test_data is None
113 if hasattr(self, 'normalize') and self.normalize is not None: 116 ):
114 self.setup_params['normalize'] = self.normalize 117 self.setup_params["train_size"] = self.train_size
115 118
116 if hasattr(self, 'feature_selection') and \ 119 if hasattr(self, "normalize") and self.normalize is not None:
117 self.feature_selection is not None: 120 self.setup_params["normalize"] = self.normalize
118 self.setup_params['feature_selection'] = self.feature_selection 121
119 122 if hasattr(self, "feature_selection") and self.feature_selection is not None:
120 if hasattr(self, 'cross_validation') and \ 123 self.setup_params["feature_selection"] = self.feature_selection
121 self.cross_validation is not None \ 124
122 and self.cross_validation is False: 125 if (
123 self.setup_params['cross_validation'] = self.cross_validation 126 hasattr(self, "cross_validation")
124 127 and self.cross_validation is not None
125 if hasattr(self, 'cross_validation') and \ 128 and self.cross_validation is False
126 self.cross_validation is not None: 129 ):
127 if hasattr(self, 'cross_validation_folds'): 130 self.setup_params["cross_validation"] = self.cross_validation
128 self.setup_params['fold'] = self.cross_validation_folds 131
129 132 if hasattr(self, "cross_validation") and self.cross_validation is not None:
130 if hasattr(self, 'remove_outliers') and \ 133 if hasattr(self, "cross_validation_folds"):
131 self.remove_outliers is not None: 134 self.setup_params["fold"] = self.cross_validation_folds
132 self.setup_params['remove_outliers'] = self.remove_outliers 135
133 136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
134 if hasattr(self, 'remove_multicollinearity') and \ 137 self.setup_params["remove_outliers"] = self.remove_outliers
135 self.remove_multicollinearity is not None: 138
136 self.setup_params['remove_multicollinearity'] = \ 139 if (
140 hasattr(self, "remove_multicollinearity")
141 and self.remove_multicollinearity is not None
142 ):
143 self.setup_params["remove_multicollinearity"] = (
137 self.remove_multicollinearity 144 self.remove_multicollinearity
138 145 )
139 if hasattr(self, 'polynomial_features') and \ 146
140 self.polynomial_features is not None: 147 if (
141 self.setup_params['polynomial_features'] = self.polynomial_features 148 hasattr(self, "polynomial_features")
142 149 and self.polynomial_features is not None
143 if hasattr(self, 'fix_imbalance') and \ 150 ):
144 self.fix_imbalance is not None: 151 self.setup_params["polynomial_features"] = self.polynomial_features
145 self.setup_params['fix_imbalance'] = self.fix_imbalance 152
153 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
154 self.setup_params["fix_imbalance"] = self.fix_imbalance
146 155
147 LOG.info(self.setup_params) 156 LOG.info(self.setup_params)
157
158 # Solution: instantiate the correct PyCaret experiment based on task_type
159 if self.task_type == "classification":
160 from pycaret.classification import ClassificationExperiment
161
162 self.exp = ClassificationExperiment()
163 elif self.task_type == "regression":
164 from pycaret.regression import RegressionExperiment
165
166 self.exp = RegressionExperiment()
167 else:
168 raise ValueError("task_type must be 'classification' or 'regression'")
169
148 self.exp.setup(self.data, **self.setup_params) 170 self.exp.setup(self.data, **self.setup_params)
149 171
150 def train_model(self): 172 def train_model(self):
151 LOG.info("Training and selecting the best model") 173 LOG.info("Training and selecting the best model")
152 if self.task_type == "classification": 174 if self.task_type == "classification":
153 average_displayed = "Weighted" 175 average_displayed = "Weighted"
154 self.exp.add_metric(id=f'PR-AUC-{average_displayed}', 176 self.exp.add_metric(
155 name=f'PR-AUC-{average_displayed}', 177 id=f"PR-AUC-{average_displayed}",
156 target='pred_proba', 178 name=f"PR-AUC-{average_displayed}",
157 score_func=average_precision_score, 179 target="pred_proba",
158 average='weighted' 180 score_func=average_precision_score,
159 ) 181 average="weighted",
160 182 )
161 if hasattr(self, 'models') and self.models is not None: 183
162 self.best_model = self.exp.compare_models( 184 if hasattr(self, "models") and self.models is not None:
163 include=self.models) 185 self.best_model = self.exp.compare_models(include=self.models)
164 else: 186 else:
165 self.best_model = self.exp.compare_models() 187 self.best_model = self.exp.compare_models()
166 self.results = self.exp.pull() 188 self.results = self.exp.pull()
167 if self.task_type == "classification": 189 if self.task_type == "classification":
168 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True) 190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
169 191
170 _ = self.exp.predict_model(self.best_model) 192 _ = self.exp.predict_model(self.best_model)
171 self.test_result_df = self.exp.pull() 193 self.test_result_df = self.exp.pull()
172 if self.task_type == "classification": 194 if self.task_type == "classification":
173 self.test_result_df.rename( 195 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
174 columns={'AUC': 'ROC-AUC'}, inplace=True)
175 196
176 def save_model(self): 197 def save_model(self):
177 hdf5_model_path = "pycaret_model.h5" 198 hdf5_model_path = "pycaret_model.h5"
178 with h5py.File(hdf5_model_path, 'w') as f: 199 with h5py.File(hdf5_model_path, "w") as f:
179 with tempfile.NamedTemporaryFile(delete=False) as temp_file: 200 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
180 joblib.dump(self.best_model, temp_file.name) 201 joblib.dump(self.best_model, temp_file.name)
181 temp_file.seek(0) 202 temp_file.seek(0)
182 model_bytes = temp_file.read() 203 model_bytes = temp_file.read()
183 f.create_dataset('model', data=np.void(model_bytes)) 204 f.create_dataset("model", data=np.void(model_bytes))
184 205
185 def generate_plots(self): 206 def generate_plots(self):
186 raise NotImplementedError("Subclasses should implement this method") 207 raise NotImplementedError("Subclasses should implement this method")
187 208
188 def encode_image_to_base64(self, img_path): 209 def encode_image_to_base64(self, img_path):
189 with open(img_path, 'rb') as img_file: 210 with open(img_path, "rb") as img_file:
190 return base64.b64encode(img_file.read()).decode('utf-8') 211 return base64.b64encode(img_file.read()).decode("utf-8")
191 212
192 def save_html_report(self): 213 def save_html_report(self):
193 LOG.info("Saving HTML report") 214 LOG.info("Saving HTML report")
194 215
216 if not self.output_dir:
217 raise ValueError("output_dir must be specified and not None")
218
195 model_name = type(self.best_model).__name__ 219 model_name = type(self.best_model).__name__
196 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data'] 220 excluded_params = ["html", "log_experiment", "system_log", "test_data"]
197 filtered_setup_params = { 221 filtered_setup_params = {
198 k: v 222 k: v for k, v in self.setup_params.items() if k not in excluded_params
199 for k, v in self.setup_params.items() if k not in excluded_params
200 } 223 }
201 setup_params_table = pd.DataFrame( 224 setup_params_table = pd.DataFrame(
202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] 225 list(filtered_setup_params.items()), columns=["Parameter", "Value"]
203 ) 226 )
204 227
205 best_model_params = pd.DataFrame( 228 best_model_params = pd.DataFrame(
206 self.best_model.get_params().items(), 229 self.best_model.get_params().items(), columns=["Parameter", "Value"]
207 columns=['Parameter', 'Value']
208 ) 230 )
209 best_model_params.to_csv( 231 best_model_params.to_csv(
210 os.path.join(self.output_dir, "best_model.csv"), index=False 232 os.path.join(self.output_dir, "best_model.csv"), index=False
211 ) 233 )
212 self.results.to_csv( 234 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
213 os.path.join(self.output_dir, "comparison_results.csv") 235 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))
214 )
215 self.test_result_df.to_csv(
216 os.path.join(self.output_dir, "test_results.csv")
217 )
218 236
219 plots_html = "" 237 plots_html = ""
220 length = len(self.plots) 238 length = len(self.plots)
221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): 239 for i, (plot_name, plot_path) in enumerate(self.plots.items()):
222 encoded_image = self.encode_image_to_base64(plot_path) 240 encoded_image = self.encode_image_to_base64(plot_path)
223 plots_html += f""" 241 plots_html += (
224 <div class="plot"> 242 f'<div class="plot">'
225 <h3>{plot_name.capitalize()}</h3> 243 f"<h3>{plot_name.capitalize()}</h3>"
226 <img src="data:image/png;base64,{encoded_image}" 244 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
227 alt="{plot_name}"> 245 f"</div>"
228 </div> 246 )
229 """
230 if i < length - 1: 247 if i < length - 1:
231 plots_html += "<hr>" 248 plots_html += "<hr>"
232 249
233 tree_plots = "" 250 tree_plots = ""
234 for i, tree in enumerate(self.trees): 251 for i, tree in enumerate(self.trees):
235 if tree: 252 if tree:
236 tree_plots += f""" 253 tree_plots += (
237 <div class="plot"> 254 f'<div class="plot">'
238 <h3>Tree {i+1}</h3> 255 f"<h3>Tree {i + 1}</h3>"
239 <img src="data:image/png;base64, 256 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
240 {tree}" 257 f"</div>"
241 alt="tree {i+1}"> 258 )
242 </div>
243 """
244 259
245 analyzer = FeatureImportanceAnalyzer( 260 analyzer = FeatureImportanceAnalyzer(
246 data=self.data, 261 data=self.data,
247 target_col=self.target_col, 262 target_col=self.target_col,
248 task_type=self.task_type, 263 task_type=self.task_type,
249 output_dir=self.output_dir, 264 output_dir=self.output_dir,
265 exp=self.exp,
266 best_model=self.best_model,
250 ) 267 )
251 feature_importance_html = analyzer.run() 268 feature_importance_html = analyzer.run()
252 269
253 html_content = f""" 270 # --- Feature Metrics Help Button ---
254 {get_html_template()} 271 feature_metrics_button_html = (
255 <h1>PyCaret Model Training Report</h1> 272 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
256 <div class="tabs"> 273 "Help: Metrics Guide"
257 <div class="tab" onclick="openTab(event, 'summary')"> 274 "</button>"
258 Setup & Best Model</div> 275 "<style>"
259 <div class="tab" onclick="openTab(event, 'plots')"> 276 ".help-modal-btn {"
260 Best Model Plots</div> 277 "background-color: #17623b;"
261 <div class="tab" onclick="openTab(event, 'feature')"> 278 "color: #fff;"
262 Feature Importance</div> 279 "border: none;"
263 """ 280 "border-radius: 24px;"
281 "padding: 10px 28px;"
282 "font-size: 1.1rem;"
283 "font-weight: bold;"
284 "letter-spacing: 0.03em;"
285 "cursor: pointer;"
286 "transition: background 0.2s, box-shadow 0.2s;"
287 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
288 "}"
289 ".help-modal-btn:hover, .help-modal-btn:focus {"
290 "background-color: #21895e;"
291 "outline: none;"
292 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
293 "}"
294 "</style>"
295 )
296
297 html_content = (
298 f"{get_html_template()}"
299 "<h1>Tabular Learner Model Report</h1>"
300 f"{feature_metrics_button_html}"
301 '<div class="tabs">'
302 '<div class="tab" onclick="openTab(event, \'summary\')">'
303 "Validation Result Summary & Config</div>"
304 '<div class="tab" onclick="openTab(event, \'plots\')">'
305 "Test Results</div>"
306 '<div class="tab" onclick="openTab(event, \'feature\')">'
307 "Feature Importance</div>"
308 )
264 if self.plots_explainer_html: 309 if self.plots_explainer_html:
265 html_content += """ 310 html_content += (
266 <div class="tab" onclick="openTab(event, 'explainer')"> 311 '<div class="tab" onclick="openTab(event, \'explainer\')">'
267 Explainer Plots</div> 312 "Explainer Plots</div>"
268 """ 313 )
269 html_content += f""" 314 html_content += (
270 </div> 315 "</div>"
271 <div id="summary" class="tab-content"> 316 '<div id="summary" class="tab-content">'
272 <h2>Setup Parameters</h2> 317 "<h2>Model Metrics from Cross-Validation Set</h2>"
273 {setup_params_table.to_html( 318 f"<h2>Best Model: {model_name}</h2>"
274 index=False, 319 "<h5>The best model is selected by: Accuracy (Classification)"
275 header=True, 320 " or R2 (Regression).</h5>"
276 classes='table sortable' 321 f"{self.results.to_html(index=False, classes='table sortable')}"
277 )} 322 "<h2>Best Model's Hyperparameters</h2>"
278 <h5>If you want to know all the experiment setup parameters, 323 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
279 please check the PyCaret documentation for 324 "<h2>Setup Parameters</h2>"
280 the classification/regression <code>exp</code> function.</h5> 325 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
281 <h2>Best Model: {model_name}</h2> 326 "<h5>If you want to know all the experiment setup parameters,"
282 {best_model_params.to_html( 327 " please check the PyCaret documentation for"
283 index=False, 328 " the classification/regression <code>exp</code> function.</h5>"
284 header=True, 329 "</div>"
285 classes='table sortable' 330 '<div id="plots" class="tab-content">'
286 )} 331 f"<h2>Best Model: {model_name}</h2>"
287 <h2>Comparison Results on the Cross-Validation Set</h2> 332 "<h5>The best model is selected by: Accuracy (Classification)"
288 {self.results.to_html(index=False, classes='table sortable')} 333 " or R2 (Regression).</h5>"
289 <h2>Results on the Test Set for the best model</h2> 334 "<h2>Test Metrics</h2>"
290 {self.test_result_df.to_html( 335 f"{self.test_result_df.to_html(index=False)}"
291 index=False, 336 "<h2>Test Results</h2>"
292 classes='table sortable' 337 f"{plots_html}"
293 )} 338 "</div>"
294 </div> 339 '<div id="feature" class="tab-content">'
295 <div id="plots" class="tab-content"> 340 f"{feature_importance_html}"
296 <h2>Best Model Plots on the testing set</h2> 341 "</div>"
297 {plots_html} 342 )
298 </div>
299 <div id="feature" class="tab-content">
300 {feature_importance_html}
301 </div>
302 """
303 if self.plots_explainer_html: 343 if self.plots_explainer_html:
304 html_content += f""" 344 html_content += (
305 <div id="explainer" class="tab-content"> 345 '<div id="explainer" class="tab-content">'
306 {self.plots_explainer_html} 346 f"{self.plots_explainer_html}"
307 {tree_plots} 347 f"{tree_plots}"
308 </div> 348 "</div>"
309 """ 349 )
310 html_content += """ 350 html_content += (
311 <script> 351 "<script>"
312 document.addEventListener("DOMContentLoaded", function() { 352 "document.addEventListener(\"DOMContentLoaded\", function() {"
313 var tables = document.querySelectorAll("table.sortable"); 353 "var tables = document.querySelectorAll(\"table.sortable\");"
314 tables.forEach(function(table) { 354 "tables.forEach(function(table) {"
315 var headers = table.querySelectorAll("th"); 355 "var headers = table.querySelectorAll(\"th\");"
316 headers.forEach(function(header, index) { 356 "headers.forEach(function(header, index) {"
317 header.style.cursor = "pointer"; 357 "header.style.cursor = \"pointer\";"
318 // Add initial arrow (up) to indicate sortability 358 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; 359 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
320 header.addEventListener("click", function() { 360 "header.addEventListener(\"click\", function() {"
321 var direction = this.getAttribute( 361 "var direction = this.getAttribute("
322 "data-sort-direction" 362 "\"data-sort-direction\""
323 ) || "asc"; 363 ") || \"asc\";"
324 // Reset arrows in all headers of this table 364 "// Reset arrows in all headers of this table"
325 headers.forEach(function(h) { 365 "headers.forEach(function(h) {"
326 var arrow = h.querySelector(".sort-arrow"); 366 "var arrow = h.querySelector(\".sort-arrow\");"
327 if (arrow) arrow.textContent = " ↑"; 367 "if (arrow) arrow.textContent = \" ↑\";"
328 }); 368 "});"
329 // Set arrow for clicked header 369 "// Set arrow for clicked header"
330 var arrow = this.querySelector(".sort-arrow"); 370 "var arrow = this.querySelector(\".sort-arrow\");"
331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; 371 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
332 sortTable(table, index, direction); 372 "sortTable(table, index, direction);"
333 this.setAttribute("data-sort-direction", 373 "this.setAttribute(\"data-sort-direction\","
334 direction === "asc" ? "desc" : "asc"); 374 "direction === \"asc\" ? \"desc\" : \"asc\");"
335 }); 375 "});"
336 }); 376 "});"
337 }); 377 "});"
338 }); 378 "});"
339 379 "function sortTable(table, colNum, direction) {"
340 function sortTable(table, colNum, direction) { 380 "var tb = table.tBodies[0];"
341 var tb = table.tBodies[0]; 381 "var tr = Array.prototype.slice.call(tb.rows, 0);"
342 var tr = Array.prototype.slice.call(tb.rows, 0); 382 "var multiplier = direction === \"asc\" ? 1 : -1;"
343 var multiplier = direction === "asc" ? 1 : -1; 383 "tr = tr.sort(function(a, b) {"
344 tr = tr.sort(function(a, b) { 384 "var aText = a.cells[colNum].textContent.trim();"
345 var aText = a.cells[colNum].textContent.trim(); 385 "var bText = b.cells[colNum].textContent.trim();"
346 var bText = b.cells[colNum].textContent.trim(); 386 "// Remove arrow from text comparison"
347 // Remove arrow from text comparison 387 "aText = aText.replace(/[↑↓]/g, '').trim();"
348 aText = aText.replace(/[↑↓]/g, '').trim(); 388 "bText = bText.replace(/[↑↓]/g, '').trim();"
349 bText = bText.replace(/[↑↓]/g, '').trim(); 389 "if (!isNaN(aText) && !isNaN(bText)) {"
350 if (!isNaN(aText) && !isNaN(bText)) { 390 "return multiplier * ("
351 return multiplier * ( 391 "parseFloat(aText) - parseFloat(bText)"
352 parseFloat(aText) - parseFloat(bText) 392 ");"
353 ); 393 "} else {"
354 } else { 394 "return multiplier * aText.localeCompare(bText);"
355 return multiplier * aText.localeCompare(bText); 395 "}"
356 } 396 "});"
357 }); 397 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); 398 "}"
359 } 399 "</script>"
360 </script> 400 )
361 """ 401 # --- Add the Feature Metrics Help Modal ---
362 html_content += f""" 402 html_content += get_feature_metrics_help_modal()
363 {get_html_closing()} 403 html_content += f"{get_html_closing()}"
364 """
365 with open( 404 with open(
366 os.path.join(self.output_dir, "comparison_result.html"), 405 os.path.join(self.output_dir, "comparison_result.html"),
367 "w" 406 "w",
407 encoding="utf-8",
368 ) as file: 408 ) as file:
369 file.write(html_content) 409 file.write(html_content)
370 410
371 def save_dashboard(self): 411 def save_dashboard(self):
372 raise NotImplementedError("Subclasses should implement this method") 412 raise NotImplementedError("Subclasses should implement this method")
373 413
374 def generate_plots_explainer(self): 414 def generate_plots_explainer(self):
375 raise NotImplementedError("Subclasses should implement this method") 415 raise NotImplementedError("Subclasses should implement this method")
376 416
377 # not working now
378 def generate_tree_plots(self): 417 def generate_tree_plots(self):
379 from sklearn.ensemble import RandomForestClassifier, \ 418 from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
380 RandomForestRegressor
381 from xgboost import XGBClassifier, XGBRegressor 419 from xgboost import XGBClassifier, XGBRegressor
382 from explainerdashboard.explainers import RandomForestExplainer 420 from explainerdashboard.explainers import RandomForestExplainer
383 421
384 LOG.info("Generating tree plots") 422 LOG.info("Generating tree plots")
385 X_test = self.exp.X_test_transformed.copy() 423 X_test = self.exp.X_test_transformed.copy()
386 y_test = self.exp.y_test_transformed 424 y_test = self.exp.y_test_transformed
387 425
388 is_rf = isinstance(self.best_model, RandomForestClassifier) or \ 426 is_rf = isinstance(
389 isinstance(self.best_model, RandomForestRegressor) 427 self.best_model, (RandomForestClassifier, RandomForestRegressor)
390 428 )
391 is_xgb = isinstance(self.best_model, XGBClassifier) or \ 429 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))
392 isinstance(self.best_model, XGBRegressor) 430
431 num_trees = None
432 if is_rf:
433 num_trees = self.best_model.n_estimators
434 elif is_xgb:
435 num_trees = len(self.best_model.get_booster().get_dump())
436 else:
437 LOG.warning("Tree plots not supported for this model type.")
438 return
393 439
394 try: 440 try:
395 if is_rf:
396 num_trees = self.best_model.n_estimators
397 if is_xgb:
398 num_trees = len(self.best_model.get_booster().get_dump())
399 explainer = RandomForestExplainer(self.best_model, X_test, y_test) 441 explainer = RandomForestExplainer(self.best_model, X_test, y_test)
400 for i in range(num_trees): 442 for i in range(num_trees):
401 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) 443 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
402 LOG.info(f"Tree {i+1}") 444 LOG.info(f"Tree {i + 1}")
403 LOG.info(fig) 445 LOG.info(fig)
404 self.trees.append(fig) 446 self.trees.append(fig)
405 except Exception as e: 447 except Exception as e:
406 LOG.error(f"Error generating tree plots: {e}") 448 LOG.error(f"Error generating tree plots: {e}")
407 449