Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 2:77c88226bfde draft
planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
| author | goeckslab |
|---|---|
| date | Wed, 02 Jul 2025 18:59:39 +0000 |
| parents | 209b663a4f62 |
| children | f6a65e05d6ec |
comparison
equal
deleted
inserted
replaced
| 1:f69ed50c9768 | 2:77c88226bfde |
|---|---|
| 5 | 5 |
| 6 import h5py | 6 import h5py |
| 7 import joblib | 7 import joblib |
| 8 import numpy as np | 8 import numpy as np |
| 9 import pandas as pd | 9 import pandas as pd |
| 10 from feature_help_modal import get_feature_metrics_help_modal | |
| 10 from feature_importance import FeatureImportanceAnalyzer | 11 from feature_importance import FeatureImportanceAnalyzer |
| 11 from sklearn.metrics import average_precision_score | 12 from sklearn.metrics import average_precision_score |
| 12 from utils import get_html_closing, get_html_template | 13 from utils import get_html_closing, get_html_template |
| 13 | 14 |
| 14 logging.basicConfig(level=logging.DEBUG) | 15 logging.basicConfig(level=logging.DEBUG) |
| 15 LOG = logging.getLogger(__name__) | 16 LOG = logging.getLogger(__name__) |
| 16 | 17 |
| 17 | 18 |
| 18 class BaseModelTrainer: | 19 class BaseModelTrainer: |
| 19 | |
| 20 def __init__( | 20 def __init__( |
| 21 self, | 21 self, |
| 22 input_file, | 22 input_file, |
| 23 target_col, | 23 target_col, |
| 24 output_dir, | 24 output_dir, |
| 25 task_type, | 25 task_type, |
| 26 random_seed, | 26 random_seed, |
| 27 test_file=None, | 27 test_file=None, |
| 28 **kwargs): | 28 **kwargs, |
| 29 ): | |
| 29 self.exp = None # This will be set in the subclass | 30 self.exp = None # This will be set in the subclass |
| 30 self.input_file = input_file | 31 self.input_file = input_file |
| 31 self.target_col = target_col | 32 self.target_col = target_col |
| 32 self.output_dir = output_dir | 33 self.output_dir = output_dir |
| 33 self.task_type = task_type | 34 self.task_type = task_type |
| 45 setattr(self, key, value) | 46 setattr(self, key, value) |
| 46 self.setup_params = {} | 47 self.setup_params = {} |
| 47 self.test_file = test_file | 48 self.test_file = test_file |
| 48 self.test_data = None | 49 self.test_data = None |
| 49 | 50 |
| 51 if not self.output_dir: | |
| 52 raise ValueError("output_dir must be specified and not None") | |
| 53 | |
| 50 LOG.info(f"Model kwargs: {self.__dict__}") | 54 LOG.info(f"Model kwargs: {self.__dict__}") |
| 51 | 55 |
| 52 def load_data(self): | 56 def load_data(self): |
| 53 LOG.info(f"Loading data from {self.input_file}") | 57 LOG.info(f"Loading data from {self.input_file}") |
| 54 self.data = pd.read_csv(self.input_file, sep=None, engine='python') | 58 self.data = pd.read_csv(self.input_file, sep=None, engine="python") |
| 55 self.data.columns = self.data.columns.str.replace('.', '_') | 59 self.data.columns = self.data.columns.str.replace(".", "_") |
| 56 | 60 |
| 57 numeric_cols = self.data.select_dtypes(include=['number']).columns | 61 # Remove prediction_label if present |
| 58 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns | 62 if "prediction_label" in self.data.columns: |
| 63 self.data = self.data.drop(columns=["prediction_label"]) | |
| 64 | |
| 65 numeric_cols = self.data.select_dtypes(include=["number"]).columns | |
| 66 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns | |
| 59 | 67 |
| 60 self.data[numeric_cols] = self.data[numeric_cols].apply( | 68 self.data[numeric_cols] = self.data[numeric_cols].apply( |
| 61 pd.to_numeric, errors='coerce') | 69 pd.to_numeric, errors="coerce" |
| 70 ) | |
| 62 | 71 |
| 63 if len(non_numeric_cols) > 0: | 72 if len(non_numeric_cols) > 0: |
| 64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 73 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") |
| 65 | 74 |
| 66 names = self.data.columns.to_list() | 75 names = self.data.columns.to_list() |
| 67 target_index = int(self.target_col) - 1 | 76 target_index = int(self.target_col) - 1 |
| 68 self.target = names[target_index] | 77 self.target = names[target_index] |
| 69 self.features_name = [name | 78 self.features_name = [name for i, name in enumerate(names) if i != target_index] |
| 70 for i, name in enumerate(names) | 79 if hasattr(self, "missing_value_strategy"): |
| 71 if i != target_index] | 80 if self.missing_value_strategy == "mean": |
| 72 if hasattr(self, 'missing_value_strategy'): | 81 self.data = self.data.fillna(self.data.mean(numeric_only=True)) |
| 73 if self.missing_value_strategy == 'mean': | 82 elif self.missing_value_strategy == "median": |
| 74 self.data = self.data.fillna( | 83 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 75 self.data.mean(numeric_only=True)) | 84 elif self.missing_value_strategy == "drop": |
| 76 elif self.missing_value_strategy == 'median': | |
| 77 self.data = self.data.fillna( | |
| 78 self.data.median(numeric_only=True)) | |
| 79 elif self.missing_value_strategy == 'drop': | |
| 80 self.data = self.data.dropna() | 85 self.data = self.data.dropna() |
| 81 else: | 86 else: |
| 82 # Default strategy if not specified | 87 # Default strategy if not specified |
| 83 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 88 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 84 | 89 |
| 85 if self.test_file: | 90 if self.test_file: |
| 86 LOG.info(f"Loading test data from {self.test_file}") | 91 LOG.info(f"Loading test data from {self.test_file}") |
| 87 self.test_data = pd.read_csv( | 92 self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") |
| 88 self.test_file, sep=None, engine='python') | |
| 89 self.test_data = self.test_data[numeric_cols].apply( | 93 self.test_data = self.test_data[numeric_cols].apply( |
| 90 pd.to_numeric, errors='coerce') | 94 pd.to_numeric, errors="coerce" |
| 91 self.test_data.columns = self.test_data.columns.str.replace( | 95 ) |
| 92 '.', '_' | 96 self.test_data.columns = self.test_data.columns.str.replace(".", "_") |
| 93 ) | |
| 94 | 97 |
| 95 def setup_pycaret(self): | 98 def setup_pycaret(self): |
| 96 LOG.info("Initializing PyCaret") | 99 LOG.info("Initializing PyCaret") |
| 97 self.setup_params = { | 100 self.setup_params = { |
| 98 'target': self.target, | 101 "target": self.target, |
| 99 'session_id': self.random_seed, | 102 "session_id": self.random_seed, |
| 100 'html': True, | 103 "html": True, |
| 101 'log_experiment': False, | 104 "log_experiment": False, |
| 102 'system_log': False, | 105 "system_log": False, |
| 103 'index': False, | 106 "index": False, |
| 104 } | 107 } |
| 105 | 108 |
| 106 if self.test_data is not None: | 109 if self.test_data is not None: |
| 107 self.setup_params['test_data'] = self.test_data | 110 self.setup_params["test_data"] = self.test_data |
| 108 | 111 |
| 109 if hasattr(self, 'train_size') and self.train_size is not None \ | 112 if ( |
| 110 and self.test_data is None: | 113 hasattr(self, "train_size") |
| 111 self.setup_params['train_size'] = self.train_size | 114 and self.train_size is not None |
| 112 | 115 and self.test_data is None |
| 113 if hasattr(self, 'normalize') and self.normalize is not None: | 116 ): |
| 114 self.setup_params['normalize'] = self.normalize | 117 self.setup_params["train_size"] = self.train_size |
| 115 | 118 |
| 116 if hasattr(self, 'feature_selection') and \ | 119 if hasattr(self, "normalize") and self.normalize is not None: |
| 117 self.feature_selection is not None: | 120 self.setup_params["normalize"] = self.normalize |
| 118 self.setup_params['feature_selection'] = self.feature_selection | 121 |
| 119 | 122 if hasattr(self, "feature_selection") and self.feature_selection is not None: |
| 120 if hasattr(self, 'cross_validation') and \ | 123 self.setup_params["feature_selection"] = self.feature_selection |
| 121 self.cross_validation is not None \ | 124 |
| 122 and self.cross_validation is False: | 125 if ( |
| 123 self.setup_params['cross_validation'] = self.cross_validation | 126 hasattr(self, "cross_validation") |
| 124 | 127 and self.cross_validation is not None |
| 125 if hasattr(self, 'cross_validation') and \ | 128 and self.cross_validation is False |
| 126 self.cross_validation is not None: | 129 ): |
| 127 if hasattr(self, 'cross_validation_folds'): | 130 self.setup_params["cross_validation"] = self.cross_validation |
| 128 self.setup_params['fold'] = self.cross_validation_folds | 131 |
| 129 | 132 if hasattr(self, "cross_validation") and self.cross_validation is not None: |
| 130 if hasattr(self, 'remove_outliers') and \ | 133 if hasattr(self, "cross_validation_folds"): |
| 131 self.remove_outliers is not None: | 134 self.setup_params["fold"] = self.cross_validation_folds |
| 132 self.setup_params['remove_outliers'] = self.remove_outliers | 135 |
| 133 | 136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: |
| 134 if hasattr(self, 'remove_multicollinearity') and \ | 137 self.setup_params["remove_outliers"] = self.remove_outliers |
| 135 self.remove_multicollinearity is not None: | 138 |
| 136 self.setup_params['remove_multicollinearity'] = \ | 139 if ( |
| 140 hasattr(self, "remove_multicollinearity") | |
| 141 and self.remove_multicollinearity is not None | |
| 142 ): | |
| 143 self.setup_params["remove_multicollinearity"] = ( | |
| 137 self.remove_multicollinearity | 144 self.remove_multicollinearity |
| 138 | 145 ) |
| 139 if hasattr(self, 'polynomial_features') and \ | 146 |
| 140 self.polynomial_features is not None: | 147 if ( |
| 141 self.setup_params['polynomial_features'] = self.polynomial_features | 148 hasattr(self, "polynomial_features") |
| 142 | 149 and self.polynomial_features is not None |
| 143 if hasattr(self, 'fix_imbalance') and \ | 150 ): |
| 144 self.fix_imbalance is not None: | 151 self.setup_params["polynomial_features"] = self.polynomial_features |
| 145 self.setup_params['fix_imbalance'] = self.fix_imbalance | 152 |
| 153 if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None: | |
| 154 self.setup_params["fix_imbalance"] = self.fix_imbalance | |
| 146 | 155 |
| 147 LOG.info(self.setup_params) | 156 LOG.info(self.setup_params) |
| 157 | |
| 158 # Solution: instantiate the correct PyCaret experiment based on task_type | |
| 159 if self.task_type == "classification": | |
| 160 from pycaret.classification import ClassificationExperiment | |
| 161 | |
| 162 self.exp = ClassificationExperiment() | |
| 163 elif self.task_type == "regression": | |
| 164 from pycaret.regression import RegressionExperiment | |
| 165 | |
| 166 self.exp = RegressionExperiment() | |
| 167 else: | |
| 168 raise ValueError("task_type must be 'classification' or 'regression'") | |
| 169 | |
| 148 self.exp.setup(self.data, **self.setup_params) | 170 self.exp.setup(self.data, **self.setup_params) |
| 149 | 171 |
| 150 def train_model(self): | 172 def train_model(self): |
| 151 LOG.info("Training and selecting the best model") | 173 LOG.info("Training and selecting the best model") |
| 152 if self.task_type == "classification": | 174 if self.task_type == "classification": |
| 153 average_displayed = "Weighted" | 175 average_displayed = "Weighted" |
| 154 self.exp.add_metric(id=f'PR-AUC-{average_displayed}', | 176 self.exp.add_metric( |
| 155 name=f'PR-AUC-{average_displayed}', | 177 id=f"PR-AUC-{average_displayed}", |
| 156 target='pred_proba', | 178 name=f"PR-AUC-{average_displayed}", |
| 157 score_func=average_precision_score, | 179 target="pred_proba", |
| 158 average='weighted' | 180 score_func=average_precision_score, |
| 159 ) | 181 average="weighted", |
| 160 | 182 ) |
| 161 if hasattr(self, 'models') and self.models is not None: | 183 |
| 162 self.best_model = self.exp.compare_models( | 184 if hasattr(self, "models") and self.models is not None: |
| 163 include=self.models) | 185 self.best_model = self.exp.compare_models(include=self.models) |
| 164 else: | 186 else: |
| 165 self.best_model = self.exp.compare_models() | 187 self.best_model = self.exp.compare_models() |
| 166 self.results = self.exp.pull() | 188 self.results = self.exp.pull() |
| 167 if self.task_type == "classification": | 189 if self.task_type == "classification": |
| 168 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True) | 190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 169 | 191 |
| 170 _ = self.exp.predict_model(self.best_model) | 192 _ = self.exp.predict_model(self.best_model) |
| 171 self.test_result_df = self.exp.pull() | 193 self.test_result_df = self.exp.pull() |
| 172 if self.task_type == "classification": | 194 if self.task_type == "classification": |
| 173 self.test_result_df.rename( | 195 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
| 174 columns={'AUC': 'ROC-AUC'}, inplace=True) | |
| 175 | 196 |
| 176 def save_model(self): | 197 def save_model(self): |
| 177 hdf5_model_path = "pycaret_model.h5" | 198 hdf5_model_path = "pycaret_model.h5" |
| 178 with h5py.File(hdf5_model_path, 'w') as f: | 199 with h5py.File(hdf5_model_path, "w") as f: |
| 179 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | 200 with tempfile.NamedTemporaryFile(delete=False) as temp_file: |
| 180 joblib.dump(self.best_model, temp_file.name) | 201 joblib.dump(self.best_model, temp_file.name) |
| 181 temp_file.seek(0) | 202 temp_file.seek(0) |
| 182 model_bytes = temp_file.read() | 203 model_bytes = temp_file.read() |
| 183 f.create_dataset('model', data=np.void(model_bytes)) | 204 f.create_dataset("model", data=np.void(model_bytes)) |
| 184 | 205 |
| 185 def generate_plots(self): | 206 def generate_plots(self): |
| 186 raise NotImplementedError("Subclasses should implement this method") | 207 raise NotImplementedError("Subclasses should implement this method") |
| 187 | 208 |
| 188 def encode_image_to_base64(self, img_path): | 209 def encode_image_to_base64(self, img_path): |
| 189 with open(img_path, 'rb') as img_file: | 210 with open(img_path, "rb") as img_file: |
| 190 return base64.b64encode(img_file.read()).decode('utf-8') | 211 return base64.b64encode(img_file.read()).decode("utf-8") |
| 191 | 212 |
| 192 def save_html_report(self): | 213 def save_html_report(self): |
| 193 LOG.info("Saving HTML report") | 214 LOG.info("Saving HTML report") |
| 194 | 215 |
| 216 if not self.output_dir: | |
| 217 raise ValueError("output_dir must be specified and not None") | |
| 218 | |
| 195 model_name = type(self.best_model).__name__ | 219 model_name = type(self.best_model).__name__ |
| 196 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data'] | 220 excluded_params = ["html", "log_experiment", "system_log", "test_data"] |
| 197 filtered_setup_params = { | 221 filtered_setup_params = { |
| 198 k: v | 222 k: v for k, v in self.setup_params.items() if k not in excluded_params |
| 199 for k, v in self.setup_params.items() if k not in excluded_params | |
| 200 } | 223 } |
| 201 setup_params_table = pd.DataFrame( | 224 setup_params_table = pd.DataFrame( |
| 202 list(filtered_setup_params.items()), columns=['Parameter', 'Value'] | 225 list(filtered_setup_params.items()), columns=["Parameter", "Value"] |
| 203 ) | 226 ) |
| 204 | 227 |
| 205 best_model_params = pd.DataFrame( | 228 best_model_params = pd.DataFrame( |
| 206 self.best_model.get_params().items(), | 229 self.best_model.get_params().items(), columns=["Parameter", "Value"] |
| 207 columns=['Parameter', 'Value'] | |
| 208 ) | 230 ) |
| 209 best_model_params.to_csv( | 231 best_model_params.to_csv( |
| 210 os.path.join(self.output_dir, "best_model.csv"), index=False | 232 os.path.join(self.output_dir, "best_model.csv"), index=False |
| 211 ) | 233 ) |
| 212 self.results.to_csv( | 234 self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) |
| 213 os.path.join(self.output_dir, "comparison_results.csv") | 235 self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) |
| 214 ) | |
| 215 self.test_result_df.to_csv( | |
| 216 os.path.join(self.output_dir, "test_results.csv") | |
| 217 ) | |
| 218 | 236 |
| 219 plots_html = "" | 237 plots_html = "" |
| 220 length = len(self.plots) | 238 length = len(self.plots) |
| 221 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | 239 for i, (plot_name, plot_path) in enumerate(self.plots.items()): |
| 222 encoded_image = self.encode_image_to_base64(plot_path) | 240 encoded_image = self.encode_image_to_base64(plot_path) |
| 223 plots_html += f""" | 241 plots_html += ( |
| 224 <div class="plot"> | 242 f'<div class="plot">' |
| 225 <h3>{plot_name.capitalize()}</h3> | 243 f"<h3>{plot_name.capitalize()}</h3>" |
| 226 <img src="data:image/png;base64,{encoded_image}" | 244 f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' |
| 227 alt="{plot_name}"> | 245 f"</div>" |
| 228 </div> | 246 ) |
| 229 """ | |
| 230 if i < length - 1: | 247 if i < length - 1: |
| 231 plots_html += "<hr>" | 248 plots_html += "<hr>" |
| 232 | 249 |
| 233 tree_plots = "" | 250 tree_plots = "" |
| 234 for i, tree in enumerate(self.trees): | 251 for i, tree in enumerate(self.trees): |
| 235 if tree: | 252 if tree: |
| 236 tree_plots += f""" | 253 tree_plots += ( |
| 237 <div class="plot"> | 254 f'<div class="plot">' |
| 238 <h3>Tree {i+1}</h3> | 255 f"<h3>Tree {i + 1}</h3>" |
| 239 <img src="data:image/png;base64, | 256 f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' |
| 240 {tree}" | 257 f"</div>" |
| 241 alt="tree {i+1}"> | 258 ) |
| 242 </div> | |
| 243 """ | |
| 244 | 259 |
| 245 analyzer = FeatureImportanceAnalyzer( | 260 analyzer = FeatureImportanceAnalyzer( |
| 246 data=self.data, | 261 data=self.data, |
| 247 target_col=self.target_col, | 262 target_col=self.target_col, |
| 248 task_type=self.task_type, | 263 task_type=self.task_type, |
| 249 output_dir=self.output_dir, | 264 output_dir=self.output_dir, |
| 265 exp=self.exp, | |
| 266 best_model=self.best_model, | |
| 250 ) | 267 ) |
| 251 feature_importance_html = analyzer.run() | 268 feature_importance_html = analyzer.run() |
| 252 | 269 |
| 253 html_content = f""" | 270 # --- Feature Metrics Help Button --- |
| 254 {get_html_template()} | 271 feature_metrics_button_html = ( |
| 255 <h1>PyCaret Model Training Report</h1> | 272 '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">' |
| 256 <div class="tabs"> | 273 "Help: Metrics Guide" |
| 257 <div class="tab" onclick="openTab(event, 'summary')"> | 274 "</button>" |
| 258 Setup & Best Model</div> | 275 "<style>" |
| 259 <div class="tab" onclick="openTab(event, 'plots')"> | 276 ".help-modal-btn {" |
| 260 Best Model Plots</div> | 277 "background-color: #17623b;" |
| 261 <div class="tab" onclick="openTab(event, 'feature')"> | 278 "color: #fff;" |
| 262 Feature Importance</div> | 279 "border: none;" |
| 263 """ | 280 "border-radius: 24px;" |
| 281 "padding: 10px 28px;" | |
| 282 "font-size: 1.1rem;" | |
| 283 "font-weight: bold;" | |
| 284 "letter-spacing: 0.03em;" | |
| 285 "cursor: pointer;" | |
| 286 "transition: background 0.2s, box-shadow 0.2s;" | |
| 287 "box-shadow: 0 2px 8px rgba(23,98,59,0.07);" | |
| 288 "}" | |
| 289 ".help-modal-btn:hover, .help-modal-btn:focus {" | |
| 290 "background-color: #21895e;" | |
| 291 "outline: none;" | |
| 292 "box-shadow: 0 4px 16px rgba(23,98,59,0.14);" | |
| 293 "}" | |
| 294 "</style>" | |
| 295 ) | |
| 296 | |
| 297 html_content = ( | |
| 298 f"{get_html_template()}" | |
| 299 "<h1>Tabular Learner Model Report</h1>" | |
| 300 f"{feature_metrics_button_html}" | |
| 301 '<div class="tabs">' | |
| 302 '<div class="tab" onclick="openTab(event, \'summary\')">' | |
| 303 "Validation Result Summary & Config</div>" | |
| 304 '<div class="tab" onclick="openTab(event, \'plots\')">' | |
| 305 "Test Results</div>" | |
| 306 '<div class="tab" onclick="openTab(event, \'feature\')">' | |
| 307 "Feature Importance</div>" | |
| 308 ) | |
| 264 if self.plots_explainer_html: | 309 if self.plots_explainer_html: |
| 265 html_content += """ | 310 html_content += ( |
| 266 <div class="tab" onclick="openTab(event, 'explainer')"> | 311 '<div class="tab" onclick="openTab(event, \'explainer\')">' |
| 267 Explainer Plots</div> | 312 "Explainer Plots</div>" |
| 268 """ | 313 ) |
| 269 html_content += f""" | 314 html_content += ( |
| 270 </div> | 315 "</div>" |
| 271 <div id="summary" class="tab-content"> | 316 '<div id="summary" class="tab-content">' |
| 272 <h2>Setup Parameters</h2> | 317 "<h2>Model Metrics from Cross-Validation Set</h2>" |
| 273 {setup_params_table.to_html( | 318 f"<h2>Best Model: {model_name}</h2>" |
| 274 index=False, | 319 "<h5>The best model is selected by: Accuracy (Classification)" |
| 275 header=True, | 320 " or R2 (Regression).</h5>" |
| 276 classes='table sortable' | 321 f"{self.results.to_html(index=False, classes='table sortable')}" |
| 277 )} | 322 "<h2>Best Model's Hyperparameters</h2>" |
| 278 <h5>If you want to know all the experiment setup parameters, | 323 f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}" |
| 279 please check the PyCaret documentation for | 324 "<h2>Setup Parameters</h2>" |
| 280 the classification/regression <code>exp</code> function.</h5> | 325 f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}" |
| 281 <h2>Best Model: {model_name}</h2> | 326 "<h5>If you want to know all the experiment setup parameters," |
| 282 {best_model_params.to_html( | 327 " please check the PyCaret documentation for" |
| 283 index=False, | 328 " the classification/regression <code>exp</code> function.</h5>" |
| 284 header=True, | 329 "</div>" |
| 285 classes='table sortable' | 330 '<div id="plots" class="tab-content">' |
| 286 )} | 331 f"<h2>Best Model: {model_name}</h2>" |
| 287 <h2>Comparison Results on the Cross-Validation Set</h2> | 332 "<h5>The best model is selected by: Accuracy (Classification)" |
| 288 {self.results.to_html(index=False, classes='table sortable')} | 333 " or R2 (Regression).</h5>" |
| 289 <h2>Results on the Test Set for the best model</h2> | 334 "<h2>Test Metrics</h2>" |
| 290 {self.test_result_df.to_html( | 335 f"{self.test_result_df.to_html(index=False)}" |
| 291 index=False, | 336 "<h2>Test Results</h2>" |
| 292 classes='table sortable' | 337 f"{plots_html}" |
| 293 )} | 338 "</div>" |
| 294 </div> | 339 '<div id="feature" class="tab-content">' |
| 295 <div id="plots" class="tab-content"> | 340 f"{feature_importance_html}" |
| 296 <h2>Best Model Plots on the testing set</h2> | 341 "</div>" |
| 297 {plots_html} | 342 ) |
| 298 </div> | |
| 299 <div id="feature" class="tab-content"> | |
| 300 {feature_importance_html} | |
| 301 </div> | |
| 302 """ | |
| 303 if self.plots_explainer_html: | 343 if self.plots_explainer_html: |
| 304 html_content += f""" | 344 html_content += ( |
| 305 <div id="explainer" class="tab-content"> | 345 '<div id="explainer" class="tab-content">' |
| 306 {self.plots_explainer_html} | 346 f"{self.plots_explainer_html}" |
| 307 {tree_plots} | 347 f"{tree_plots}" |
| 308 </div> | 348 "</div>" |
| 309 """ | 349 ) |
| 310 html_content += """ | 350 html_content += ( |
| 311 <script> | 351 "<script>" |
| 312 document.addEventListener("DOMContentLoaded", function() { | 352 "document.addEventListener(\"DOMContentLoaded\", function() {" |
| 313 var tables = document.querySelectorAll("table.sortable"); | 353 "var tables = document.querySelectorAll(\"table.sortable\");" |
| 314 tables.forEach(function(table) { | 354 "tables.forEach(function(table) {" |
| 315 var headers = table.querySelectorAll("th"); | 355 "var headers = table.querySelectorAll(\"th\");" |
| 316 headers.forEach(function(header, index) { | 356 "headers.forEach(function(header, index) {" |
| 317 header.style.cursor = "pointer"; | 357 "header.style.cursor = \"pointer\";" |
| 318 // Add initial arrow (up) to indicate sortability | 358 "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)" |
| 319 header.innerHTML += '<span class="sort-arrow"> ↑</span>'; | 359 "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';" |
| 320 header.addEventListener("click", function() { | 360 "header.addEventListener(\"click\", function() {" |
| 321 var direction = this.getAttribute( | 361 "var direction = this.getAttribute(" |
| 322 "data-sort-direction" | 362 "\"data-sort-direction\"" |
| 323 ) || "asc"; | 363 ") || \"asc\";" |
| 324 // Reset arrows in all headers of this table | 364 "// Reset arrows in all headers of this table" |
| 325 headers.forEach(function(h) { | 365 "headers.forEach(function(h) {" |
| 326 var arrow = h.querySelector(".sort-arrow"); | 366 "var arrow = h.querySelector(\".sort-arrow\");" |
| 327 if (arrow) arrow.textContent = " ↑"; | 367 "if (arrow) arrow.textContent = \" ↑\";" |
| 328 }); | 368 "});" |
| 329 // Set arrow for clicked header | 369 "// Set arrow for clicked header" |
| 330 var arrow = this.querySelector(".sort-arrow"); | 370 "var arrow = this.querySelector(\".sort-arrow\");" |
| 331 arrow.textContent = direction === "asc" ? " ↓" : " ↑"; | 371 "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";" |
| 332 sortTable(table, index, direction); | 372 "sortTable(table, index, direction);" |
| 333 this.setAttribute("data-sort-direction", | 373 "this.setAttribute(\"data-sort-direction\"," |
| 334 direction === "asc" ? "desc" : "asc"); | 374 "direction === \"asc\" ? \"desc\" : \"asc\");" |
| 335 }); | 375 "});" |
| 336 }); | 376 "});" |
| 337 }); | 377 "});" |
| 338 }); | 378 "});" |
| 339 | 379 "function sortTable(table, colNum, direction) {" |
| 340 function sortTable(table, colNum, direction) { | 380 "var tb = table.tBodies[0];" |
| 341 var tb = table.tBodies[0]; | 381 "var tr = Array.prototype.slice.call(tb.rows, 0);" |
| 342 var tr = Array.prototype.slice.call(tb.rows, 0); | 382 "var multiplier = direction === \"asc\" ? 1 : -1;" |
| 343 var multiplier = direction === "asc" ? 1 : -1; | 383 "tr = tr.sort(function(a, b) {" |
| 344 tr = tr.sort(function(a, b) { | 384 "var aText = a.cells[colNum].textContent.trim();" |
| 345 var aText = a.cells[colNum].textContent.trim(); | 385 "var bText = b.cells[colNum].textContent.trim();" |
| 346 var bText = b.cells[colNum].textContent.trim(); | 386 "// Remove arrow from text comparison" |
| 347 // Remove arrow from text comparison | 387 "aText = aText.replace(/[↑↓]/g, '').trim();" |
| 348 aText = aText.replace(/[↑↓]/g, '').trim(); | 388 "bText = bText.replace(/[↑↓]/g, '').trim();" |
| 349 bText = bText.replace(/[↑↓]/g, '').trim(); | 389 "if (!isNaN(aText) && !isNaN(bText)) {" |
| 350 if (!isNaN(aText) && !isNaN(bText)) { | 390 "return multiplier * (" |
| 351 return multiplier * ( | 391 "parseFloat(aText) - parseFloat(bText)" |
| 352 parseFloat(aText) - parseFloat(bText) | 392 ");" |
| 353 ); | 393 "} else {" |
| 354 } else { | 394 "return multiplier * aText.localeCompare(bText);" |
| 355 return multiplier * aText.localeCompare(bText); | 395 "}" |
| 356 } | 396 "});" |
| 357 }); | 397 "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);" |
| 358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]); | 398 "}" |
| 359 } | 399 "</script>" |
| 360 </script> | 400 ) |
| 361 """ | 401 # --- Add the Feature Metrics Help Modal --- |
| 362 html_content += f""" | 402 html_content += get_feature_metrics_help_modal() |
| 363 {get_html_closing()} | 403 html_content += f"{get_html_closing()}" |
| 364 """ | |
| 365 with open( | 404 with open( |
| 366 os.path.join(self.output_dir, "comparison_result.html"), | 405 os.path.join(self.output_dir, "comparison_result.html"), |
| 367 "w" | 406 "w", |
| 407 encoding="utf-8", | |
| 368 ) as file: | 408 ) as file: |
| 369 file.write(html_content) | 409 file.write(html_content) |
| 370 | 410 |
| 371 def save_dashboard(self): | 411 def save_dashboard(self): |
| 372 raise NotImplementedError("Subclasses should implement this method") | 412 raise NotImplementedError("Subclasses should implement this method") |
| 373 | 413 |
| 374 def generate_plots_explainer(self): | 414 def generate_plots_explainer(self): |
| 375 raise NotImplementedError("Subclasses should implement this method") | 415 raise NotImplementedError("Subclasses should implement this method") |
| 376 | 416 |
| 377 # not working now | |
| 378 def generate_tree_plots(self): | 417 def generate_tree_plots(self): |
| 379 from sklearn.ensemble import RandomForestClassifier, \ | 418 from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
| 380 RandomForestRegressor | |
| 381 from xgboost import XGBClassifier, XGBRegressor | 419 from xgboost import XGBClassifier, XGBRegressor |
| 382 from explainerdashboard.explainers import RandomForestExplainer | 420 from explainerdashboard.explainers import RandomForestExplainer |
| 383 | 421 |
| 384 LOG.info("Generating tree plots") | 422 LOG.info("Generating tree plots") |
| 385 X_test = self.exp.X_test_transformed.copy() | 423 X_test = self.exp.X_test_transformed.copy() |
| 386 y_test = self.exp.y_test_transformed | 424 y_test = self.exp.y_test_transformed |
| 387 | 425 |
| 388 is_rf = isinstance(self.best_model, RandomForestClassifier) or \ | 426 is_rf = isinstance( |
| 389 isinstance(self.best_model, RandomForestRegressor) | 427 self.best_model, (RandomForestClassifier, RandomForestRegressor) |
| 390 | 428 ) |
| 391 is_xgb = isinstance(self.best_model, XGBClassifier) or \ | 429 is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) |
| 392 isinstance(self.best_model, XGBRegressor) | 430 |
| 431 num_trees = None | |
| 432 if is_rf: | |
| 433 num_trees = self.best_model.n_estimators | |
| 434 elif is_xgb: | |
| 435 num_trees = len(self.best_model.get_booster().get_dump()) | |
| 436 else: | |
| 437 LOG.warning("Tree plots not supported for this model type.") | |
| 438 return | |
| 393 | 439 |
| 394 try: | 440 try: |
| 395 if is_rf: | |
| 396 num_trees = self.best_model.n_estimators | |
| 397 if is_xgb: | |
| 398 num_trees = len(self.best_model.get_booster().get_dump()) | |
| 399 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | 441 explainer = RandomForestExplainer(self.best_model, X_test, y_test) |
| 400 for i in range(num_trees): | 442 for i in range(num_trees): |
| 401 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | 443 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) |
| 402 LOG.info(f"Tree {i+1}") | 444 LOG.info(f"Tree {i + 1}") |
| 403 LOG.info(fig) | 445 LOG.info(fig) |
| 404 self.trees.append(fig) | 446 self.trees.append(fig) |
| 405 except Exception as e: | 447 except Exception as e: |
| 406 LOG.error(f"Error generating tree plots: {e}") | 448 LOG.error(f"Error generating tree plots: {e}") |
| 407 | 449 |
