Mercurial > repos > muon-spectroscopy-computational-project > larch_artemis
comparison larch_artemis.py @ 5:7acb53ffb96f draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 4814f53888643f1d3667789050914675fffb7d59
author | muon-spectroscopy-computational-project |
---|---|
date | Fri, 23 Aug 2024 13:46:13 +0000 (7 months ago) |
parents | 39ab361e6d59 |
children |
comparison
equal
deleted
inserted
replaced
4:39ab361e6d59 | 5:7acb53ffb96f |
---|---|
43 gds_val = 0.0 | 43 gds_val = 0.0 |
44 gds_expr = "" | 44 gds_expr = "" |
45 try: | 45 try: |
46 gds_val = float(data_dict[par_idx]["value"]) | 46 gds_val = float(data_dict[par_idx]["value"]) |
47 except ValueError: | 47 except ValueError: |
48 gds_val = 0.00 | 48 continue |
49 gds_expr = data_dict[par_idx]["expr"] | 49 gds_expr = data_dict[par_idx]["expr"] |
50 gds_vary = ( | 50 gds_vary = ( |
51 True | 51 True |
52 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" | 52 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" |
53 else False | 53 else False |
66 if one_par is not None: | 66 if one_par is not None: |
67 dgs_group.__setattr__(gds_name, one_par) | 67 dgs_group.__setattr__(gds_name, one_par) |
68 return dgs_group | 68 return dgs_group |
69 | 69 |
70 | 70 |
71 def plot_rmr(path: str, data_set, rmin, rmax): | 71 def plot_rmr(path: str, datasets: list, rmin, rmax): |
72 plt.figure() | 72 plt.figure() |
73 plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") | 73 for i, dataset in enumerate(datasets): |
74 plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") | 74 plt.subplot(len(datasets), 1, i + 1) |
75 plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") | 75 data = dataset.data |
76 plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") | 76 model = dataset.model |
77 plt.ylabel( | 77 plt.plot(data.r, data.chir_mag, color="b") |
78 "Magnitude of Fourier Transform of " | 78 plt.plot(data.r, data.chir_re, color="b", label="expt.") |
79 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" | 79 plt.plot(model.r, model.chir_mag, color="r") |
80 ) | 80 plt.plot(model.r, model.chir_re, color="r", label="fit") |
81 plt.xlabel(r"Radial distance/$\mathrm{\AA}$") | 81 plt.ylabel( |
82 plt.xlim(0, 5) | 82 "Magnitude of Fourier Transform of " |
83 | 83 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" |
84 plt.fill( | 84 ) |
85 [rmin, rmin, rmax, rmax], | 85 plt.xlabel(r"Radial distance/$\mathrm{\AA}$") |
86 [-rmax, rmax, rmax, -rmax], | 86 plt.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1) |
87 color="g", | 87 plt.legend() |
88 alpha=0.1, | 88 |
89 ) | |
90 plt.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
91 plt.legend() | |
92 plt.savefig(path, format="png") | 89 plt.savefig(path, format="png") |
93 plt.close("all") | 90 plt.close("all") |
94 | 91 |
95 | 92 |
96 def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): | 93 def plot_chikr(path: str, datasets, rmin, rmax, kmin, kmax): |
97 fig = plt.figure(figsize=(16, 4)) | 94 fig = plt.figure(figsize=(16, 4)) |
98 ax1 = fig.add_subplot(121) | 95 for i, dataset in enumerate(datasets): |
99 ax2 = fig.add_subplot(122) | 96 data = dataset.data |
100 ax1.plot( | 97 model = dataset.model |
101 data_set.data.k, | 98 ax1 = fig.add_subplot(len(datasets), 2, 2*i + 1) |
102 data_set.data.chi * data_set.data.k**2, | 99 ax2 = fig.add_subplot(len(datasets), 2, 2*i + 2) |
103 color="b", | 100 ax1.plot(data.k, data.chi * data.k**2, color="b", label="expt.") |
104 label="expt.", | 101 ax1.plot(model.k, model.chi * data.k**2, color="r", label="fit") |
105 ) | 102 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") |
106 ax1.plot( | 103 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") |
107 data_set.model.k, | 104 ax1.axvspan(xmin=kmin, xmax=kmax, color="g", alpha=0.1) |
108 data_set.model.chi * data_set.data.k**2, | 105 ax1.legend() |
109 color="r", | 106 |
110 label="fit", | 107 ax2.plot(data.r, data.chir_mag, color="b", label="expt.") |
111 ) | 108 ax2.plot(model.r, model.chir_mag, color="r", label="fit") |
112 ax1.set_xlim(0, 15) | 109 ax2.set_xlim(0, 5) |
113 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") | 110 ax2.set_xlabel(r"$R(\mathrm{\AA})$") |
114 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") | 111 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") |
115 | 112 ax2.legend(loc="upper right") |
116 ax1.fill( | 113 ax2.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1) |
117 [kmin, kmin, kmax, kmax], | 114 |
118 [-rmax, rmax, rmax, -rmax], | |
119 color="g", | |
120 alpha=0.1, | |
121 ) | |
122 ax1.text(kmax - 1.65, -rmax + 0.5, "fit range") | |
123 ax1.legend() | |
124 | |
125 ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.") | |
126 ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit") | |
127 ax2.set_xlim(0, 5) | |
128 ax2.set_xlabel(r"$R(\mathrm{\AA})$") | |
129 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$") | |
130 ax2.legend(loc="upper right") | |
131 | |
132 ax2.fill( | |
133 [rmin, rmin, rmax, rmax], | |
134 [-rmax, rmax, rmax, -rmax], | |
135 color="g", | |
136 alpha=0.1, | |
137 ) | |
138 ax2.text(rmax - 0.65, -rmax + 0.5, "fit range") | |
139 fig.savefig(path, format="png") | 115 fig.savefig(path, format="png") |
140 plt.close("all") | 116 plt.close("all") |
141 | 117 |
142 | 118 |
143 def read_gds(gds_file): | 119 def read_gds(gds_file): |
147 | 123 |
148 | 124 |
149 def read_selected_paths_list(file_name): | 125 def read_selected_paths_list(file_name): |
150 sp_dict = read_csv_data(file_name) | 126 sp_dict = read_csv_data(file_name) |
151 sp_list = [] | 127 sp_list = [] |
152 for path_id in sp_dict: | 128 for path_dict in sp_dict.values(): |
153 filename = sp_dict[path_id]["filename"] | 129 filename = path_dict["filename"] |
130 if not os.path.isfile(filename): | |
131 raise FileNotFoundError( | |
132 f"{filename} not found, check paths in the Selected Paths " | |
133 "table match those in the zipped directory structure." | |
134 ) | |
135 | |
154 print(f"Reading selected path for file {filename}") | 136 print(f"Reading selected path for file {filename}") |
155 new_path = FeffPathGroup( | 137 new_path = FeffPathGroup( |
156 filename=filename, | 138 filename=filename, |
157 label=sp_dict[path_id]["label"], | 139 label=path_dict["label"], |
158 s02=sp_dict[path_id]["s02"], | 140 degen=path_dict["degen"] if path_dict["degen"] != "" else None, |
159 e0=sp_dict[path_id]["e0"], | 141 s02=path_dict["s02"], |
160 sigma2=sp_dict[path_id]["sigma2"], | 142 e0=path_dict["e0"], |
161 deltar=sp_dict[path_id]["deltar"], | 143 sigma2=path_dict["sigma2"], |
144 deltar=path_dict["deltar"], | |
162 ) | 145 ) |
163 sp_list.append(new_path) | 146 sp_list.append(new_path) |
164 return sp_list | 147 return sp_list |
165 | 148 |
166 | 149 |
167 def run_fit(data_group, gds, selected_paths, fv): | 150 def run_fit( |
151 data_groups: list, gds, pathlist, fv, selected_path_ids: list = None | |
152 ): | |
168 # create the transform group (prepare the fit space). | 153 # create the transform group (prepare the fit space). |
169 trans = TransformGroup( | 154 trans = TransformGroup( |
170 fitspace=fv["fitspace"], | 155 fitspace=fv["fitspace"], |
171 kmin=fv["kmin"], | 156 kmin=fv["kmin"], |
172 kmax=fv["kmax"], | 157 kmax=fv["kmax"], |
175 window=fv["window"], | 160 window=fv["window"], |
176 rmin=fv["rmin"], | 161 rmin=fv["rmin"], |
177 rmax=fv["rmax"], | 162 rmax=fv["rmax"], |
178 ) | 163 ) |
179 | 164 |
180 dset = FeffitDataSet( | 165 datasets = [] |
181 data=data_group, pathlist=selected_paths, transform=trans | 166 for i, data_group in enumerate(data_groups): |
182 ) | 167 if selected_path_ids: |
183 | 168 selected_paths = [] |
184 out = feffit(gds, dset) | 169 for path_id in selected_path_ids[i]: |
185 return dset, out | 170 selected_paths.append(pathlist[path_id - 1]) |
171 | |
172 dataset = FeffitDataSet( | |
173 data=data_group, pathlist=selected_paths, transform=trans | |
174 ) | |
175 | |
176 else: | |
177 dataset = FeffitDataSet( | |
178 data=data_group, pathlist=pathlist, transform=trans | |
179 ) | |
180 | |
181 datasets.append(dataset) | |
182 | |
183 out = feffit(gds, datasets) | |
184 return datasets, out | |
186 | 185 |
187 | 186 |
188 def main( | 187 def main( |
189 prj_file: str, | 188 prj_file: list, |
190 gds_file: str, | 189 gds_file: str, |
191 sp_file: str, | 190 sp_file: str, |
192 fit_vars: dict, | 191 fit_vars: dict, |
193 plot_graph: bool, | 192 plot_graph: bool, |
194 series_id: str = "", | 193 series_id: str = "", |
195 ) -> Group: | 194 ) -> Group: |
196 report_path = f"report/fit_report{series_id}.txt" | 195 report_path = f"report/fit_report{series_id}.txt" |
197 rmr_path = f"rmr/rmr{series_id}.png" | 196 rmr_path = f"rmr/rmr{series_id}.png" |
198 chikr_path = f"chikr/chikr{series_id}.png" | 197 chikr_path = f"chikr/chikr{series_id}.png" |
199 | 198 |
199 gds = read_gds(gds_file) | |
200 pathlist = read_selected_paths_list(sp_file) | |
201 | |
200 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the | 202 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the |
201 # data contains any NaNs - consider adding an early error here if this is | 203 # data contains any NaNs - consider adding an early error here if this is |
202 # not fixed in Larch? | 204 # not fixed in Larch? |
203 data_group = read_group(prj_file) | 205 selected_path_ids = [] |
204 | 206 if isinstance(prj_file[0], dict): |
205 print(f"Fitting project from file {data_group.filename}") | 207 data_groups = [] |
206 | 208 for dataset in prj_file: |
207 gds = read_gds(gds_file) | 209 data_groups.append(read_group(dataset["prj_file"])) |
208 selected_paths = read_selected_paths_list(sp_file) | 210 selected_path_ids.append([p["path_id"] for p in dataset["paths"]]) |
209 dset, out = run_fit(data_group, gds, selected_paths, fit_vars) | 211 else: |
212 data_groups = [read_group(p) for p in prj_file] | |
213 | |
214 print(f"Fitting project from file {[d.filename for d in data_groups]}") | |
215 | |
216 datasets, out = run_fit( | |
217 data_groups, | |
218 gds, | |
219 pathlist, | |
220 fit_vars, | |
221 selected_path_ids=selected_path_ids, | |
222 ) | |
210 | 223 |
211 fit_report = feffit_report(out) | 224 fit_report = feffit_report(out) |
212 with open(report_path, "w") as fit_report_file: | 225 with open(report_path, "w") as fit_report_file: |
213 fit_report_file.write(fit_report) | 226 fit_report_file.write(fit_report) |
214 | 227 |
215 if plot_graph: | 228 if plot_graph: |
216 plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) | 229 plot_rmr(rmr_path, datasets, fit_vars["rmin"], fit_vars["rmax"]) |
217 plot_chikr( | 230 plot_chikr( |
218 chikr_path, | 231 chikr_path, |
219 dset, | 232 datasets, |
220 fit_vars["rmin"], | 233 fit_vars["rmin"], |
221 fit_vars["rmax"], | 234 fit_vars["rmax"], |
222 fit_vars["kmin"], | 235 fit_vars["kmin"], |
223 fit_vars["kmax"], | 236 fit_vars["kmax"], |
224 ) | 237 ) |
264 rows = [[f"{c['variable']:>12s}" for c in report_criteria]] | 277 rows = [[f"{c['variable']:>12s}" for c in report_criteria]] |
265 for series_index, series_file in enumerate(filepaths): | 278 for series_index, series_file in enumerate(filepaths): |
266 series_id = str(series_index).zfill(id_length) | 279 series_id = str(series_index).zfill(id_length) |
267 try: | 280 try: |
268 out = main( | 281 out = main( |
269 series_file, | 282 [series_file], |
270 gds_file, | 283 gds_file, |
271 sp_file, | 284 sp_file, |
272 fit_vars, | 285 fit_vars, |
273 plot_graph, | 286 plot_graph, |
274 f"_{series_id}", | 287 f"_{series_id}", |
341 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) | 354 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) |
342 fit_vars = input_values["fit_vars"] | 355 fit_vars = input_values["fit_vars"] |
343 plot_graph = input_values["plot_graph"] | 356 plot_graph = input_values["plot_graph"] |
344 | 357 |
345 if input_values["execution"]["execution"] == "parallel": | 358 if input_values["execution"]["execution"] == "parallel": |
346 main(prj_file, gds_file, sp_file, fit_vars, plot_graph) | 359 main([prj_file], gds_file, sp_file, fit_vars, plot_graph) |
347 | 360 elif input_values["execution"]["execution"] == "simultaneous": |
361 dataset_dicts = input_values["execution"]["simultaneous"] | |
362 main(dataset_dicts, gds_file, sp_file, fit_vars, plot_graph) | |
348 else: | 363 else: |
349 if os.path.isdir(prj_file): | 364 if os.path.isdir(prj_file): |
350 # Sort the unzipped directory, all filenames should be zero-padded | 365 # Sort the unzipped directory, all filenames should be zero-padded |
351 filepaths = [ | 366 filepaths = [ |
352 os.path.join(prj_file, p) for p in os.listdir(prj_file) | 367 os.path.join(prj_file, p) for p in os.listdir(prj_file) |