comparison pycaret_predict.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children 11fdac5affb3
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import argparse
2 import logging
3 import tempfile
4
5 import h5py
6 import joblib
7 import pandas as pd
8 from pycaret.classification import ClassificationExperiment
9 from pycaret.regression import RegressionExperiment
10 from sklearn.metrics import average_precision_score
11 from utils import encode_image_to_base64, get_html_closing, get_html_template
12
13 LOG = logging.getLogger(__name__)
14
15
16 class PyCaretModelEvaluator:
17 def __init__(self, model_path, task, target):
18 self.model_path = model_path
19 self.task = task.lower()
20 self.model = self.load_h5_model()
21 self.target = target if target != "None" else None
22
23 def load_h5_model(self):
24 """Load a PyCaret model from an HDF5 file."""
25 with h5py.File(self.model_path, 'r') as f:
26 model_bytes = bytes(f['model'][()])
27 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
28 temp_file.write(model_bytes)
29 temp_file.seek(0)
30 loaded_model = joblib.load(temp_file.name)
31 return loaded_model
32
33 def evaluate(self, data_path):
34 """Evaluate the model using the specified data."""
35 raise NotImplementedError("Subclasses must implement this method")
36
37
38 class ClassificationEvaluator(PyCaretModelEvaluator):
39 def evaluate(self, data_path):
40 metrics = None
41 plot_paths = {}
42 data = pd.read_csv(data_path, engine='python', sep=None)
43 if self.target:
44 exp = ClassificationExperiment()
45 names = data.columns.to_list()
46 LOG.error(f"Column names: {names}")
47 target_index = int(self.target) - 1
48 target_name = names[target_index]
49 exp.setup(data, target=target_name, test_data=data, index=False)
50 exp.add_metric(id='PR-AUC-Weighted',
51 name='PR-AUC-Weighted',
52 target='pred_proba',
53 score_func=average_precision_score,
54 average='weighted')
55 predictions = exp.predict_model(self.model)
56 metrics = exp.pull()
57 plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
58 'error', 'class_report', 'learning', 'calibration',
59 'vc', 'dimension', 'manifold', 'rfe', 'feature',
60 'feature_all']
61 for plot_name in plots:
62 try:
63 if plot_name == 'auc' and not exp.is_multiclass:
64 plot_path = exp.plot_model(self.model,
65 plot=plot_name,
66 save=True,
67 plot_kwargs={
68 'micro': False,
69 'macro': False,
70 'per_class': False,
71 'binary': True})
72 plot_paths[plot_name] = plot_path
73 continue
74
75 plot_path = exp.plot_model(self.model,
76 plot=plot_name, save=True)
77 plot_paths[plot_name] = plot_path
78 except Exception as e:
79 LOG.error(f"Error generating plot {plot_name}: {e}")
80 continue
81 generate_html_report(plot_paths, metrics)
82
83 else:
84 exp = ClassificationExperiment()
85 exp.setup(data, target=None, test_data=data, index=False)
86 predictions = exp.predict_model(self.model, data=data)
87
88 return predictions, metrics, plot_paths
89
90
91 class RegressionEvaluator(PyCaretModelEvaluator):
92 def evaluate(self, data_path):
93 metrics = None
94 plot_paths = {}
95 data = pd.read_csv(data_path, engine='python', sep=None)
96 if self.target:
97 names = data.columns.to_list()
98 target_index = int(self.target) - 1
99 target_name = names[target_index]
100 exp = RegressionExperiment()
101 exp.setup(data, target=target_name, test_data=data, index=False)
102 predictions = exp.predict_model(self.model)
103 metrics = exp.pull()
104 plots = ['residuals', 'error', 'cooks',
105 'learning', 'vc', 'manifold',
106 'rfe', 'feature', 'feature_all']
107 for plot_name in plots:
108 try:
109 plot_path = exp.plot_model(self.model,
110 plot=plot_name, save=True)
111 plot_paths[plot_name] = plot_path
112 except Exception as e:
113 LOG.error(f"Error generating plot {plot_name}: {e}")
114 continue
115 generate_html_report(plot_paths, metrics)
116 else:
117 exp = RegressionExperiment()
118 exp.setup(data, target=None, test_data=data, index=False)
119 predictions = exp.predict_model(self.model, data=data)
120
121 return predictions, metrics, plot_paths
122
123
124 def generate_html_report(plots, metrics):
125 """Generate an HTML evaluation report."""
126 plots_html = ""
127 for plot_name, plot_path in plots.items():
128 encoded_image = encode_image_to_base64(plot_path)
129 plots_html += f"""
130 <div class="plot">
131 <h3>{plot_name.capitalize()}</h3>
132 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
133 </div>
134 <hr>
135 """
136
137 metrics_html = metrics.to_html(index=False, classes="table")
138
139 html_content = f"""
140 {get_html_template()}
141 <h1>Model Evaluation Report</h1>
142 <div class="tabs">
143 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div>
144 <div class="tab" onclick="openTab(event, 'plots')">Plots</div>
145 </div>
146 <div id="metrics" class="tab-content">
147 <h2>Metrics</h2>
148 <table>
149 {metrics_html}
150 </table>
151 </div>
152 <div id="plots" class="tab-content">
153 <h2>Plots</h2>
154 {plots_html}
155 </div>
156 {get_html_closing()}
157 """
158
159 # Save HTML report
160 with open("evaluation_report.html", "w") as html_file:
161 html_file.write(html_content)
162
163
164 if __name__ == "__main__":
165 parser = argparse.ArgumentParser(
166 description="Evaluate a PyCaret model stored in HDF5 format.")
167 parser.add_argument("--model_path",
168 type=str,
169 help="Path to the HDF5 model file.")
170 parser.add_argument("--data_path",
171 type=str,
172 help="Path to the evaluation data CSV file.")
173 parser.add_argument("--task",
174 type=str,
175 choices=["classification", "regression"],
176 help="Specify the task: classification or regression.")
177 parser.add_argument("--target",
178 default=None,
179 help="Column number of the target")
180 args = parser.parse_args()
181
182 if args.task == "classification":
183 evaluator = ClassificationEvaluator(
184 args.model_path, args.task, args.target)
185 elif args.task == "regression":
186 evaluator = RegressionEvaluator(
187 args.model_path, args.task, args.target)
188 else:
189 raise ValueError(
190 "Unsupported task type. Use 'classification' or 'regression'.")
191
192 predictions, metrics, plots = evaluator.evaluate(args.data_path)
193
194 predictions.to_csv("predictions.csv", index=False)