Mercurial > repos > muon-spectroscopy-computational-project > larch_artemis
comparison larch_artemis.py @ 0:2752b2dd7ad6 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 5be486890442dedfb327289d597e1c8110240735
| author | muon-spectroscopy-computational-project |
|---|---|
| date | Tue, 14 Nov 2023 15:34:23 +0000 |
| parents | |
| children | 84c8e04bc1a1 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:2752b2dd7ad6 |
|---|---|
| 1 import csv | |
| 2 import faulthandler | |
| 3 import gc | |
| 4 import json | |
| 5 import os | |
| 6 import sys | |
| 7 | |
| 8 from common import get_group | |
| 9 | |
| 10 from larch.fitting import guess, param, param_group | |
| 11 from larch.io import read_athena | |
| 12 from larch.symboltable import Group | |
| 13 from larch.xafs import ( | |
| 14 FeffPathGroup, | |
| 15 FeffitDataSet, | |
| 16 TransformGroup, | |
| 17 autobk, | |
| 18 feffit, | |
| 19 feffit_report, | |
| 20 pre_edge, | |
| 21 xftf, | |
| 22 ) | |
| 23 | |
| 24 import matplotlib | |
| 25 import matplotlib.pyplot as plt | |
| 26 | |
| 27 import numpy as np | |
| 28 | |
| 29 | |
| 30 def read_csv_data(input_file, id_field="id"): | |
| 31 csv_data = {} | |
| 32 try: | |
| 33 with open(input_file, encoding="utf8") as csvfile: | |
| 34 reader = csv.DictReader(csvfile, skipinitialspace=True) | |
| 35 for row in reader: | |
| 36 csv_data[int(row[id_field])] = row | |
| 37 except FileNotFoundError: | |
| 38 print("The specified file does not exist") | |
| 39 return csv_data | |
| 40 | |
| 41 | |
| 42 def calc_with_defaults(xafs_group: Group) -> Group: | |
| 43 """Calculate pre_edge and background with default arguments""" | |
| 44 pre_edge(xafs_group) | |
| 45 autobk(xafs_group) | |
| 46 xftf(xafs_group) | |
| 47 return xafs_group | |
| 48 | |
| 49 | |
| 50 def dict_to_gds(data_dict): | |
| 51 dgs_group = param_group() | |
| 52 for par_idx in data_dict: | |
| 53 # gds file structure: | |
| 54 gds_name = data_dict[par_idx]["name"] | |
| 55 gds_val = 0.0 | |
| 56 gds_expr = "" | |
| 57 try: | |
| 58 gds_val = float(data_dict[par_idx]["value"]) | |
| 59 except ValueError: | |
| 60 gds_val = 0.00 | |
| 61 gds_expr = data_dict[par_idx]["expr"] | |
| 62 gds_vary = ( | |
| 63 True | |
| 64 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" | |
| 65 else False | |
| 66 ) | |
| 67 one_par = None | |
| 68 if gds_vary: | |
| 69 # equivalent to a guess parameter in Demeter | |
| 70 one_par = guess( | |
| 71 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr | |
| 72 ) | |
| 73 else: | |
| 74 # equivalent to a defined parameter in Demeter | |
| 75 one_par = param( | |
| 76 name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr | |
| 77 ) | |
| 78 if one_par is not None: | |
| 79 dgs_group.__setattr__(gds_name, one_par) | |
| 80 return dgs_group | |
| 81 | |
| 82 | |
| 83 def plot_rmr(path: str, data_set, rmin, rmax): | |
| 84 plt.figure() | |
| 85 plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") | |
| 86 plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") | |
| 87 plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") | |
| 88 plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") | |
| 89 plt.ylabel( | |
| 90 "Magnitude of Fourier Transform of " | |
| 91 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" | |
| 92 ) | |
| 93 plt.xlabel(r"Radial distance/$\mathrm{\AA}$") | |
| 94 plt.xlim(0, 5) | |
| 95 | |
| 96 plt.fill( | |
| 97 [rmin, rmin, rmax, rmax], | |
| 98 [-rmax, rmax, rmax, -rmax], | |
| 99 color="g", | |
| 100 alpha=0.1, | |
| 101 ) | |
| 102 plt.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
| 103 plt.legend() | |
| 104 plt.savefig(path, format="png") | |
| 105 plt.close("all") | |
| 106 | |
| 107 | |
| 108 def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): | |
| 109 fig = plt.figure(figsize=(16, 4)) | |
| 110 ax1 = fig.add_subplot(121) | |
| 111 ax2 = fig.add_subplot(122) | |
| 112 ax1.plot( | |
| 113 data_set.data.k, | |
| 114 data_set.data.chi * data_set.data.k**2, | |
| 115 color="b", | |
| 116 label="expt.", | |
| 117 ) | |
| 118 ax1.plot( | |
| 119 data_set.model.k, | |
| 120 data_set.model.chi * data_set.data.k**2, | |
| 121 color="r", | |
| 122 label="fit", | |
| 123 ) | |
| 124 ax1.set_xlim(0, 15) | |
| 125 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") | |
| 126 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") | |
| 127 | |
| 128 ax1.fill( | |
| 129 [kmin, kmin, kmax, kmax], | |
| 130 [-rmax, rmax, rmax, -rmax], | |
| 131 color="g", | |
| 132 alpha=0.1, | |
| 133 ) | |
| 134 ax1.text(kmax - 1.65, -rmax + 0.5, "fit range") | |
| 135 ax1.legend() | |
| 136 | |
| 137 ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.") | |
| 138 ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit") | |
| 139 ax2.set_xlim(0, 5) | |
| 140 ax2.set_xlabel(r"$R(\mathrm{\AA})$") | |
| 141 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") | |
| 142 ax2.legend(loc="upper right") | |
| 143 | |
| 144 ax2.fill( | |
| 145 [rmin, rmin, rmax, rmax], | |
| 146 [-rmax, rmax, rmax, -rmax], | |
| 147 color="g", | |
| 148 alpha=0.1, | |
| 149 ) | |
| 150 ax2.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
| 151 fig.savefig(path, format="png") | |
| 152 plt.close("all") | |
| 153 | |
| 154 | |
| 155 def read_gds(gds_file): | |
| 156 gds_pars = read_csv_data(gds_file) | |
| 157 dgs_group = dict_to_gds(gds_pars) | |
| 158 return dgs_group | |
| 159 | |
| 160 | |
| 161 def read_selected_paths_list(file_name): | |
| 162 sp_dict = read_csv_data(file_name) | |
| 163 sp_list = [] | |
| 164 for path_id in sp_dict: | |
| 165 filename = sp_dict[path_id]["filename"] | |
| 166 print(f"Reading selected path for file {filename}") | |
| 167 new_path = FeffPathGroup( | |
| 168 filename=filename, | |
| 169 label=sp_dict[path_id]["label"], | |
| 170 s02=sp_dict[path_id]["s02"], | |
| 171 e0=sp_dict[path_id]["e0"], | |
| 172 sigma2=sp_dict[path_id]["sigma2"], | |
| 173 deltar=sp_dict[path_id]["deltar"], | |
| 174 ) | |
| 175 sp_list.append(new_path) | |
| 176 return sp_list | |
| 177 | |
| 178 | |
| 179 def run_fit(data_group, gds, selected_paths, fv): | |
| 180 # create the transform group (prepare the fit space). | |
| 181 trans = TransformGroup( | |
| 182 fitspace=fv["fitspace"], | |
| 183 kmin=fv["kmin"], | |
| 184 kmax=fv["kmax"], | |
| 185 kweight=fv["kweight"], | |
| 186 dk=fv["dk"], | |
| 187 window=fv["window"], | |
| 188 rmin=fv["rmin"], | |
| 189 rmax=fv["rmax"], | |
| 190 ) | |
| 191 | |
| 192 dset = FeffitDataSet( | |
| 193 data=data_group, pathlist=selected_paths, transform=trans | |
| 194 ) | |
| 195 | |
| 196 out = feffit(gds, dset) | |
| 197 return dset, out | |
| 198 | |
| 199 | |
| 200 def main( | |
| 201 prj_file: str, | |
| 202 gds_file: str, | |
| 203 sp_file: str, | |
| 204 fit_vars: dict, | |
| 205 plot_graph: bool, | |
| 206 series_id: str = "", | |
| 207 ) -> Group: | |
| 208 report_path = f"report/fit_report{series_id}.txt" | |
| 209 rmr_path = f"rmr/rmr{series_id}.png" | |
| 210 chikr_path = f"chikr/chikr{series_id}.png" | |
| 211 | |
| 212 athena_project = read_athena(prj_file) | |
| 213 athena_group = get_group(athena_project) | |
| 214 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the | |
| 215 # data contains any NaNs - consider adding an early error here if this is | |
| 216 # not fixed in Larch? | |
| 217 data_group = calc_with_defaults(athena_group) | |
| 218 | |
| 219 print(f"Fitting project from file {data_group.filename}") | |
| 220 | |
| 221 gds = read_gds(gds_file) | |
| 222 selected_paths = read_selected_paths_list(sp_file) | |
| 223 dset, out = run_fit(data_group, gds, selected_paths, fit_vars) | |
| 224 | |
| 225 fit_report = feffit_report(out) | |
| 226 with open(report_path, "w") as fit_report_file: | |
| 227 fit_report_file.write(fit_report) | |
| 228 | |
| 229 if plot_graph: | |
| 230 plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) | |
| 231 plot_chikr( | |
| 232 chikr_path, | |
| 233 dset, | |
| 234 fit_vars["rmin"], | |
| 235 fit_vars["rmax"], | |
| 236 fit_vars["kmin"], | |
| 237 fit_vars["kmax"], | |
| 238 ) | |
| 239 return out | |
| 240 | |
| 241 | |
| 242 def check_threshold( | |
| 243 series_id: str, | |
| 244 threshold: float, | |
| 245 variable: str, | |
| 246 value: float, | |
| 247 early_stopping: bool = False, | |
| 248 ): | |
| 249 if abs(value) > threshold: | |
| 250 if early_stopping: | |
| 251 message = ( | |
| 252 "ERROR: Stopping series fit after project " | |
| 253 f"{series_id} as {variable} > {threshold}" | |
| 254 ) | |
| 255 else: | |
| 256 message = ( | |
| 257 f"WARNING: Project {series_id} has {variable} > {threshold}" | |
| 258 ) | |
| 259 | |
| 260 print(message) | |
| 261 return early_stopping | |
| 262 | |
| 263 return False | |
| 264 | |
| 265 | |
| 266 def series_execution( | |
| 267 filepaths: "list[str]", | |
| 268 gds_file: str, | |
| 269 sp_file: str, | |
| 270 fit_vars: dict, | |
| 271 plot_graph: bool, | |
| 272 report_criteria: "list[dict]", | |
| 273 stop_on_error: bool, | |
| 274 ) -> "list[list[str]]": | |
| 275 report_criteria = input_values["execution"]["report_criteria"] | |
| 276 id_length = len(str(len(filepaths))) | |
| 277 stop = False | |
| 278 rows = [[f"{c['variable']:>12s}" for c in report_criteria]] | |
| 279 for series_index, series_file in enumerate(filepaths): | |
| 280 series_id = str(series_index).zfill(id_length) | |
| 281 try: | |
| 282 out = main( | |
| 283 series_file, | |
| 284 gds_file, | |
| 285 sp_file, | |
| 286 fit_vars, | |
| 287 plot_graph, | |
| 288 f"_{series_id}", | |
| 289 ) | |
| 290 except ValueError as e: | |
| 291 rows.append([np.NaN for _ in report_criteria]) | |
| 292 if stop_on_error: | |
| 293 print( | |
| 294 f"ERROR: fitting failed for {series_id}" | |
| 295 f" due to following error, stopping:\n{e}" | |
| 296 ) | |
| 297 break | |
| 298 else: | |
| 299 print( | |
| 300 f"WARNING: fitting failed for {series_id} due to following" | |
| 301 f" error, continuing to next project:\n{e}" | |
| 302 ) | |
| 303 continue | |
| 304 | |
| 305 row = [] | |
| 306 for criterium in report_criteria: | |
| 307 stop = parse_row(series_id, out, row, criterium) or stop | |
| 308 rows.append(row) | |
| 309 | |
| 310 gc.collect() | |
| 311 | |
| 312 if stop: | |
| 313 break | |
| 314 | |
| 315 return rows | |
| 316 | |
| 317 | |
| 318 def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict): | |
| 319 action = criterium["action"]["action"] | |
| 320 variable = criterium["variable"] | |
| 321 try: | |
| 322 value = group.__getattribute__(variable) | |
| 323 except AttributeError: | |
| 324 value = group.params[variable].value | |
| 325 | |
| 326 row.append(f"{value:>12f}") | |
| 327 if action == "stop": | |
| 328 return check_threshold( | |
| 329 series_id, | |
| 330 criterium["action"]["threshold"], | |
| 331 variable, | |
| 332 value, | |
| 333 True, | |
| 334 ) | |
| 335 elif action == "warn": | |
| 336 return check_threshold( | |
| 337 series_id, | |
| 338 criterium["action"]["threshold"], | |
| 339 variable, | |
| 340 value, | |
| 341 False, | |
| 342 ) | |
| 343 | |
| 344 return False | |
| 345 | |
| 346 | |
| 347 if __name__ == "__main__": | |
| 348 faulthandler.enable() | |
| 349 # larch imports set this to an interactive backend, so need to change it | |
| 350 matplotlib.use("Agg") | |
| 351 | |
| 352 prj_file = sys.argv[1] | |
| 353 gds_file = sys.argv[2] | |
| 354 sp_file = sys.argv[3] | |
| 355 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) | |
| 356 fit_vars = input_values["fit_vars"] | |
| 357 plot_graph = input_values["plot_graph"] | |
| 358 | |
| 359 if input_values["execution"]["execution"] == "parallel": | |
| 360 main(prj_file, gds_file, sp_file, fit_vars, plot_graph) | |
| 361 | |
| 362 else: | |
| 363 if os.path.isdir(prj_file): | |
| 364 # Sort the unzipped directory, all filenames should be zero-padded | |
| 365 filepaths = [ | |
| 366 os.path.join(prj_file, p) for p in os.listdir(prj_file) | |
| 367 ] | |
| 368 filepaths.sort() | |
| 369 else: | |
| 370 # DO NOT sort if we have multiple Galaxy datasets - the filenames | |
| 371 # are arbitrary but should be in order | |
| 372 filepaths = prj_file.split(",") | |
| 373 | |
| 374 rows = series_execution( | |
| 375 filepaths, | |
| 376 gds_file, | |
| 377 sp_file, | |
| 378 fit_vars, | |
| 379 plot_graph, | |
| 380 input_values["execution"]["report_criteria"], | |
| 381 input_values["execution"]["stop_on_error"], | |
| 382 ) | |
| 383 if len(rows[0]) > 0: | |
| 384 with open("criteria_report.csv", "w") as f: | |
| 385 writer = csv.writer(f) | |
| 386 writer.writerows(rows) |
