Mercurial > repos > muon-spectroscopy-computational-project > larch_artemis
comparison larch_artemis.py @ 0:2752b2dd7ad6 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:34:23 +0000 |
parents | |
children | 84c8e04bc1a1 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:2752b2dd7ad6 |
---|---|
1 import csv | |
2 import faulthandler | |
3 import gc | |
4 import json | |
5 import os | |
6 import sys | |
7 | |
8 from common import get_group | |
9 | |
10 from larch.fitting import guess, param, param_group | |
11 from larch.io import read_athena | |
12 from larch.symboltable import Group | |
13 from larch.xafs import ( | |
14 FeffPathGroup, | |
15 FeffitDataSet, | |
16 TransformGroup, | |
17 autobk, | |
18 feffit, | |
19 feffit_report, | |
20 pre_edge, | |
21 xftf, | |
22 ) | |
23 | |
24 import matplotlib | |
25 import matplotlib.pyplot as plt | |
26 | |
27 import numpy as np | |
28 | |
29 | |
30 def read_csv_data(input_file, id_field="id"): | |
31 csv_data = {} | |
32 try: | |
33 with open(input_file, encoding="utf8") as csvfile: | |
34 reader = csv.DictReader(csvfile, skipinitialspace=True) | |
35 for row in reader: | |
36 csv_data[int(row[id_field])] = row | |
37 except FileNotFoundError: | |
38 print("The specified file does not exist") | |
39 return csv_data | |
40 | |
41 | |
42 def calc_with_defaults(xafs_group: Group) -> Group: | |
43 """Calculate pre_edge and background with default arguments""" | |
44 pre_edge(xafs_group) | |
45 autobk(xafs_group) | |
46 xftf(xafs_group) | |
47 return xafs_group | |
48 | |
49 | |
50 def dict_to_gds(data_dict): | |
51 dgs_group = param_group() | |
52 for par_idx in data_dict: | |
53 # gds file structure: | |
54 gds_name = data_dict[par_idx]["name"] | |
55 gds_val = 0.0 | |
56 gds_expr = "" | |
57 try: | |
58 gds_val = float(data_dict[par_idx]["value"]) | |
59 except ValueError: | |
60 gds_val = 0.00 | |
61 gds_expr = data_dict[par_idx]["expr"] | |
62 gds_vary = ( | |
63 True | |
64 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" | |
65 else False | |
66 ) | |
67 one_par = None | |
68 if gds_vary: | |
69 # equivalent to a guess parameter in Demeter | |
70 one_par = guess( | |
71 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr | |
72 ) | |
73 else: | |
74 # equivalent to a defined parameter in Demeter | |
75 one_par = param( | |
76 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr | |
77 ) | |
78 if one_par is not None: | |
79 dgs_group.__setattr__(gds_name, one_par) | |
80 return dgs_group | |
81 | |
82 | |
83 def plot_rmr(path: str, data_set, rmin, rmax): | |
84 plt.figure() | |
85 plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") | |
86 plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") | |
87 plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") | |
88 plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") | |
89 plt.ylabel( | |
90 "Magnitude of Fourier Transform of " | |
91 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" | |
92 ) | |
93 plt.xlabel(r"Radial distance/$\mathrm{\AA}$") | |
94 plt.xlim(0, 5) | |
95 | |
96 plt.fill( | |
97 [rmin, rmin, rmax, rmax], | |
98 [-rmax, rmax, rmax, -rmax], | |
99 color="g", | |
100 alpha=0.1, | |
101 ) | |
102 plt.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
103 plt.legend() | |
104 plt.savefig(path, format="png") | |
105 plt.close("all") | |
106 | |
107 | |
108 def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): | |
109 fig = plt.figure(figsize=(16, 4)) | |
110 ax1 = fig.add_subplot(121) | |
111 ax2 = fig.add_subplot(122) | |
112 ax1.plot( | |
113 data_set.data.k, | |
114 data_set.data.chi * data_set.data.k**2, | |
115 color="b", | |
116 label="expt.", | |
117 ) | |
118 ax1.plot( | |
119 data_set.model.k, | |
120 data_set.model.chi * data_set.data.k**2, | |
121 color="r", | |
122 label="fit", | |
123 ) | |
124 ax1.set_xlim(0, 15) | |
125 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") | |
126 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") | |
127 | |
128 ax1.fill( | |
129 [kmin, kmin, kmax, kmax], | |
130 [-rmax, rmax, rmax, -rmax], | |
131 color="g", | |
132 alpha=0.1, | |
133 ) | |
134 ax1.text(kmax - 1.65, -rmax + 0.5, "fit range") | |
135 ax1.legend() | |
136 | |
137 ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.") | |
138 ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit") | |
139 ax2.set_xlim(0, 5) | |
140 ax2.set_xlabel(r"$R(\mathrm{\AA})$") | |
141 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") | |
142 ax2.legend(loc="upper right") | |
143 | |
144 ax2.fill( | |
145 [rmin, rmin, rmax, rmax], | |
146 [-rmax, rmax, rmax, -rmax], | |
147 color="g", | |
148 alpha=0.1, | |
149 ) | |
150 ax2.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
151 fig.savefig(path, format="png") | |
152 plt.close("all") | |
153 | |
154 | |
155 def read_gds(gds_file): | |
156 gds_pars = read_csv_data(gds_file) | |
157 dgs_group = dict_to_gds(gds_pars) | |
158 return dgs_group | |
159 | |
160 | |
161 def read_selected_paths_list(file_name): | |
162 sp_dict = read_csv_data(file_name) | |
163 sp_list = [] | |
164 for path_id in sp_dict: | |
165 filename = sp_dict[path_id]["filename"] | |
166 print(f"Reading selected path for file {filename}") | |
167 new_path = FeffPathGroup( | |
168 filename=filename, | |
169 label=sp_dict[path_id]["label"], | |
170 s02=sp_dict[path_id]["s02"], | |
171 e0=sp_dict[path_id]["e0"], | |
172 sigma2=sp_dict[path_id]["sigma2"], | |
173 deltar=sp_dict[path_id]["deltar"], | |
174 ) | |
175 sp_list.append(new_path) | |
176 return sp_list | |
177 | |
178 | |
179 def run_fit(data_group, gds, selected_paths, fv): | |
180 # create the transform group (prepare the fit space). | |
181 trans = TransformGroup( | |
182 fitspace=fv["fitspace"], | |
183 kmin=fv["kmin"], | |
184 kmax=fv["kmax"], | |
185 kweight=fv["kweight"], | |
186 dk=fv["dk"], | |
187 window=fv["window"], | |
188 rmin=fv["rmin"], | |
189 rmax=fv["rmax"], | |
190 ) | |
191 | |
192 dset = FeffitDataSet( | |
193 data=data_group, pathlist=selected_paths, transform=trans | |
194 ) | |
195 | |
196 out = feffit(gds, dset) | |
197 return dset, out | |
198 | |
199 | |
200 def main( | |
201 prj_file: str, | |
202 gds_file: str, | |
203 sp_file: str, | |
204 fit_vars: dict, | |
205 plot_graph: bool, | |
206 series_id: str = "", | |
207 ) -> Group: | |
208 report_path = f"report/fit_report{series_id}.txt" | |
209 rmr_path = f"rmr/rmr{series_id}.png" | |
210 chikr_path = f"chikr/chikr{series_id}.png" | |
211 | |
212 athena_project = read_athena(prj_file) | |
213 athena_group = get_group(athena_project) | |
214 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the | |
215 # data contains any NaNs - consider adding an early error here if this is | |
216 # not fixed in Larch? | |
217 data_group = calc_with_defaults(athena_group) | |
218 | |
219 print(f"Fitting project from file {data_group.filename}") | |
220 | |
221 gds = read_gds(gds_file) | |
222 selected_paths = read_selected_paths_list(sp_file) | |
223 dset, out = run_fit(data_group, gds, selected_paths, fit_vars) | |
224 | |
225 fit_report = feffit_report(out) | |
226 with open(report_path, "w") as fit_report_file: | |
227 fit_report_file.write(fit_report) | |
228 | |
229 if plot_graph: | |
230 plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) | |
231 plot_chikr( | |
232 chikr_path, | |
233 dset, | |
234 fit_vars["rmin"], | |
235 fit_vars["rmax"], | |
236 fit_vars["kmin"], | |
237 fit_vars["kmax"], | |
238 ) | |
239 return out | |
240 | |
241 | |
242 def check_threshold( | |
243 series_id: str, | |
244 threshold: float, | |
245 variable: str, | |
246 value: float, | |
247 early_stopping: bool = False, | |
248 ): | |
249 if abs(value) > threshold: | |
250 if early_stopping: | |
251 message = ( | |
252 "ERROR: Stopping series fit after project " | |
253 f"{series_id} as {variable} > {threshold}" | |
254 ) | |
255 else: | |
256 message = ( | |
257 f"WARNING: Project {series_id} has {variable} > {threshold}" | |
258 ) | |
259 | |
260 print(message) | |
261 return early_stopping | |
262 | |
263 return False | |
264 | |
265 | |
266 def series_execution( | |
267 filepaths: "list[str]", | |
268 gds_file: str, | |
269 sp_file: str, | |
270 fit_vars: dict, | |
271 plot_graph: bool, | |
272 report_criteria: "list[dict]", | |
273 stop_on_error: bool, | |
274 ) -> "list[list[str]]": | |
275 report_criteria = input_values["execution"]["report_criteria"] | |
276 id_length = len(str(len(filepaths))) | |
277 stop = False | |
278 rows = [[f"{c['variable']:>12s}" for c in report_criteria]] | |
279 for series_index, series_file in enumerate(filepaths): | |
280 series_id = str(series_index).zfill(id_length) | |
281 try: | |
282 out = main( | |
283 series_file, | |
284 gds_file, | |
285 sp_file, | |
286 fit_vars, | |
287 plot_graph, | |
288 f"_{series_id}", | |
289 ) | |
290 except ValueError as e: | |
291 rows.append([np.NaN for _ in report_criteria]) | |
292 if stop_on_error: | |
293 print( | |
294 f"ERROR: fitting failed for {series_id}" | |
295 f" due to following error, stopping:\n{e}" | |
296 ) | |
297 break | |
298 else: | |
299 print( | |
300 f"WARNING: fitting failed for {series_id} due to following" | |
301 f" error, continuing to next project:\n{e}" | |
302 ) | |
303 continue | |
304 | |
305 row = [] | |
306 for criterium in report_criteria: | |
307 stop = parse_row(series_id, out, row, criterium) or stop | |
308 rows.append(row) | |
309 | |
310 gc.collect() | |
311 | |
312 if stop: | |
313 break | |
314 | |
315 return rows | |
316 | |
317 | |
318 def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict): | |
319 action = criterium["action"]["action"] | |
320 variable = criterium["variable"] | |
321 try: | |
322 value = group.__getattribute__(variable) | |
323 except AttributeError: | |
324 value = group.params[variable].value | |
325 | |
326 row.append(f"{value:>12f}") | |
327 if action == "stop": | |
328 return check_threshold( | |
329 series_id, | |
330 criterium["action"]["threshold"], | |
331 variable, | |
332 value, | |
333 True, | |
334 ) | |
335 elif action == "warn": | |
336 return check_threshold( | |
337 series_id, | |
338 criterium["action"]["threshold"], | |
339 variable, | |
340 value, | |
341 False, | |
342 ) | |
343 | |
344 return False | |
345 | |
346 | |
347 if __name__ == "__main__": | |
348 faulthandler.enable() | |
349 # larch imports set this to an interactive backend, so need to change it | |
350 matplotlib.use("Agg") | |
351 | |
352 prj_file = sys.argv[1] | |
353 gds_file = sys.argv[2] | |
354 sp_file = sys.argv[3] | |
355 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) | |
356 fit_vars = input_values["fit_vars"] | |
357 plot_graph = input_values["plot_graph"] | |
358 | |
359 if input_values["execution"]["execution"] == "parallel": | |
360 main(prj_file, gds_file, sp_file, fit_vars, plot_graph) | |
361 | |
362 else: | |
363 if os.path.isdir(prj_file): | |
364 # Sort the unzipped directory, all filenames should be zero-padded | |
365 filepaths = [ | |
366 os.path.join(prj_file, p) for p in os.listdir(prj_file) | |
367 ] | |
368 filepaths.sort() | |
369 else: | |
370 # DO NOT sort if we have multiple Galaxy datasets - the filenames | |
371 # are arbitrary but should be in order | |
372 filepaths = prj_file.split(",") | |
373 | |
374 rows = series_execution( | |
375 filepaths, | |
376 gds_file, | |
377 sp_file, | |
378 fit_vars, | |
379 plot_graph, | |
380 input_values["execution"]["report_criteria"], | |
381 input_values["execution"]["stop_on_error"], | |
382 ) | |
383 if len(rows[0]) > 0: | |
384 with open("criteria_report.csv", "w") as f: | |
385 writer = csv.writer(f) | |
386 writer.writerows(rows) |