comparison larch_artemis.py @ 0:2752b2dd7ad6 draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 5be486890442dedfb327289d597e1c8110240735
author muon-spectroscopy-computational-project
date Tue, 14 Nov 2023 15:34:23 +0000
parents
children 84c8e04bc1a1
comparison
equal deleted inserted replaced
-1:000000000000 0:2752b2dd7ad6
1 import csv
2 import faulthandler
3 import gc
4 import json
5 import os
6 import sys
7
8 from common import get_group
9
10 from larch.fitting import guess, param, param_group
11 from larch.io import read_athena
12 from larch.symboltable import Group
13 from larch.xafs import (
14 FeffPathGroup,
15 FeffitDataSet,
16 TransformGroup,
17 autobk,
18 feffit,
19 feffit_report,
20 pre_edge,
21 xftf,
22 )
23
24 import matplotlib
25 import matplotlib.pyplot as plt
26
27 import numpy as np
28
29
30 def read_csv_data(input_file, id_field="id"):
31 csv_data = {}
32 try:
33 with open(input_file, encoding="utf8") as csvfile:
34 reader = csv.DictReader(csvfile, skipinitialspace=True)
35 for row in reader:
36 csv_data[int(row[id_field])] = row
37 except FileNotFoundError:
38 print("The specified file does not exist")
39 return csv_data
40
41
42 def calc_with_defaults(xafs_group: Group) -> Group:
43 """Calculate pre_edge and background with default arguments"""
44 pre_edge(xafs_group)
45 autobk(xafs_group)
46 xftf(xafs_group)
47 return xafs_group
48
49
50 def dict_to_gds(data_dict):
51 dgs_group = param_group()
52 for par_idx in data_dict:
53 # gds file structure:
54 gds_name = data_dict[par_idx]["name"]
55 gds_val = 0.0
56 gds_expr = ""
57 try:
58 gds_val = float(data_dict[par_idx]["value"])
59 except ValueError:
60 gds_val = 0.00
61 gds_expr = data_dict[par_idx]["expr"]
62 gds_vary = (
63 True
64 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True"
65 else False
66 )
67 one_par = None
68 if gds_vary:
69 # equivalent to a guess parameter in Demeter
70 one_par = guess(
71 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
72 )
73 else:
74 # equivalent to a defined parameter in Demeter
75 one_par = param(
76 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
77 )
78 if one_par is not None:
79 dgs_group.__setattr__(gds_name, one_par)
80 return dgs_group
81
82
83 def plot_rmr(path: str, data_set, rmin, rmax):
84 plt.figure()
85 plt.plot(data_set.data.r, data_set.data.chir_mag, color="b")
86 plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.")
87 plt.plot(data_set.model.r, data_set.model.chir_mag, color="r")
88 plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit")
89 plt.ylabel(
90 "Magnitude of Fourier Transform of "
91 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
92 )
93 plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
94 plt.xlim(0, 5)
95
96 plt.fill(
97 [rmin, rmin, rmax, rmax],
98 [-rmax, rmax, rmax, -rmax],
99 color="g",
100 alpha=0.1,
101 )
102 plt.text(rmax - 0.65, -rmax + 0.5, "fit range")
103 plt.legend()
104 plt.savefig(path, format="png")
105 plt.close("all")
106
107
108 def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax):
109 fig = plt.figure(figsize=(16, 4))
110 ax1 = fig.add_subplot(121)
111 ax2 = fig.add_subplot(122)
112 ax1.plot(
113 data_set.data.k,
114 data_set.data.chi * data_set.data.k**2,
115 color="b",
116 label="expt.",
117 )
118 ax1.plot(
119 data_set.model.k,
120 data_set.model.chi * data_set.data.k**2,
121 color="r",
122 label="fit",
123 )
124 ax1.set_xlim(0, 15)
125 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
126 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")
127
128 ax1.fill(
129 [kmin, kmin, kmax, kmax],
130 [-rmax, rmax, rmax, -rmax],
131 color="g",
132 alpha=0.1,
133 )
134 ax1.text(kmax - 1.65, -rmax + 0.5, "fit range")
135 ax1.legend()
136
137 ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.")
138 ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit")
139 ax2.set_xlim(0, 5)
140 ax2.set_xlabel(r"$R(\mathrm{\AA})$")
141 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
142 ax2.legend(loc="upper right")
143
144 ax2.fill(
145 [rmin, rmin, rmax, rmax],
146 [-rmax, rmax, rmax, -rmax],
147 color="g",
148 alpha=0.1,
149 )
150 ax2.text(rmax - 0.65, -rmax + 0.5, "fit range")
151 fig.savefig(path, format="png")
152 plt.close("all")
153
154
155 def read_gds(gds_file):
156 gds_pars = read_csv_data(gds_file)
157 dgs_group = dict_to_gds(gds_pars)
158 return dgs_group
159
160
161 def read_selected_paths_list(file_name):
162 sp_dict = read_csv_data(file_name)
163 sp_list = []
164 for path_id in sp_dict:
165 filename = sp_dict[path_id]["filename"]
166 print(f"Reading selected path for file {filename}")
167 new_path = FeffPathGroup(
168 filename=filename,
169 label=sp_dict[path_id]["label"],
170 s02=sp_dict[path_id]["s02"],
171 e0=sp_dict[path_id]["e0"],
172 sigma2=sp_dict[path_id]["sigma2"],
173 deltar=sp_dict[path_id]["deltar"],
174 )
175 sp_list.append(new_path)
176 return sp_list
177
178
179 def run_fit(data_group, gds, selected_paths, fv):
180 # create the transform group (prepare the fit space).
181 trans = TransformGroup(
182 fitspace=fv["fitspace"],
183 kmin=fv["kmin"],
184 kmax=fv["kmax"],
185 kweight=fv["kweight"],
186 dk=fv["dk"],
187 window=fv["window"],
188 rmin=fv["rmin"],
189 rmax=fv["rmax"],
190 )
191
192 dset = FeffitDataSet(
193 data=data_group, pathlist=selected_paths, transform=trans
194 )
195
196 out = feffit(gds, dset)
197 return dset, out
198
199
200 def main(
201 prj_file: str,
202 gds_file: str,
203 sp_file: str,
204 fit_vars: dict,
205 plot_graph: bool,
206 series_id: str = "",
207 ) -> Group:
208 report_path = f"report/fit_report{series_id}.txt"
209 rmr_path = f"rmr/rmr{series_id}.png"
210 chikr_path = f"chikr/chikr{series_id}.png"
211
212 athena_project = read_athena(prj_file)
213 athena_group = get_group(athena_project)
214 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the
215 # data contains any NaNs - consider adding an early error here if this is
216 # not fixed in Larch?
217 data_group = calc_with_defaults(athena_group)
218
219 print(f"Fitting project from file {data_group.filename}")
220
221 gds = read_gds(gds_file)
222 selected_paths = read_selected_paths_list(sp_file)
223 dset, out = run_fit(data_group, gds, selected_paths, fit_vars)
224
225 fit_report = feffit_report(out)
226 with open(report_path, "w") as fit_report_file:
227 fit_report_file.write(fit_report)
228
229 if plot_graph:
230 plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"])
231 plot_chikr(
232 chikr_path,
233 dset,
234 fit_vars["rmin"],
235 fit_vars["rmax"],
236 fit_vars["kmin"],
237 fit_vars["kmax"],
238 )
239 return out
240
241
242 def check_threshold(
243 series_id: str,
244 threshold: float,
245 variable: str,
246 value: float,
247 early_stopping: bool = False,
248 ):
249 if abs(value) > threshold:
250 if early_stopping:
251 message = (
252 "ERROR: Stopping series fit after project "
253 f"{series_id} as {variable} > {threshold}"
254 )
255 else:
256 message = (
257 f"WARNING: Project {series_id} has {variable} > {threshold}"
258 )
259
260 print(message)
261 return early_stopping
262
263 return False
264
265
266 def series_execution(
267 filepaths: "list[str]",
268 gds_file: str,
269 sp_file: str,
270 fit_vars: dict,
271 plot_graph: bool,
272 report_criteria: "list[dict]",
273 stop_on_error: bool,
274 ) -> "list[list[str]]":
275 report_criteria = input_values["execution"]["report_criteria"]
276 id_length = len(str(len(filepaths)))
277 stop = False
278 rows = [[f"{c['variable']:>12s}" for c in report_criteria]]
279 for series_index, series_file in enumerate(filepaths):
280 series_id = str(series_index).zfill(id_length)
281 try:
282 out = main(
283 series_file,
284 gds_file,
285 sp_file,
286 fit_vars,
287 plot_graph,
288 f"_{series_id}",
289 )
290 except ValueError as e:
291 rows.append([np.NaN for _ in report_criteria])
292 if stop_on_error:
293 print(
294 f"ERROR: fitting failed for {series_id}"
295 f" due to following error, stopping:\n{e}"
296 )
297 break
298 else:
299 print(
300 f"WARNING: fitting failed for {series_id} due to following"
301 f" error, continuing to next project:\n{e}"
302 )
303 continue
304
305 row = []
306 for criterium in report_criteria:
307 stop = parse_row(series_id, out, row, criterium) or stop
308 rows.append(row)
309
310 gc.collect()
311
312 if stop:
313 break
314
315 return rows
316
317
318 def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict):
319 action = criterium["action"]["action"]
320 variable = criterium["variable"]
321 try:
322 value = group.__getattribute__(variable)
323 except AttributeError:
324 value = group.params[variable].value
325
326 row.append(f"{value:>12f}")
327 if action == "stop":
328 return check_threshold(
329 series_id,
330 criterium["action"]["threshold"],
331 variable,
332 value,
333 True,
334 )
335 elif action == "warn":
336 return check_threshold(
337 series_id,
338 criterium["action"]["threshold"],
339 variable,
340 value,
341 False,
342 )
343
344 return False
345
346
347 if __name__ == "__main__":
348 faulthandler.enable()
349 # larch imports set this to an interactive backend, so need to change it
350 matplotlib.use("Agg")
351
352 prj_file = sys.argv[1]
353 gds_file = sys.argv[2]
354 sp_file = sys.argv[3]
355 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8"))
356 fit_vars = input_values["fit_vars"]
357 plot_graph = input_values["plot_graph"]
358
359 if input_values["execution"]["execution"] == "parallel":
360 main(prj_file, gds_file, sp_file, fit_vars, plot_graph)
361
362 else:
363 if os.path.isdir(prj_file):
364 # Sort the unzipped directory, all filenames should be zero-padded
365 filepaths = [
366 os.path.join(prj_file, p) for p in os.listdir(prj_file)
367 ]
368 filepaths.sort()
369 else:
370 # DO NOT sort if we have multiple Galaxy datasets - the filenames
371 # are arbitrary but should be in order
372 filepaths = prj_file.split(",")
373
374 rows = series_execution(
375 filepaths,
376 gds_file,
377 sp_file,
378 fit_vars,
379 plot_graph,
380 input_values["execution"]["report_criteria"],
381 input_values["execution"]["stop_on_error"],
382 )
383 if len(rows[0]) > 0:
384 with open("criteria_report.csv", "w") as f:
385 writer = csv.writer(f)
386 writer.writerows(rows)