Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_predict.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
| author | goeckslab |
|---|---|
| date | Wed, 18 Jun 2025 15:38:19 +0000 |
| parents | |
| children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:209b663a4f62 |
|---|---|
| 1 import argparse | |
| 2 import logging | |
| 3 import tempfile | |
| 4 | |
| 5 import h5py | |
| 6 import joblib | |
| 7 import pandas as pd | |
| 8 from pycaret.classification import ClassificationExperiment | |
| 9 from pycaret.regression import RegressionExperiment | |
| 10 from sklearn.metrics import average_precision_score | |
| 11 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
| 12 | |
| 13 LOG = logging.getLogger(__name__) | |
| 14 | |
| 15 | |
| 16 class PyCaretModelEvaluator: | |
| 17 def __init__(self, model_path, task, target): | |
| 18 self.model_path = model_path | |
| 19 self.task = task.lower() | |
| 20 self.model = self.load_h5_model() | |
| 21 self.target = target if target != "None" else None | |
| 22 | |
| 23 def load_h5_model(self): | |
| 24 """Load a PyCaret model from an HDF5 file.""" | |
| 25 with h5py.File(self.model_path, 'r') as f: | |
| 26 model_bytes = bytes(f['model'][()]) | |
| 27 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| 28 temp_file.write(model_bytes) | |
| 29 temp_file.seek(0) | |
| 30 loaded_model = joblib.load(temp_file.name) | |
| 31 return loaded_model | |
| 32 | |
| 33 def evaluate(self, data_path): | |
| 34 """Evaluate the model using the specified data.""" | |
| 35 raise NotImplementedError("Subclasses must implement this method") | |
| 36 | |
| 37 | |
| 38 class ClassificationEvaluator(PyCaretModelEvaluator): | |
| 39 def evaluate(self, data_path): | |
| 40 metrics = None | |
| 41 plot_paths = {} | |
| 42 data = pd.read_csv(data_path, engine='python', sep=None) | |
| 43 if self.target: | |
| 44 exp = ClassificationExperiment() | |
| 45 names = data.columns.to_list() | |
| 46 LOG.error(f"Column names: {names}") | |
| 47 target_index = int(self.target) - 1 | |
| 48 target_name = names[target_index] | |
| 49 exp.setup(data, target=target_name, test_data=data, index=False) | |
| 50 exp.add_metric(id='PR-AUC-Weighted', | |
| 51 name='PR-AUC-Weighted', | |
| 52 target='pred_proba', | |
| 53 score_func=average_precision_score, | |
| 54 average='weighted') | |
| 55 predictions = exp.predict_model(self.model) | |
| 56 metrics = exp.pull() | |
| 57 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
| 58 'error', 'class_report', 'learning', 'calibration', | |
| 59 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
| 60 'feature_all'] | |
| 61 for plot_name in plots: | |
| 62 try: | |
| 63 if plot_name == 'auc' and not exp.is_multiclass: | |
| 64 plot_path = exp.plot_model(self.model, | |
| 65 plot=plot_name, | |
| 66 save=True, | |
| 67 plot_kwargs={ | |
| 68 'micro': False, | |
| 69 'macro': False, | |
| 70 'per_class': False, | |
| 71 'binary': True}) | |
| 72 plot_paths[plot_name] = plot_path | |
| 73 continue | |
| 74 | |
| 75 plot_path = exp.plot_model(self.model, | |
| 76 plot=plot_name, save=True) | |
| 77 plot_paths[plot_name] = plot_path | |
| 78 except Exception as e: | |
| 79 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 80 continue | |
| 81 generate_html_report(plot_paths, metrics) | |
| 82 | |
| 83 else: | |
| 84 exp = ClassificationExperiment() | |
| 85 exp.setup(data, target=None, test_data=data, index=False) | |
| 86 predictions = exp.predict_model(self.model, data=data) | |
| 87 | |
| 88 return predictions, metrics, plot_paths | |
| 89 | |
| 90 | |
| 91 class RegressionEvaluator(PyCaretModelEvaluator): | |
| 92 def evaluate(self, data_path): | |
| 93 metrics = None | |
| 94 plot_paths = {} | |
| 95 data = pd.read_csv(data_path, engine='python', sep=None) | |
| 96 if self.target: | |
| 97 names = data.columns.to_list() | |
| 98 target_index = int(self.target) - 1 | |
| 99 target_name = names[target_index] | |
| 100 exp = RegressionExperiment() | |
| 101 exp.setup(data, target=target_name, test_data=data, index=False) | |
| 102 predictions = exp.predict_model(self.model) | |
| 103 metrics = exp.pull() | |
| 104 plots = ['residuals', 'error', 'cooks', | |
| 105 'learning', 'vc', 'manifold', | |
| 106 'rfe', 'feature', 'feature_all'] | |
| 107 for plot_name in plots: | |
| 108 try: | |
| 109 plot_path = exp.plot_model(self.model, | |
| 110 plot=plot_name, save=True) | |
| 111 plot_paths[plot_name] = plot_path | |
| 112 except Exception as e: | |
| 113 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 114 continue | |
| 115 generate_html_report(plot_paths, metrics) | |
| 116 else: | |
| 117 exp = RegressionExperiment() | |
| 118 exp.setup(data, target=None, test_data=data, index=False) | |
| 119 predictions = exp.predict_model(self.model, data=data) | |
| 120 | |
| 121 return predictions, metrics, plot_paths | |
| 122 | |
| 123 | |
| 124 def generate_html_report(plots, metrics): | |
| 125 """Generate an HTML evaluation report.""" | |
| 126 plots_html = "" | |
| 127 for plot_name, plot_path in plots.items(): | |
| 128 encoded_image = encode_image_to_base64(plot_path) | |
| 129 plots_html += f""" | |
| 130 <div class="plot"> | |
| 131 <h3>{plot_name.capitalize()}</h3> | |
| 132 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | |
| 133 </div> | |
| 134 <hr> | |
| 135 """ | |
| 136 | |
| 137 metrics_html = metrics.to_html(index=False, classes="table") | |
| 138 | |
| 139 html_content = f""" | |
| 140 {get_html_template()} | |
| 141 <h1>Model Evaluation Report</h1> | |
| 142 <div class="tabs"> | |
| 143 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div> | |
| 144 <div class="tab" onclick="openTab(event, 'plots')">Plots</div> | |
| 145 </div> | |
| 146 <div id="metrics" class="tab-content"> | |
| 147 <h2>Metrics</h2> | |
| 148 <table> | |
| 149 {metrics_html} | |
| 150 </table> | |
| 151 </div> | |
| 152 <div id="plots" class="tab-content"> | |
| 153 <h2>Plots</h2> | |
| 154 {plots_html} | |
| 155 </div> | |
| 156 {get_html_closing()} | |
| 157 """ | |
| 158 | |
| 159 # Save HTML report | |
| 160 with open("evaluation_report.html", "w") as html_file: | |
| 161 html_file.write(html_content) | |
| 162 | |
| 163 | |
| 164 if __name__ == "__main__": | |
| 165 parser = argparse.ArgumentParser( | |
| 166 description="Evaluate a PyCaret model stored in HDF5 format.") | |
| 167 parser.add_argument("--model_path", | |
| 168 type=str, | |
| 169 help="Path to the HDF5 model file.") | |
| 170 parser.add_argument("--data_path", | |
| 171 type=str, | |
| 172 help="Path to the evaluation data CSV file.") | |
| 173 parser.add_argument("--task", | |
| 174 type=str, | |
| 175 choices=["classification", "regression"], | |
| 176 help="Specify the task: classification or regression.") | |
| 177 parser.add_argument("--target", | |
| 178 default=None, | |
| 179 help="Column number of the target") | |
| 180 args = parser.parse_args() | |
| 181 | |
| 182 if args.task == "classification": | |
| 183 evaluator = ClassificationEvaluator( | |
| 184 args.model_path, args.task, args.target) | |
| 185 elif args.task == "regression": | |
| 186 evaluator = RegressionEvaluator( | |
| 187 args.model_path, args.task, args.target) | |
| 188 else: | |
| 189 raise ValueError( | |
| 190 "Unsupported task type. Use 'classification' or 'regression'.") | |
| 191 | |
| 192 predictions, metrics, plots = evaluator.evaluate(args.data_path) | |
| 193 | |
| 194 predictions.to_csv("predictions.csv", index=False) |
