Mercurial > repos > muon-spectroscopy-computational-project > larch_artemis
diff larch_artemis.py @ 5:7acb53ffb96f draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 4814f53888643f1d3667789050914675fffb7d59
author | muon-spectroscopy-computational-project |
---|---|
date | Fri, 23 Aug 2024 13:46:13 +0000 (6 months ago) |
parents | 39ab361e6d59 |
children |
line wrap: on
line diff
--- a/larch_artemis.py Mon Jun 17 13:54:30 2024 +0000 +++ b/larch_artemis.py Fri Aug 23 13:46:13 2024 +0000 @@ -45,7 +45,7 @@ try: gds_val = float(data_dict[par_idx]["value"]) except ValueError: - gds_val = 0.00 + continue gds_expr = data_dict[par_idx]["expr"] gds_vary = ( True @@ -68,74 +68,50 @@ return dgs_group -def plot_rmr(path: str, data_set, rmin, rmax): +def plot_rmr(path: str, datasets: list, rmin, rmax): plt.figure() - plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") - plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") - plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") - plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") - plt.ylabel( - "Magnitude of Fourier Transform of " - r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" - ) - plt.xlabel(r"Radial distance/$\mathrm{\AA}$") - plt.xlim(0, 5) + for i, dataset in enumerate(datasets): + plt.subplot(len(datasets), 1, i + 1) + data = dataset.data + model = dataset.model + plt.plot(data.r, data.chir_mag, color="b") + plt.plot(data.r, data.chir_re, color="b", label="expt.") + plt.plot(model.r, model.chir_mag, color="r") + plt.plot(model.r, model.chir_re, color="r", label="fit") + plt.ylabel( + "Magnitude of Fourier Transform of " + r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" + ) + plt.xlabel(r"Radial distance/$\mathrm{\AA}$") + plt.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1) + plt.legend() - plt.fill( - [rmin, rmin, rmax, rmax], - [-rmax, rmax, rmax, -rmax], - color="g", - alpha=0.1, - ) - plt.text(rmax - 0.65, -rmax + 0.5, "fit range") - plt.legend() plt.savefig(path, format="png") plt.close("all") -def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): +def plot_chikr(path: str, datasets, rmin, rmax, kmin, kmax): fig = plt.figure(figsize=(16, 4)) - ax1 = fig.add_subplot(121) - ax2 = fig.add_subplot(122) - ax1.plot( - data_set.data.k, - data_set.data.chi * data_set.data.k**2, - color="b", - label="expt.", - ) - ax1.plot( - data_set.model.k, - data_set.model.chi * data_set.data.k**2, - color="r", - label="fit", - ) - ax1.set_xlim(0, 15) - ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") - ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") + for i, dataset in enumerate(datasets): + data = dataset.data + model = dataset.model + ax1 = fig.add_subplot(len(datasets), 2, 2*i + 1) + ax2 = fig.add_subplot(len(datasets), 2, 2*i + 2) + ax1.plot(data.k, data.chi * data.k**2, color="b", label="expt.") + ax1.plot(model.k, model.chi * data.k**2, color="r", label="fit") + ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") + ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") + ax1.axvspan(xmin=kmin, xmax=kmax, color="g", alpha=0.1) + ax1.legend() - ax1.fill( - [kmin, kmin, kmax, kmax], - [-rmax, rmax, rmax, -rmax], - color="g", - alpha=0.1, - ) - ax1.text(kmax - 1.65, -rmax + 0.5, "fit range") - ax1.legend() + ax2.plot(data.r, data.chir_mag, color="b", label="expt.") + ax2.plot(model.r, model.chir_mag, color="r", label="fit") + ax2.set_xlim(0, 5) + ax2.set_xlabel(r"$R(\mathrm{\AA})$") + ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") + ax2.legend(loc="upper right") + ax2.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1) - ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.") - ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit") - ax2.set_xlim(0, 5) - ax2.set_xlabel(r"$R(\mathrm{\AA})$") - ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") - ax2.legend(loc="upper right") - - ax2.fill( - [rmin, rmin, rmax, rmax], - [-rmax, rmax, rmax, -rmax], - color="g", - alpha=0.1, - ) - ax2.text(rmax - 0.65, -rmax + 0.5, "fit range") fig.savefig(path, format="png") plt.close("all") @@ -149,22 +125,31 @@ def read_selected_paths_list(file_name): sp_dict = read_csv_data(file_name) sp_list = [] - for path_id in sp_dict: - filename = sp_dict[path_id]["filename"] + for path_dict in sp_dict.values(): + filename = path_dict["filename"] + if not os.path.isfile(filename): + raise FileNotFoundError( + f"{filename} not found, check paths in the Selected Paths " + "table match those in the zipped directory structure." + ) + print(f"Reading selected path for file {filename}") new_path = FeffPathGroup( filename=filename, - label=sp_dict[path_id]["label"], - s02=sp_dict[path_id]["s02"], - e0=sp_dict[path_id]["e0"], - sigma2=sp_dict[path_id]["sigma2"], - deltar=sp_dict[path_id]["deltar"], + label=path_dict["label"], + degen=path_dict["degen"] if path_dict["degen"] != "" else None, + s02=path_dict["s02"], + e0=path_dict["e0"], + sigma2=path_dict["sigma2"], + deltar=path_dict["deltar"], ) sp_list.append(new_path) return sp_list -def run_fit(data_group, gds, selected_paths, fv): +def run_fit( + data_groups: list, gds, pathlist, fv, selected_path_ids: list = None +): # create the transform group (prepare the fit space). trans = TransformGroup( fitspace=fv["fitspace"], @@ -177,16 +162,30 @@ rmax=fv["rmax"], ) - dset = FeffitDataSet( - data=data_group, pathlist=selected_paths, transform=trans - ) + datasets = [] + for i, data_group in enumerate(data_groups): + if selected_path_ids: + selected_paths = [] + for path_id in selected_path_ids[i]: + selected_paths.append(pathlist[path_id - 1]) - out = feffit(gds, dset) - return dset, out + dataset = FeffitDataSet( + data=data_group, pathlist=selected_paths, transform=trans + ) + + else: + dataset = FeffitDataSet( + data=data_group, pathlist=pathlist, transform=trans + ) + + datasets.append(dataset) + + out = feffit(gds, datasets) + return datasets, out def main( - prj_file: str, + prj_file: list, gds_file: str, sp_file: str, fit_vars: dict, @@ -197,26 +196,40 @@ rmr_path = f"rmr/rmr{series_id}.png" chikr_path = f"chikr/chikr{series_id}.png" + gds = read_gds(gds_file) + pathlist = read_selected_paths_list(sp_file) + # calc_with_defaults will hang indefinitely (>6 hours recorded) if the # data contains any NaNs - consider adding an early error here if this is # not fixed in Larch? - data_group = read_group(prj_file) - - print(f"Fitting project from file {data_group.filename}") + selected_path_ids = [] + if isinstance(prj_file[0], dict): + data_groups = [] + for dataset in prj_file: + data_groups.append(read_group(dataset["prj_file"])) + selected_path_ids.append([p["path_id"] for p in dataset["paths"]]) + else: + data_groups = [read_group(p) for p in prj_file] - gds = read_gds(gds_file) - selected_paths = read_selected_paths_list(sp_file) - dset, out = run_fit(data_group, gds, selected_paths, fit_vars) + print(f"Fitting project from file {[d.filename for d in data_groups]}") + + datasets, out = run_fit( + data_groups, + gds, + pathlist, + fit_vars, + selected_path_ids=selected_path_ids, + ) fit_report = feffit_report(out) with open(report_path, "w") as fit_report_file: fit_report_file.write(fit_report) if plot_graph: - plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) + plot_rmr(rmr_path, datasets, fit_vars["rmin"], fit_vars["rmax"]) plot_chikr( chikr_path, - dset, + datasets, fit_vars["rmin"], fit_vars["rmax"], fit_vars["kmin"], @@ -266,7 +279,7 @@ series_id = str(series_index).zfill(id_length) try: out = main( - series_file, + [series_file], gds_file, sp_file, fit_vars, @@ -343,8 +356,10 @@ plot_graph = input_values["plot_graph"] if input_values["execution"]["execution"] == "parallel": - main(prj_file, gds_file, sp_file, fit_vars, plot_graph) - + main([prj_file], gds_file, sp_file, fit_vars, plot_graph) + elif input_values["execution"]["execution"] == "simultaneous": + dataset_dicts = input_values["execution"]["simultaneous"] + main(dataset_dicts, gds_file, sp_file, fit_vars, plot_graph) else: if os.path.isdir(prj_file): # Sort the unzipped directory, all filenames should be zero-padded