diff larch_artemis.py @ 5:7acb53ffb96f draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 4814f53888643f1d3667789050914675fffb7d59
author muon-spectroscopy-computational-project
date Fri, 23 Aug 2024 13:46:13 +0000 (6 months ago)
parents 39ab361e6d59
children
line wrap: on
line diff
--- a/larch_artemis.py	Mon Jun 17 13:54:30 2024 +0000
+++ b/larch_artemis.py	Fri Aug 23 13:46:13 2024 +0000
@@ -45,7 +45,7 @@
         try:
             gds_val = float(data_dict[par_idx]["value"])
         except ValueError:
-            gds_val = 0.00
+            continue
         gds_expr = data_dict[par_idx]["expr"]
         gds_vary = (
             True
@@ -68,74 +68,50 @@
     return dgs_group
 
 
-def plot_rmr(path: str, data_set, rmin, rmax):
+def plot_rmr(path: str, datasets: list, rmin, rmax):
     plt.figure()
-    plt.plot(data_set.data.r, data_set.data.chir_mag, color="b")
-    plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.")
-    plt.plot(data_set.model.r, data_set.model.chir_mag, color="r")
-    plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit")
-    plt.ylabel(
-        "Magnitude of Fourier Transform of "
-        r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
-    )
-    plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
-    plt.xlim(0, 5)
+    for i, dataset in enumerate(datasets):
+        plt.subplot(len(datasets), 1, i + 1)
+        data = dataset.data
+        model = dataset.model
+        plt.plot(data.r, data.chir_mag, color="b")
+        plt.plot(data.r, data.chir_re, color="b", label="expt.")
+        plt.plot(model.r, model.chir_mag, color="r")
+        plt.plot(model.r, model.chir_re, color="r", label="fit")
+        plt.ylabel(
+            "Magnitude of Fourier Transform of "
+            r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
+        )
+        plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
+        plt.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)
+        plt.legend()
 
-    plt.fill(
-        [rmin, rmin, rmax, rmax],
-        [-rmax, rmax, rmax, -rmax],
-        color="g",
-        alpha=0.1,
-    )
-    plt.text(rmax - 0.65, -rmax + 0.5, "fit range")
-    plt.legend()
     plt.savefig(path, format="png")
     plt.close("all")
 
 
-def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax):
+def plot_chikr(path: str, datasets, rmin, rmax, kmin, kmax):
     fig = plt.figure(figsize=(16, 4))
-    ax1 = fig.add_subplot(121)
-    ax2 = fig.add_subplot(122)
-    ax1.plot(
-        data_set.data.k,
-        data_set.data.chi * data_set.data.k**2,
-        color="b",
-        label="expt.",
-    )
-    ax1.plot(
-        data_set.model.k,
-        data_set.model.chi * data_set.data.k**2,
-        color="r",
-        label="fit",
-    )
-    ax1.set_xlim(0, 15)
-    ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
-    ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")
+    for i, dataset in enumerate(datasets):
+        data = dataset.data
+        model = dataset.model
+        ax1 = fig.add_subplot(len(datasets), 2, 2*i + 1)
+        ax2 = fig.add_subplot(len(datasets), 2, 2*i + 2)
+        ax1.plot(data.k, data.chi * data.k**2, color="b", label="expt.")
+        ax1.plot(model.k, model.chi * data.k**2, color="r", label="fit")
+        ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
+        ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")
+        ax1.axvspan(xmin=kmin, xmax=kmax, color="g", alpha=0.1)
+        ax1.legend()
 
-    ax1.fill(
-        [kmin, kmin, kmax, kmax],
-        [-rmax, rmax, rmax, -rmax],
-        color="g",
-        alpha=0.1,
-    )
-    ax1.text(kmax - 1.65, -rmax + 0.5, "fit range")
-    ax1.legend()
+        ax2.plot(data.r, data.chir_mag, color="b", label="expt.")
+        ax2.plot(model.r, model.chir_mag, color="r", label="fit")
+        ax2.set_xlim(0, 5)
+        ax2.set_xlabel(r"$R(\mathrm{\AA})$")
+        ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
+        ax2.legend(loc="upper right")
+        ax2.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)
 
-    ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.")
-    ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit")
-    ax2.set_xlim(0, 5)
-    ax2.set_xlabel(r"$R(\mathrm{\AA})$")
-    ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
-    ax2.legend(loc="upper right")
-
-    ax2.fill(
-        [rmin, rmin, rmax, rmax],
-        [-rmax, rmax, rmax, -rmax],
-        color="g",
-        alpha=0.1,
-    )
-    ax2.text(rmax - 0.65, -rmax + 0.5, "fit range")
     fig.savefig(path, format="png")
     plt.close("all")
 
@@ -149,22 +125,31 @@
 def read_selected_paths_list(file_name):
     sp_dict = read_csv_data(file_name)
     sp_list = []
-    for path_id in sp_dict:
-        filename = sp_dict[path_id]["filename"]
+    for path_dict in sp_dict.values():
+        filename = path_dict["filename"]
+        if not os.path.isfile(filename):
+            raise FileNotFoundError(
+                f"{filename} not found, check paths in the Selected Paths "
+                "table match those in the zipped directory structure."
+            )
+
         print(f"Reading selected path for file {filename}")
         new_path = FeffPathGroup(
             filename=filename,
-            label=sp_dict[path_id]["label"],
-            s02=sp_dict[path_id]["s02"],
-            e0=sp_dict[path_id]["e0"],
-            sigma2=sp_dict[path_id]["sigma2"],
-            deltar=sp_dict[path_id]["deltar"],
+            label=path_dict["label"],
+            degen=path_dict["degen"] if path_dict["degen"] != "" else None,
+            s02=path_dict["s02"],
+            e0=path_dict["e0"],
+            sigma2=path_dict["sigma2"],
+            deltar=path_dict["deltar"],
         )
         sp_list.append(new_path)
     return sp_list
 
 
-def run_fit(data_group, gds, selected_paths, fv):
+def run_fit(
+        data_groups: list, gds, pathlist, fv, selected_path_ids: list = None
+):
     # create the transform group (prepare the fit space).
     trans = TransformGroup(
         fitspace=fv["fitspace"],
@@ -177,16 +162,30 @@
         rmax=fv["rmax"],
     )
 
-    dset = FeffitDataSet(
-        data=data_group, pathlist=selected_paths, transform=trans
-    )
+    datasets = []
+    for i, data_group in enumerate(data_groups):
+        if selected_path_ids:
+            selected_paths = []
+            for path_id in selected_path_ids[i]:
+                selected_paths.append(pathlist[path_id - 1])
 
-    out = feffit(gds, dset)
-    return dset, out
+            dataset = FeffitDataSet(
+                data=data_group, pathlist=selected_paths, transform=trans
+            )
+
+        else:
+            dataset = FeffitDataSet(
+                data=data_group, pathlist=pathlist, transform=trans
+            )
+
+        datasets.append(dataset)
+
+    out = feffit(gds, datasets)
+    return datasets, out
 
 
 def main(
-    prj_file: str,
+    prj_file: list,
     gds_file: str,
     sp_file: str,
     fit_vars: dict,
@@ -197,26 +196,40 @@
     rmr_path = f"rmr/rmr{series_id}.png"
     chikr_path = f"chikr/chikr{series_id}.png"
 
+    gds = read_gds(gds_file)
+    pathlist = read_selected_paths_list(sp_file)
+
     # calc_with_defaults will hang indefinitely (>6 hours recorded) if the
     # data contains any NaNs - consider adding an early error here if this is
     # not fixed in Larch?
-    data_group = read_group(prj_file)
-
-    print(f"Fitting project from file {data_group.filename}")
+    selected_path_ids = []
+    if isinstance(prj_file[0], dict):
+        data_groups = []
+        for dataset in prj_file:
+            data_groups.append(read_group(dataset["prj_file"]))
+            selected_path_ids.append([p["path_id"] for p in dataset["paths"]])
+    else:
+        data_groups = [read_group(p) for p in prj_file]
 
-    gds = read_gds(gds_file)
-    selected_paths = read_selected_paths_list(sp_file)
-    dset, out = run_fit(data_group, gds, selected_paths, fit_vars)
+    print(f"Fitting project from file {[d.filename for d in data_groups]}")
+
+    datasets, out = run_fit(
+        data_groups,
+        gds,
+        pathlist,
+        fit_vars,
+        selected_path_ids=selected_path_ids,
+    )
 
     fit_report = feffit_report(out)
     with open(report_path, "w") as fit_report_file:
         fit_report_file.write(fit_report)
 
     if plot_graph:
-        plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"])
+        plot_rmr(rmr_path, datasets, fit_vars["rmin"], fit_vars["rmax"])
         plot_chikr(
             chikr_path,
-            dset,
+            datasets,
             fit_vars["rmin"],
             fit_vars["rmax"],
             fit_vars["kmin"],
@@ -266,7 +279,7 @@
         series_id = str(series_index).zfill(id_length)
         try:
             out = main(
-                series_file,
+                [series_file],
                 gds_file,
                 sp_file,
                 fit_vars,
@@ -343,8 +356,10 @@
     plot_graph = input_values["plot_graph"]
 
     if input_values["execution"]["execution"] == "parallel":
-        main(prj_file, gds_file, sp_file, fit_vars, plot_graph)
-
+        main([prj_file], gds_file, sp_file, fit_vars, plot_graph)
+    elif input_values["execution"]["execution"] == "simultaneous":
+        dataset_dicts = input_values["execution"]["simultaneous"]
+        main(dataset_dicts, gds_file, sp_file, fit_vars, plot_graph)
     else:
         if os.path.isdir(prj_file):
             # Sort the unzipped directory, all filenames should be zero-padded