Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 6:a32ff7201629 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 19:00:03 +0000 |
parents | ccd798db5abb |
children |
comparison
equal
deleted
inserted
replaced
5:c846405830eb | 6:a32ff7201629 |
---|---|
5 | 5 |
6 import h5py | 6 import h5py |
7 import joblib | 7 import joblib |
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 from feature_help_modal import get_feature_metrics_help_modal | |
10 from feature_importance import FeatureImportanceAnalyzer | 11 from feature_importance import FeatureImportanceAnalyzer |
11 from sklearn.metrics import average_precision_score | 12 from sklearn.metrics import average_precision_score |
12 from utils import get_html_closing, get_html_template | 13 from utils import get_html_closing, get_html_template |
13 | 14 |
14 logging.basicConfig(level=logging.DEBUG) | 15 logging.basicConfig(level=logging.DEBUG) |
15 LOG = logging.getLogger(__name__) | 16 LOG = logging.getLogger(__name__) |
16 | 17 |
17 | 18 |
18 class BaseModelTrainer: | 19 class BaseModelTrainer: |
19 | |
20 def __init__( | 20 def __init__( |
21 self, | 21 self, |
22 input_file, | 22 input_file, |
23 target_col, | 23 target_col, |
24 output_dir, | 24 output_dir, |
25 task_type, | 25 task_type, |
26 random_seed, | 26 random_seed, |
27 test_file=None, | 27 test_file=None, |
28 **kwargs): | 28 **kwargs, |
29 ): | |
29 self.exp = None # This will be set in the subclass | 30 self.exp = None # This will be set in the subclass |
30 self.input_file = input_file | 31 self.input_file = input_file |
31 self.target_col = target_col | 32 self.target_col = target_col |
32 self.output_dir = output_dir | 33 self.output_dir = output_dir |
33 self.task_type = task_type | 34 self.task_type = task_type |
45 setattr(self, key, value) | 46 setattr(self, key, value) |
46 self.setup_params = {} | 47 self.setup_params = {} |
47 self.test_file = test_file | 48 self.test_file = test_file |
48 self.test_data = None | 49 self.test_data = None |
49 | 50 |
51 if not self.output_dir: | |
52 raise ValueError("output_dir must be specified and not None") | |
53 | |
50 LOG.info(f"Model kwargs: {self.__dict__}") | 54 LOG.info(f"Model kwargs: {self.__dict__}") |
51 | 55 |
52 def load_data(self): | 56 def load_data(self): |
53 LOG.info(f"Loading data from {self.input_file}") | 57 LOG.info(f"Loading data from {self.input_file}") |
54 self.data = pd.read_csv(self.input_file, sep=None, engine='python') | 58 self.data = pd.read_csv(self.input_file, sep=None, engine="python") |
55 self.data.columns = self.data.columns.str.replace('.', '_') | 59 self.data.columns = self.data.columns.str.replace(".", "_") |
56 | 60 |
57 numeric_cols = self.data.select_dtypes(include=['number']).columns | 61 # Remove prediction_label if present |
58 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns | 62 if "prediction_label" in self.data.columns: |
63 self.data = self.data.drop(columns=["prediction_label"]) | |
64 | |
65 numeric_cols = self.data.select_dtypes(include=["number"]).columns | |
66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns | |
59 | 67 |
60 self.data[numeric_cols] = self.data[numeric_cols].apply( | 68 self.data[numeric_cols] = self.data[numeric_cols].apply( |
61 pd.to_numeric, errors='coerce') | 69 pd.to_numeric, errors="coerce" |
70 ) | |
62 | 71 |
63 if len(non_numeric_cols) > 0: | 72 if len(non_numeric_cols) > 0: |
64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
65 | 74 |
66 names = self.data.columns.to_list() | 75 names = self.data.columns.to_list() |
67 target_index = int(self.target_col) - 1 | 76 target_index = int(self.target_col) - 1 |
68 self.target = names[target_index] | 77 self.target = names[target_index] |
69 self.features_name = [name | 78 self.features_name = [name for i, name in enumerate(names) if i != target_index] |
70 for i, name in enumerate(names) | 79 if hasattr(self, "missing_value_strategy"): |
71 if i != target_index] | 80 if self.missing_value_strategy == "mean": |
72 if hasattr(self, 'missing_value_strategy'): | 81 self.data = self.data.fillna(self.data.mean(numeric_only=True)) |
73 if self.missing_value_strategy == 'mean': | 82 elif self.missing_value_strategy == "median": |
74 self.data = self.data.fillna( | 83 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
75 self.data.mean(numeric_only=True)) | 84 elif self.missing_value_strategy == "drop": |
76 elif self.missing_value_strategy == 'median': | |
77 self.data = self.data.fillna( | |
78 self.data.median(numeric_only=True)) | |
79 elif self.missing_value_strategy == 'drop': | |
80 self.data = self.data.dropna() | 85 self.data = self.data.dropna() |
81 else: | 86 else: |
82 # Default strategy if not specified | 87 # Default strategy if not specified |
83 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 88 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
84 | 89 |
85 if self.test_file: | 90 if self.test_file: |
86 LOG.info(f"Loading test data from {self.test_file}") | 91 LOG.info(f"Loading test data from {self.test_file}") |
87 self.test_data = pd.read_csv( | 92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") |
88 self.test_file, sep=None, engine='python') | |
89 self.test_data = self.test_data[numeric_cols].apply( | 93 self.test_data = self.test_data[numeric_cols].apply( |
90 pd.to_numeric, errors='coerce') | 94 pd.to_numeric, errors="coerce" |
91 self.test_data.columns = self.test_data.columns.str.replace( | 95 ) |
92 '.', '_' | 96 self.test_data.columns = self.test_data.columns.str.replace(".", "_") |
93 ) | |
94 | 97 |
95 def setup_pycaret(self): | 98 def setup_pycaret(self): |
96 LOG.info("Initializing PyCaret") | 99 LOG.info("Initializing PyCaret") |
97 self.setup_params = { | 100 self.setup_params = { |
98 'target': self.target, | 101 "target": self.target, |
99 'session_id': self.random_seed, | 102 "session_id": self.random_seed, |
100 'html': True, | 103 "html": True, |
101 'log_experiment': False, | 104 "log_experiment": False, |
102 'system_log': False, | 105 "system_log": False, |
103 'index': False, | 106 "index": False, |
104 } | 107 } |
105 | 108 |
106 if self.test_data is not None: | 109 if self.test_data is not None: |
107 self.setup_params['test_data'] = self.test_data | 110 self.setup_params["test_data"] = self.test_data |
108 | 111 |
109 if hasattr(self, 'train_size') and self.train_size is not None \ | 112 if ( |
110 and self.test_data is None: | 113 hasattr(self, "train_size") |
111 self.setup_params['train_size'] = self.train_size | 114 and self.train_size is not None |
112 | 115 and self.test_data is None |
113 if hasattr(self, 'normalize') and self.normalize is not None: | 116 ): |
114 self.setup_params['normalize'] = self.normalize | 117 self.setup_params["train_size"] = self.train_size |
115 | 118 |
116 if hasattr(self, 'feature_selection') and \ | 119 if hasattr(self, "normalize") and self.normalize is not None: |
117 self.feature_selection is not None: | 120 self.setup_params["normalize"] = self.normalize |
118 self.setup_params['feature_selection'] = self.feature_selection | 121 |
119 | 122 if hasattr(self, "feature_selection") and self.feature_selection is not None: |
120 if hasattr(self, 'cross_validation') and \ | 123 self.setup_params["feature_selection"] = self.feature_selection |
121 self.cross_validation is not None \ | 124 |
122 and self.cross_validation is False: | 125 if ( |
123 self.setup_params['cross_validation'] = self.cross_validation | 126 hasattr(self, "cross_validation") |
124 | 127 and self.cross_validation is not None |
125 if hasattr(self, 'cross_validation') and \ | 128 and self.cross_validation is False |
126 self.cross_validation is not None: | 129 ): |
127 if hasattr(self, 'cross_validation_folds'): | 130 self.setup_params["cross_validation"] = self.cross_validation |
128 self.setup_params['fold'] = self.cross_validation_folds | 131 |
129 | 132 if hasattr(self, "cross_validation") and self.cross_validation is not None: |
130 if hasattr(self, 'remove_outliers') and \ | 133 if hasattr(self, "cross_validation_folds"): |
131 self.remove_outliers is not None: | 134 self.setup_params["fold"] = self.cross_validation_folds |
132 self.setup_params['remove_outliers'] = self.remove_outliers | 135 |
133 | 136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: |
134 if hasattr(self, 'remove_multicollinearity') and \ | 137 self.setup_params["remove_outliers"] = self.remove_outliers |
135 self.remove_multicollinearity is not None: | 138 |
136 self.setup_params['remove_multicollinearity'] = \ | 139 if ( |
140 hasattr(self, "remove_multicollinearity") | |
141 and self.remove_multicollinearity is not None | |
142 ): | |
143 self.setup_params["remove_multicollinearity"] = ( | |
137 self.remove_multicollinearity | 144 self.remove_multicollinearity |
138 | 145 ) |
139 if hasattr(self, 'polynomial_features') and \ | 146 |
140 self.polynomial_features is not None: | 147 if ( |
141 self.setup_params['polynomial_features'] = self.polynomial_features | 148 hasattr(self, "polynomial_features") |
142 | 149 and self.polynomial_features is not None |
143 if hasattr(self, 'fix_imbalance') and \ | 150 ): |
144 self.fix_imbalance is not None: | 151 self.setup_params["polynomial_features"] = self.polynomial_features |
145 self.setup_params['fix_imbalance'] = self.fix_imbalance | 152 |
153 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None: | |
154 self.setup_params["fix_imbalance"] = self.fix_imbalance | |
146 | 155 |
147 LOG.info(self.setup_params) | 156 LOG.info(self.setup_params) |
157 | |
158 # Solution: instantiate the correct PyCaret experiment based on task_type | |
159 if self.task_type == "classification": | |
160 from pycaret.classification import ClassificationExperiment | |
161 | |
162 self.exp = ClassificationExperiment() | |
163 elif self.task_type == "regression": | |
164 from pycaret.regression import RegressionExperiment | |
165 | |
166 self.exp = RegressionExperiment() | |
167 else: | |
168 raise ValueError("task_type must be 'classification' or 'regression'") | |
169 | |
148 self.exp.setup(self.data, **self.setup_params) | 170 self.exp.setup(self.data, **self.setup_params) |
149 | 171 |
150 def train_model(self): | 172 def train_model(self): |
151 LOG.info("Training and selecting the best model") | 173 LOG.info("Training and selecting the best model") |
152 if self.task_type == "classification": | 174 if self.task_type == "classification": |
153 average_displayed = "Weighted" | 175 average_displayed = "Weighted" |
154 self.exp.add_metric(id=f'PR-AUC-{average_displayed}', | 176 self.exp.add_metric( |
155 name=f'PR-AUC-{average_displayed}', | 177 id=f"PR-AUC-{average_displayed}", |
156 target='pred_proba', | 178 name=f"PR-AUC-{average_displayed}", |
157 score_func=average_precision_score, | 179 target="pred_proba", |
158 average='weighted' | 180 score_func=average_precision_score, |
159 ) | 181 average="weighted", |
160 | 182 ) |
161 if hasattr(self, 'models') and self.models is not None: | 183 |
162 self.best_model = self.exp.compare_models( | 184 if hasattr(self, "models") and self.models is not None: |
163 include=self.models) | 185 self.best_model = self.exp.compare_models(include=self.models) |
164 else: | 186 else: |
165 self.best_model = self.exp.compare_models() | 187 self.best_model = self.exp.compare_models() |
166 self.results = self.exp.pull() | 188 self.results = self.exp.pull() |
167 if self.task_type == "classification": | 189 if self.task_type == "classification": |
168 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True) | 190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
169 | 191 |
170 _ = self.exp.predict_model(self.best_model) | 192 _ = self.exp.predict_model(self.best_model) |
171 self.test_result_df = self.exp.pull() | 193 self.test_result_df = self.exp.pull() |
172 if self.task_type == "classification": | 194 if self.task_type == "classification": |
173 self.test_result_df.rename( | 195 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
174 columns={'AUC': 'ROC-AUC'}, inplace=True) | |
175 | 196 |
176 def save_model(self): | 197 def save_model(self): |
177 hdf5_model_path = "pycaret_model.h5" | 198 hdf5_model_path = "pycaret_model.h5" |
178 with h5py.File(hdf5_model_path, 'w') as f: | 199 with h5py.File(hdf5_model_path, "w") as f: |
179 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | 200 with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
180 joblib.dump(self.best_model, temp_file.name) | 201 joblib.dump(self.best_model, temp_file.name) |
181 temp_file.seek(0) | 202 temp_file.seek(0) |
182 model_bytes = temp_file.read() | 203 model_bytes = temp_file.read() |
183 f.create_dataset('model', data=np.void(model_bytes)) | 204 f.create_dataset("model", data=np.void(model_bytes)) |
184 | 205 |
185 def generate_plots(self): | 206 def generate_plots(self): |
186 raise NotImplementedError("Subclasses should implement this method") | 207 raise NotImplementedError("Subclasses should implement this method") |
187 | 208 |
188 def encode_image_to_base64(self, img_path): | 209 def encode_image_to_base64(self, img_path): |
189 with open(img_path, 'rb') as img_file: | 210 with open(img_path, "rb") as img_file: |
190 return base64.b64encode(img_file.read()).decode('utf-8') | 211 return base64.b64encode(img_file.read()).decode("utf-8") |
191 | 212 |
192 def save_html_report(self): | 213 def save_html_report(self): |
193 LOG.info("Saving HTML report") | 214 LOG.info("Saving HTML report") |
194 | 215 |
216 if not self.output_dir: | |
217 raise ValueError("output_dir must be specified and not None") | |
218 | |
195 model_name = type(self.best_model).__name__ | 219 model_name = type(self.best_model).__name__ |
196 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data'] | 220 excluded_params = ["html", "log_experiment", "system_log", "test_data"] |
197 filtered_setup_params = { | 221 filtered_setup_params = { |
198 k: v | 222 k: v for k, v in self.setup_params.items() if k not in excluded_params |
199 for k, v in self.setup_params.items() if k not in excluded_params | |
200 } | 223 } |
201 setup_params_table = pd.DataFrame( | 224 setup_params_table = pd.DataFrame( |
202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] | 225 list(filtered_setup_params.items()), columns=["Parameter", "Value"] |
203 ) | 226 ) |
204 | 227 |
205 best_model_params = pd.DataFrame( | 228 best_model_params = pd.DataFrame( |
206 self.best_model.get_params().items(), | 229 self.best_model.get_params().items(), columns=["Parameter", "Value"] |
207 columns=['Parameter', 'Value'] | |
208 ) | 230 ) |
209 best_model_params.to_csv( | 231 best_model_params.to_csv( |
210 os.path.join(self.output_dir, "best_model.csv"), index=False | 232 os.path.join(self.output_dir, "best_model.csv"), index=False |
211 ) | 233 ) |
212 self.results.to_csv( | 234 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) |
213 os.path.join(self.output_dir, "comparison_results.csv") | 235 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) |
214 ) | |
215 self.test_result_df.to_csv( | |
216 os.path.join(self.output_dir, "test_results.csv") | |
217 ) | |
218 | 236 |
219 plots_html = "" | 237 plots_html = "" |
220 length = len(self.plots) | 238 length = len(self.plots) |
221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 239 for i, (plot_name, plot_path) in enumerate(self.plots.items()): |
222 encoded_image = self.encode_image_to_base64(plot_path) | 240 encoded_image = self.encode_image_to_base64(plot_path) |
223 plots_html += f""" | 241 plots_html += ( |
224 <div class="plot"> | 242 f'<div class="plot">' |
225 <h3>{plot_name.capitalize()}</h3> | 243 f"<h3>{plot_name.capitalize()}</h3>" |
226 <img src="data:image/png;base64,{encoded_image}" | 244 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' |
227 alt="{plot_name}"> | 245 f"</div>" |
228 </div> | 246 ) |
229 """ | |
230 if i < length - 1: | 247 if i < length - 1: |
231 plots_html += "<hr>" | 248 plots_html += "<hr>" |
232 | 249 |
233 tree_plots = "" | 250 tree_plots = "" |
234 for i, tree in enumerate(self.trees): | 251 for i, tree in enumerate(self.trees): |
235 if tree: | 252 if tree: |
236 tree_plots += f""" | 253 tree_plots += ( |
237 <div class="plot"> | 254 f'<div class="plot">' |
238 <h3>Tree {i+1}</h3> | 255 f"<h3>Tree {i + 1}</h3>" |
239 <img src="data:image/png;base64, | 256 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' |
240 {tree}" | 257 f"</div>" |
241 alt="tree {i+1}"> | 258 ) |
242 </div> | |
243 """ | |
244 | 259 |
245 analyzer = FeatureImportanceAnalyzer( | 260 analyzer = FeatureImportanceAnalyzer( |
246 data=self.data, | 261 data=self.data, |
247 target_col=self.target_col, | 262 target_col=self.target_col, |
248 task_type=self.task_type, | 263 task_type=self.task_type, |
249 output_dir=self.output_dir, | 264 output_dir=self.output_dir, |
265 exp=self.exp, | |
266 best_model=self.best_model, | |
250 ) | 267 ) |
251 feature_importance_html = analyzer.run() | 268 feature_importance_html = analyzer.run() |
252 | 269 |
253 html_content = f""" | 270 # --- Feature Metrics Help Button --- |
254 {get_html_template()} | 271 feature_metrics_button_html = ( |
255 <h1>PyCaret Model Training Report</h1> | 272 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">' |
256 <div class="tabs"> | 273 "Help: Metrics Guide" |
257 <div class="tab" onclick="openTab(event, 'summary')"> | 274 "</button>" |
258 Setup & Best Model</div> | 275 "<style>" |
259 <div class="tab" onclick="openTab(event, 'plots')"> | 276 ".help-modal-btn {" |
260 Best Model Plots</div> | 277 "background-color: #17623b;" |
261 <div class="tab" onclick="openTab(event, 'feature')"> | 278 "color: #fff;" |
262 Feature Importance</div> | 279 "border: none;" |
263 """ | 280 "border-radius: 24px;" |
281 "padding: 10px 28px;" | |
282 "font-size: 1.1rem;" | |
283 "font-weight: bold;" | |
284 "letter-spacing: 0.03em;" | |
285 "cursor: pointer;" | |
286 "transition: background 0.2s, box-shadow 0.2s;" | |
287 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);" | |
288 "}" | |
289 ".help-modal-btn:hover, .help-modal-btn:focus {" | |
290 "background-color: #21895e;" | |
291 "outline: none;" | |
292 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);" | |
293 "}" | |
294 "</style>" | |
295 ) | |
296 | |
297 html_content = ( | |
298 f"{get_html_template()}" | |
299 "<h1>Tabular Learner Model Report</h1>" | |
300 f"{feature_metrics_button_html}" | |
301 '<div class="tabs">' | |
302 '<div class="tab" onclick="openTab(event, \'summary\')">' | |
303 "Validation Result Summary & Config</div>" | |
304 '<div class="tab" onclick="openTab(event, \'plots\')">' | |
305 "Test Results</div>" | |
306 '<div class="tab" onclick="openTab(event, \'feature\')">' | |
307 "Feature Importance</div>" | |
308 ) | |
264 if self.plots_explainer_html: | 309 if self.plots_explainer_html: |
265 html_content += """ | 310 html_content += ( |
266 <div class="tab" onclick="openTab(event, 'explainer')"> | 311 '<div class="tab" onclick="openTab(event, \'explainer\')">' |
267 Explainer Plots</div> | 312 "Explainer Plots</div>" |
268 """ | 313 ) |
269 html_content += f""" | 314 html_content += ( |
270 </div> | 315 "</div>" |
271 <div id="summary" class="tab-content"> | 316 '<div id="summary" class="tab-content">' |
272 <h2>Setup Parameters</h2> | 317 "<h2>Model Metrics from Cross-Validation Set</h2>" |
273 {setup_params_table.to_html( | 318 f"<h2>Best Model: {model_name}</h2>" |
274 index=False, | 319 "<h5>The best model is selected by: Accuracy (Classification)" |
275 header=True, | 320 " or R2 (Regression).</h5>" |
276 classes='table sortable' | 321 f"{self.results.to_html(index=False, classes='table sortable')}" |
277 )} | 322 "<h2>Best Model's Hyperparameters</h2>" |
278 <h5>If you want to know all the experiment setup parameters, | 323 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}" |
279 please check the PyCaret documentation for | 324 "<h2>Setup Parameters</h2>" |
280 the classification/regression <code>exp</code> function.</h5> | 325 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}" |
281 <h2>Best Model: {model_name}</h2> | 326 "<h5>If you want to know all the experiment setup parameters," |
282 {best_model_params.to_html( | 327 " please check the PyCaret documentation for" |
283 index=False, | 328 " the classification/regression <code>exp</code> function.</h5>" |
284 header=True, | 329 "</div>" |
285 classes='table sortable' | 330 '<div id="plots" class="tab-content">' |
286 )} | 331 f"<h2>Best Model: {model_name}</h2>" |
287 <h2>Comparison Results on the Cross-Validation Set</h2> | 332 "<h5>The best model is selected by: Accuracy (Classification)" |
288 {self.results.to_html(index=False, classes='table sortable')} | 333 " or R2 (Regression).</h5>" |
289 <h2>Results on the Test Set for the best model</h2> | 334 "<h2>Test Metrics</h2>" |
290 {self.test_result_df.to_html( | 335 f"{self.test_result_df.to_html(index=False)}" |
291 index=False, | 336 "<h2>Test Results</h2>" |
292 classes='table sortable' | 337 f"{plots_html}" |
293 )} | 338 "</div>" |
294 </div> | 339 '<div id="feature" class="tab-content">' |
295 <div id="plots" class="tab-content"> | 340 f"{feature_importance_html}" |
296 <h2>Best Model Plots on the testing set</h2> | 341 "</div>" |
297 {plots_html} | 342 ) |
298 </div> | |
299 <div id="feature" class="tab-content"> | |
300 {feature_importance_html} | |
301 </div> | |
302 """ | |
303 if self.plots_explainer_html: | 343 if self.plots_explainer_html: |
304 html_content += f""" | 344 html_content += ( |
305 <div id="explainer" class="tab-content"> | 345 '<div id="explainer" class="tab-content">' |
306 {self.plots_explainer_html} | 346 f"{self.plots_explainer_html}" |
307 {tree_plots} | 347 f"{tree_plots}" |
308 </div> | 348 "</div>" |
309 """ | 349 ) |
310 html_content += """ | 350 html_content += ( |
311 <script> | 351 "<script>" |
312 document.addEventListener("DOMContentLoaded", function() { | 352 "document.addEventListener(\"DOMContentLoaded\", function() {" |
313 var tables = document.querySelectorAll("table.sortable"); | 353 "var tables = document.querySelectorAll(\"table.sortable\");" |
314 tables.forEach(function(table) { | 354 "tables.forEach(function(table) {" |
315 var headers = table.querySelectorAll("th"); | 355 "var headers = table.querySelectorAll(\"th\");" |
316 headers.forEach(function(header, index) { | 356 "headers.forEach(function(header, index) {" |
317 header.style.cursor = "pointer"; | 357 "header.style.cursor = \"pointer\";" |
318 // Add initial arrow (up) to indicate sortability | 358 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)" |
319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; | 359 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';" |
320 header.addEventListener("click", function() { | 360 "header.addEventListener(\"click\", function() {" |
321 var direction = this.getAttribute( | 361 "var direction = this.getAttribute(" |
322 "data-sort-direction" | 362 "\"data-sort-direction\"" |
323 ) || "asc"; | 363 ") || \"asc\";" |
324 // Reset arrows in all headers of this table | 364 "// Reset arrows in all headers of this table" |
325 headers.forEach(function(h) { | 365 "headers.forEach(function(h) {" |
326 var arrow = h.querySelector(".sort-arrow"); | 366 "var arrow = h.querySelector(\".sort-arrow\");" |
327 if (arrow) arrow.textContent = " ↑"; | 367 "if (arrow) arrow.textContent = \" ↑\";" |
328 }); | 368 "});" |
329 // Set arrow for clicked header | 369 "// Set arrow for clicked header" |
330 var arrow = this.querySelector(".sort-arrow"); | 370 "var arrow = this.querySelector(\".sort-arrow\");" |
331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; | 371 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";" |
332 sortTable(table, index, direction); | 372 "sortTable(table, index, direction);" |
333 this.setAttribute("data-sort-direction", | 373 "this.setAttribute(\"data-sort-direction\"," |
334 direction === "asc" ? "desc" : "asc"); | 374 "direction === \"asc\" ? \"desc\" : \"asc\");" |
335 }); | 375 "});" |
336 }); | 376 "});" |
337 }); | 377 "});" |
338 }); | 378 "});" |
339 | 379 "function sortTable(table, colNum, direction) {" |
340 function sortTable(table, colNum, direction) { | 380 "var tb = table.tBodies[0];" |
341 var tb = table.tBodies[0]; | 381 "var tr = Array.prototype.slice.call(tb.rows, 0);" |
342 var tr = Array.prototype.slice.call(tb.rows, 0); | 382 "var multiplier = direction === \"asc\" ? 1 : -1;" |
343 var multiplier = direction === "asc" ? 1 : -1; | 383 "tr = tr.sort(function(a, b) {" |
344 tr = tr.sort(function(a, b) { | 384 "var aText = a.cells[colNum].textContent.trim();" |
345 var aText = a.cells[colNum].textContent.trim(); | 385 "var bText = b.cells[colNum].textContent.trim();" |
346 var bText = b.cells[colNum].textContent.trim(); | 386 "// Remove arrow from text comparison" |
347 // Remove arrow from text comparison | 387 "aText = aText.replace(/[↑↓]/g, '').trim();" |
348 aText = aText.replace(/[↑↓]/g, '').trim(); | 388 "bText = bText.replace(/[↑↓]/g, '').trim();" |
349 bText = bText.replace(/[↑↓]/g, '').trim(); | 389 "if (!isNaN(aText) && !isNaN(bText)) {" |
350 if (!isNaN(aText) && !isNaN(bText)) { | 390 "return multiplier * (" |
351 return multiplier * ( | 391 "parseFloat(aText) - parseFloat(bText)" |
352 parseFloat(aText) - parseFloat(bText) | 392 ");" |
353 ); | 393 "} else {" |
354 } else { | 394 "return multiplier * aText.localeCompare(bText);" |
355 return multiplier * aText.localeCompare(bText); | 395 "}" |
356 } | 396 "});" |
357 }); | 397 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);" |
358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); | 398 "}" |
359 } | 399 "</script>" |
360 </script> | 400 ) |
361 """ | 401 # --- Add the Feature Metrics Help Modal --- |
362 html_content += f""" | 402 html_content += get_feature_metrics_help_modal() |
363 {get_html_closing()} | 403 html_content += f"{get_html_closing()}" |
364 """ | |
365 with open( | 404 with open( |
366 os.path.join(self.output_dir, "comparison_result.html"), | 405 os.path.join(self.output_dir, "comparison_result.html"), |
367 "w" | 406 "w", |
407 encoding="utf-8", | |
368 ) as file: | 408 ) as file: |
369 file.write(html_content) | 409 file.write(html_content) |
370 | 410 |
371 def save_dashboard(self): | 411 def save_dashboard(self): |
372 raise NotImplementedError("Subclasses should implement this method") | 412 raise NotImplementedError("Subclasses should implement this method") |
373 | 413 |
374 def generate_plots_explainer(self): | 414 def generate_plots_explainer(self): |
375 raise NotImplementedError("Subclasses should implement this method") | 415 raise NotImplementedError("Subclasses should implement this method") |
376 | 416 |
377 # not working now | |
378 def generate_tree_plots(self): | 417 def generate_tree_plots(self): |
379 from sklearn.ensemble import RandomForestClassifier, \ | 418 from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
380 RandomForestRegressor | |
381 from xgboost import XGBClassifier, XGBRegressor | 419 from xgboost import XGBClassifier, XGBRegressor |
382 from explainerdashboard.explainers import RandomForestExplainer | 420 from explainerdashboard.explainers import RandomForestExplainer |
383 | 421 |
384 LOG.info("Generating tree plots") | 422 LOG.info("Generating tree plots") |
385 X_test = self.exp.X_test_transformed.copy() | 423 X_test = self.exp.X_test_transformed.copy() |
386 y_test = self.exp.y_test_transformed | 424 y_test = self.exp.y_test_transformed |
387 | 425 |
388 is_rf = isinstance(self.best_model, RandomForestClassifier) or \ | 426 is_rf = isinstance( |
389 isinstance(self.best_model, RandomForestRegressor) | 427 self.best_model, (RandomForestClassifier, RandomForestRegressor) |
390 | 428 ) |
391 is_xgb = isinstance(self.best_model, XGBClassifier) or \ | 429 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) |
392 isinstance(self.best_model, XGBRegressor) | 430 |
431 num_trees = None | |
432 if is_rf: | |
433 num_trees = self.best_model.n_estimators | |
434 elif is_xgb: | |
435 num_trees = len(self.best_model.get_booster().get_dump()) | |
436 else: | |
437 LOG.warning("Tree plots not supported for this model type.") | |
438 return | |
393 | 439 |
394 try: | 440 try: |
395 if is_rf: | |
396 num_trees = self.best_model.n_estimators | |
397 if is_xgb: | |
398 num_trees = len(self.best_model.get_booster().get_dump()) | |
399 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | 441 explainer = RandomForestExplainer(self.best_model, X_test, y_test) |
400 for i in range(num_trees): | 442 for i in range(num_trees): |
401 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | 443 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) |
402 LOG.info(f"Tree {i+1}") | 444 LOG.info(f"Tree {i + 1}") |
403 LOG.info(fig) | 445 LOG.info(fig) |
404 self.trees.append(fig) | 446 self.trees.append(fig) |
405 except Exception as e: | 447 except Exception as e: |
406 LOG.error(f"Error generating tree plots: {e}") | 448 LOG.error(f"Error generating tree plots: {e}") |
407 | 449 |