Mercurial > repos > muon-spectroscopy-computational-project > larch_artemis
diff larch_artemis.py @ 0:2752b2dd7ad6 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:34:23 +0000 |
parents | |
children | 84c8e04bc1a1 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/larch_artemis.py Tue Nov 14 15:34:23 2023 +0000 @@ -0,0 +1,386 @@ +import csv +import faulthandler +import gc +import json +import os +import sys + +from common import get_group + +from larch.fitting import guess, param, param_group +from larch.io import read_athena +from larch.symboltable import Group +from larch.xafs import ( + FeffPathGroup, + FeffitDataSet, + TransformGroup, + autobk, + feffit, + feffit_report, + pre_edge, + xftf, +) + +import matplotlib +import matplotlib.pyplot as plt + +import numpy as np + + +def read_csv_data(input_file, id_field="id"): + csv_data = {} + try: + with open(input_file, encoding="utf8") as csvfile: + reader = csv.DictReader(csvfile, skipinitialspace=True) + for row in reader: + csv_data[int(row[id_field])] = row + except FileNotFoundError: + print("The specified file does not exist") + return csv_data + + +def calc_with_defaults(xafs_group: Group) -> Group: + """Calculate pre_edge and background with default arguments""" + pre_edge(xafs_group) + autobk(xafs_group) + xftf(xafs_group) + return xafs_group + + +def dict_to_gds(data_dict): + dgs_group = param_group() + for par_idx in data_dict: + # gds file structure: + gds_name = data_dict[par_idx]["name"] + gds_val = 0.0 + gds_expr = "" + try: + gds_val = float(data_dict[par_idx]["value"]) + except ValueError: + gds_val = 0.00 + gds_expr = data_dict[par_idx]["expr"] + gds_vary = ( + True + if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" + else False + ) + one_par = None + if gds_vary: + # equivalent to a guess parameter in Demeter + one_par = guess( + name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr + ) + else: + # equivalent to a defined parameter in Demeter + one_par = param( + name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr + ) + if one_par is not None: + dgs_group.__setattr__(gds_name, one_par) + return dgs_group + + +def plot_rmr(path: str, data_set, rmin, rmax): + plt.figure() + plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") + plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") + plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") + plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") + plt.ylabel( + "Magnitude of Fourier Transform of " + r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" + ) + plt.xlabel(r"Radial distance/$\mathrm{\AA}$") + plt.xlim(0, 5) + + plt.fill( + [rmin, rmin, rmax, rmax], + [-rmax, rmax, rmax, -rmax], + color="g", + alpha=0.1, + ) + plt.text(rmax - 0.65, -rmax + 0.5, "fit range") + plt.legend() + plt.savefig(path, format="png") + plt.close("all") + + +def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): + fig = plt.figure(figsize=(16, 4)) + ax1 = fig.add_subplot(121) + ax2 = fig.add_subplot(122) + ax1.plot( + data_set.data.k, + data_set.data.chi * data_set.data.k**2, + color="b", + label="expt.", + ) + ax1.plot( + data_set.model.k, + data_set.model.chi * data_set.data.k**2, + color="r", + label="fit", + ) + ax1.set_xlim(0, 15) + ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") + ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") + + ax1.fill( + [kmin, kmin, kmax, kmax], + [-rmax, rmax, rmax, -rmax], + color="g", + alpha=0.1, + ) + ax1.text(kmax - 1.65, -rmax + 0.5, "fit range") + ax1.legend() + + ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.") + ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit") + ax2.set_xlim(0, 5) + ax2.set_xlabel(r"$R(\mathrm{\AA})$") + ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") + ax2.legend(loc="upper right") + + ax2.fill( + [rmin, rmin, rmax, rmax], + [-rmax, rmax, rmax, -rmax], + color="g", + alpha=0.1, + ) + ax2.text(rmax - 0.65, -rmax + 0.5, "fit range") + fig.savefig(path, format="png") + plt.close("all") + + +def read_gds(gds_file): + gds_pars = read_csv_data(gds_file) + dgs_group = dict_to_gds(gds_pars) + return dgs_group + + +def read_selected_paths_list(file_name): + sp_dict = read_csv_data(file_name) + sp_list = [] + for path_id in sp_dict: + filename = sp_dict[path_id]["filename"] + print(f"Reading selected path for file {filename}") + new_path = FeffPathGroup( + filename=filename, + label=sp_dict[path_id]["label"], + s02=sp_dict[path_id]["s02"], + e0=sp_dict[path_id]["e0"], + sigma2=sp_dict[path_id]["sigma2"], + deltar=sp_dict[path_id]["deltar"], + ) + sp_list.append(new_path) + return sp_list + + +def run_fit(data_group, gds, selected_paths, fv): + # create the transform group (prepare the fit space). + trans = TransformGroup( + fitspace=fv["fitspace"], + kmin=fv["kmin"], + kmax=fv["kmax"], + kweight=fv["kweight"], + dk=fv["dk"], + window=fv["window"], + rmin=fv["rmin"], + rmax=fv["rmax"], + ) + + dset = FeffitDataSet( + data=data_group, pathlist=selected_paths, transform=trans + ) + + out = feffit(gds, dset) + return dset, out + + +def main( + prj_file: str, + gds_file: str, + sp_file: str, + fit_vars: dict, + plot_graph: bool, + series_id: str = "", +) -> Group: + report_path = f"report/fit_report{series_id}.txt" + rmr_path = f"rmr/rmr{series_id}.png" + chikr_path = f"chikr/chikr{series_id}.png" + + athena_project = read_athena(prj_file) + athena_group = get_group(athena_project) + # calc_with_defaults will hang indefinitely (>6 hours recorded) if the + # data contains any NaNs - consider adding an early error here if this is + # not fixed in Larch? + data_group = calc_with_defaults(athena_group) + + print(f"Fitting project from file {data_group.filename}") + + gds = read_gds(gds_file) + selected_paths = read_selected_paths_list(sp_file) + dset, out = run_fit(data_group, gds, selected_paths, fit_vars) + + fit_report = feffit_report(out) + with open(report_path, "w") as fit_report_file: + fit_report_file.write(fit_report) + + if plot_graph: + plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) + plot_chikr( + chikr_path, + dset, + fit_vars["rmin"], + fit_vars["rmax"], + fit_vars["kmin"], + fit_vars["kmax"], + ) + return out + + +def check_threshold( + series_id: str, + threshold: float, + variable: str, + value: float, + early_stopping: bool = False, +): + if abs(value) > threshold: + if early_stopping: + message = ( + "ERROR: Stopping series fit after project " + f"{series_id} as {variable} > {threshold}" + ) + else: + message = ( + f"WARNING: Project {series_id} has {variable} > {threshold}" + ) + + print(message) + return early_stopping + + return False + + +def series_execution( + filepaths: "list[str]", + gds_file: str, + sp_file: str, + fit_vars: dict, + plot_graph: bool, + report_criteria: "list[dict]", + stop_on_error: bool, +) -> "list[list[str]]": + report_criteria = input_values["execution"]["report_criteria"] + id_length = len(str(len(filepaths))) + stop = False + rows = [[f"{c['variable']:>12s}" for c in report_criteria]] + for series_index, series_file in enumerate(filepaths): + series_id = str(series_index).zfill(id_length) + try: + out = main( + series_file, + gds_file, + sp_file, + fit_vars, + plot_graph, + f"_{series_id}", + ) + except ValueError as e: + rows.append([np.NaN for _ in report_criteria]) + if stop_on_error: + print( + f"ERROR: fitting failed for {series_id}" + f" due to following error, stopping:\n{e}" + ) + break + else: + print( + f"WARNING: fitting failed for {series_id} due to following" + f" error, continuing to next project:\n{e}" + ) + continue + + row = [] + for criterium in report_criteria: + stop = parse_row(series_id, out, row, criterium) or stop + rows.append(row) + + gc.collect() + + if stop: + break + + return rows + + +def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict): + action = criterium["action"]["action"] + variable = criterium["variable"] + try: + value = group.__getattribute__(variable) + except AttributeError: + value = group.params[variable].value + + row.append(f"{value:>12f}") + if action == "stop": + return check_threshold( + series_id, + criterium["action"]["threshold"], + variable, + value, + True, + ) + elif action == "warn": + return check_threshold( + series_id, + criterium["action"]["threshold"], + variable, + value, + False, + ) + + return False + + +if __name__ == "__main__": + faulthandler.enable() + # larch imports set this to an interactive backend, so need to change it + matplotlib.use("Agg") + + prj_file = sys.argv[1] + gds_file = sys.argv[2] + sp_file = sys.argv[3] + input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) + fit_vars = input_values["fit_vars"] + plot_graph = input_values["plot_graph"] + + if input_values["execution"]["execution"] == "parallel": + main(prj_file, gds_file, sp_file, fit_vars, plot_graph) + + else: + if os.path.isdir(prj_file): + # Sort the unzipped directory, all filenames should be zero-padded + filepaths = [ + os.path.join(prj_file, p) for p in os.listdir(prj_file) + ] + filepaths.sort() + else: + # DO NOT sort if we have multiple Galaxy datasets - the filenames + # are arbitrary but should be in order + filepaths = prj_file.split(",") + + rows = series_execution( + filepaths, + gds_file, + sp_file, + fit_vars, + plot_graph, + input_values["execution"]["report_criteria"], + input_values["execution"]["stop_on_error"], + ) + if len(rows[0]) > 0: + with open("criteria_report.csv", "w") as f: + writer = csv.writer(f) + writer.writerows(rows)