Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_predict.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author | goeckslab |
---|---|
date | Wed, 18 Jun 2025 15:38:19 +0000 |
parents | |
children | 11fdac5affb3 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:209b663a4f62 |
---|---|
1 import argparse | |
2 import logging | |
3 import tempfile | |
4 | |
5 import h5py | |
6 import joblib | |
7 import pandas as pd | |
8 from pycaret.classification import ClassificationExperiment | |
9 from pycaret.regression import RegressionExperiment | |
10 from sklearn.metrics import average_precision_score | |
11 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
12 | |
13 LOG = logging.getLogger(__name__) | |
14 | |
15 | |
16 class PyCaretModelEvaluator: | |
17 def __init__(self, model_path, task, target): | |
18 self.model_path = model_path | |
19 self.task = task.lower() | |
20 self.model = self.load_h5_model() | |
21 self.target = target if target != "None" else None | |
22 | |
23 def load_h5_model(self): | |
24 """Load a PyCaret model from an HDF5 file.""" | |
25 with h5py.File(self.model_path, 'r') as f: | |
26 model_bytes = bytes(f['model'][()]) | |
27 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
28 temp_file.write(model_bytes) | |
29 temp_file.seek(0) | |
30 loaded_model = joblib.load(temp_file.name) | |
31 return loaded_model | |
32 | |
33 def evaluate(self, data_path): | |
34 """Evaluate the model using the specified data.""" | |
35 raise NotImplementedError("Subclasses must implement this method") | |
36 | |
37 | |
38 class ClassificationEvaluator(PyCaretModelEvaluator): | |
39 def evaluate(self, data_path): | |
40 metrics = None | |
41 plot_paths = {} | |
42 data = pd.read_csv(data_path, engine='python', sep=None) | |
43 if self.target: | |
44 exp = ClassificationExperiment() | |
45 names = data.columns.to_list() | |
46 LOG.error(f"Column names: {names}") | |
47 target_index = int(self.target) - 1 | |
48 target_name = names[target_index] | |
49 exp.setup(data, target=target_name, test_data=data, index=False) | |
50 exp.add_metric(id='PR-AUC-Weighted', | |
51 name='PR-AUC-Weighted', | |
52 target='pred_proba', | |
53 score_func=average_precision_score, | |
54 average='weighted') | |
55 predictions = exp.predict_model(self.model) | |
56 metrics = exp.pull() | |
57 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
58 'error', 'class_report', 'learning', 'calibration', | |
59 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
60 'feature_all'] | |
61 for plot_name in plots: | |
62 try: | |
63 if plot_name == 'auc' and not exp.is_multiclass: | |
64 plot_path = exp.plot_model(self.model, | |
65 plot=plot_name, | |
66 save=True, | |
67 plot_kwargs={ | |
68 'micro': False, | |
69 'macro': False, | |
70 'per_class': False, | |
71 'binary': True}) | |
72 plot_paths[plot_name] = plot_path | |
73 continue | |
74 | |
75 plot_path = exp.plot_model(self.model, | |
76 plot=plot_name, save=True) | |
77 plot_paths[plot_name] = plot_path | |
78 except Exception as e: | |
79 LOG.error(f"Error generating plot {plot_name}: {e}") | |
80 continue | |
81 generate_html_report(plot_paths, metrics) | |
82 | |
83 else: | |
84 exp = ClassificationExperiment() | |
85 exp.setup(data, target=None, test_data=data, index=False) | |
86 predictions = exp.predict_model(self.model, data=data) | |
87 | |
88 return predictions, metrics, plot_paths | |
89 | |
90 | |
91 class RegressionEvaluator(PyCaretModelEvaluator): | |
92 def evaluate(self, data_path): | |
93 metrics = None | |
94 plot_paths = {} | |
95 data = pd.read_csv(data_path, engine='python', sep=None) | |
96 if self.target: | |
97 names = data.columns.to_list() | |
98 target_index = int(self.target) - 1 | |
99 target_name = names[target_index] | |
100 exp = RegressionExperiment() | |
101 exp.setup(data, target=target_name, test_data=data, index=False) | |
102 predictions = exp.predict_model(self.model) | |
103 metrics = exp.pull() | |
104 plots = ['residuals', 'error', 'cooks', | |
105 'learning', 'vc', 'manifold', | |
106 'rfe', 'feature', 'feature_all'] | |
107 for plot_name in plots: | |
108 try: | |
109 plot_path = exp.plot_model(self.model, | |
110 plot=plot_name, save=True) | |
111 plot_paths[plot_name] = plot_path | |
112 except Exception as e: | |
113 LOG.error(f"Error generating plot {plot_name}: {e}") | |
114 continue | |
115 generate_html_report(plot_paths, metrics) | |
116 else: | |
117 exp = RegressionExperiment() | |
118 exp.setup(data, target=None, test_data=data, index=False) | |
119 predictions = exp.predict_model(self.model, data=data) | |
120 | |
121 return predictions, metrics, plot_paths | |
122 | |
123 | |
124 def generate_html_report(plots, metrics): | |
125 """Generate an HTML evaluation report.""" | |
126 plots_html = "" | |
127 for plot_name, plot_path in plots.items(): | |
128 encoded_image = encode_image_to_base64(plot_path) | |
129 plots_html += f""" | |
130 <div class="plot"> | |
131 <h3>{plot_name.capitalize()}</h3> | |
132 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | |
133 </div> | |
134 <hr> | |
135 """ | |
136 | |
137 metrics_html = metrics.to_html(index=False, classes="table") | |
138 | |
139 html_content = f""" | |
140 {get_html_template()} | |
141 <h1>Model Evaluation Report</h1> | |
142 <div class="tabs"> | |
143 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div> | |
144 <div class="tab" onclick="openTab(event, 'plots')">Plots</div> | |
145 </div> | |
146 <div id="metrics" class="tab-content"> | |
147 <h2>Metrics</h2> | |
148 <table> | |
149 {metrics_html} | |
150 </table> | |
151 </div> | |
152 <div id="plots" class="tab-content"> | |
153 <h2>Plots</h2> | |
154 {plots_html} | |
155 </div> | |
156 {get_html_closing()} | |
157 """ | |
158 | |
159 # Save HTML report | |
160 with open("evaluation_report.html", "w") as html_file: | |
161 html_file.write(html_content) | |
162 | |
163 | |
164 if __name__ == "__main__": | |
165 parser = argparse.ArgumentParser( | |
166 description="Evaluate a PyCaret model stored in HDF5 format.") | |
167 parser.add_argument("--model_path", | |
168 type=str, | |
169 help="Path to the HDF5 model file.") | |
170 parser.add_argument("--data_path", | |
171 type=str, | |
172 help="Path to the evaluation data CSV file.") | |
173 parser.add_argument("--task", | |
174 type=str, | |
175 choices=["classification", "regression"], | |
176 help="Specify the task: classification or regression.") | |
177 parser.add_argument("--target", | |
178 default=None, | |
179 help="Column number of the target") | |
180 args = parser.parse_args() | |
181 | |
182 if args.task == "classification": | |
183 evaluator = ClassificationEvaluator( | |
184 args.model_path, args.task, args.target) | |
185 elif args.task == "regression": | |
186 evaluator = RegressionEvaluator( | |
187 args.model_path, args.task, args.target) | |
188 else: | |
189 raise ValueError( | |
190 "Unsupported task type. Use 'classification' or 'regression'.") | |
191 | |
192 predictions, metrics, plots = evaluator.evaluate(args.data_path) | |
193 | |
194 predictions.to_csv("predictions.csv", index=False) |