comparison larch_artemis.py @ 5:7acb53ffb96f draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 4814f53888643f1d3667789050914675fffb7d59
author muon-spectroscopy-computational-project
date Fri, 23 Aug 2024 13:46:13 +0000 (7 months ago)
parents 39ab361e6d59
children
comparison
equal deleted inserted replaced
4:39ab361e6d59 5:7acb53ffb96f
43 gds_val = 0.0 43 gds_val = 0.0
44 gds_expr = "" 44 gds_expr = ""
45 try: 45 try:
46 gds_val = float(data_dict[par_idx]["value"]) 46 gds_val = float(data_dict[par_idx]["value"])
47 except ValueError: 47 except ValueError:
48 gds_val = 0.00 48 continue
49 gds_expr = data_dict[par_idx]["expr"] 49 gds_expr = data_dict[par_idx]["expr"]
50 gds_vary = ( 50 gds_vary = (
51 True 51 True
52 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True" 52 if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True"
53 else False 53 else False
66 if one_par is not None: 66 if one_par is not None:
67 dgs_group.__setattr__(gds_name, one_par) 67 dgs_group.__setattr__(gds_name, one_par)
68 return dgs_group 68 return dgs_group
69 69
70 70
71 def plot_rmr(path: str, data_set, rmin, rmax): 71 def plot_rmr(path: str, datasets: list, rmin, rmax):
72 plt.figure() 72 plt.figure()
73 plt.plot(data_set.data.r, data_set.data.chir_mag, color="b") 73 for i, dataset in enumerate(datasets):
74 plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.") 74 plt.subplot(len(datasets), 1, i + 1)
75 plt.plot(data_set.model.r, data_set.model.chir_mag, color="r") 75 data = dataset.data
76 plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit") 76 model = dataset.model
77 plt.ylabel( 77 plt.plot(data.r, data.chir_mag, color="b")
78 "Magnitude of Fourier Transform of " 78 plt.plot(data.r, data.chir_re, color="b", label="expt.")
79 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$" 79 plt.plot(model.r, model.chir_mag, color="r")
80 ) 80 plt.plot(model.r, model.chir_re, color="r", label="fit")
81 plt.xlabel(r"Radial distance/$\mathrm{\AA}$") 81 plt.ylabel(
82 plt.xlim(0, 5) 82 "Magnitude of Fourier Transform of "
83 83 r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
84 plt.fill( 84 )
85 [rmin, rmin, rmax, rmax], 85 plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
86 [-rmax, rmax, rmax, -rmax], 86 plt.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)
87 color="g", 87 plt.legend()
88 alpha=0.1, 88
89 )
90 plt.text(rmax - 0.65, -rmax + 0.5, "fit range")
91 plt.legend()
92 plt.savefig(path, format="png") 89 plt.savefig(path, format="png")
93 plt.close("all") 90 plt.close("all")
94 91
95 92
96 def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax): 93 def plot_chikr(path: str, datasets, rmin, rmax, kmin, kmax):
97 fig = plt.figure(figsize=(16, 4)) 94 fig = plt.figure(figsize=(16, 4))
98 ax1 = fig.add_subplot(121) 95 for i, dataset in enumerate(datasets):
99 ax2 = fig.add_subplot(122) 96 data = dataset.data
100 ax1.plot( 97 model = dataset.model
101 data_set.data.k, 98 ax1 = fig.add_subplot(len(datasets), 2, 2*i + 1)
102 data_set.data.chi * data_set.data.k**2, 99 ax2 = fig.add_subplot(len(datasets), 2, 2*i + 2)
103 color="b", 100 ax1.plot(data.k, data.chi * data.k**2, color="b", label="expt.")
104 label="expt.", 101 ax1.plot(model.k, model.chi * data.k**2, color="r", label="fit")
105 ) 102 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
106 ax1.plot( 103 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")
107 data_set.model.k, 104 ax1.axvspan(xmin=kmin, xmax=kmax, color="g", alpha=0.1)
108 data_set.model.chi * data_set.data.k**2, 105 ax1.legend()
109 color="r", 106
110 label="fit", 107 ax2.plot(data.r, data.chir_mag, color="b", label="expt.")
111 ) 108 ax2.plot(model.r, model.chir_mag, color="r", label="fit")
112 ax1.set_xlim(0, 15) 109 ax2.set_xlim(0, 5)
113 ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$") 110 ax2.set_xlabel(r"$R(\mathrm{\AA})$")
114 ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$") 111 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
115 112 ax2.legend(loc="upper right")
116 ax1.fill( 113 ax2.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)
117 [kmin, kmin, kmax, kmax], 114
118 [-rmax, rmax, rmax, -rmax],
119 color="g",
120 alpha=0.1,
121 )
122 ax1.text(kmax - 1.65, -rmax + 0.5, "fit range")
123 ax1.legend()
124
125 ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.")
126 ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit")
127 ax2.set_xlim(0, 5)
128 ax2.set_xlabel(r"$R(\mathrm{\AA})$")
129 ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
130 ax2.legend(loc="upper right")
131
132 ax2.fill(
133 [rmin, rmin, rmax, rmax],
134 [-rmax, rmax, rmax, -rmax],
135 color="g",
136 alpha=0.1,
137 )
138 ax2.text(rmax - 0.65, -rmax + 0.5, "fit range")
139 fig.savefig(path, format="png") 115 fig.savefig(path, format="png")
140 plt.close("all") 116 plt.close("all")
141 117
142 118
143 def read_gds(gds_file): 119 def read_gds(gds_file):
147 123
148 124
149 def read_selected_paths_list(file_name): 125 def read_selected_paths_list(file_name):
150 sp_dict = read_csv_data(file_name) 126 sp_dict = read_csv_data(file_name)
151 sp_list = [] 127 sp_list = []
152 for path_id in sp_dict: 128 for path_dict in sp_dict.values():
153 filename = sp_dict[path_id]["filename"] 129 filename = path_dict["filename"]
130 if not os.path.isfile(filename):
131 raise FileNotFoundError(
132 f"{filename} not found, check paths in the Selected Paths "
133 "table match those in the zipped directory structure."
134 )
135
154 print(f"Reading selected path for file {filename}") 136 print(f"Reading selected path for file {filename}")
155 new_path = FeffPathGroup( 137 new_path = FeffPathGroup(
156 filename=filename, 138 filename=filename,
157 label=sp_dict[path_id]["label"], 139 label=path_dict["label"],
158 s02=sp_dict[path_id]["s02"], 140 degen=path_dict["degen"] if path_dict["degen"] != "" else None,
159 e0=sp_dict[path_id]["e0"], 141 s02=path_dict["s02"],
160 sigma2=sp_dict[path_id]["sigma2"], 142 e0=path_dict["e0"],
161 deltar=sp_dict[path_id]["deltar"], 143 sigma2=path_dict["sigma2"],
144 deltar=path_dict["deltar"],
162 ) 145 )
163 sp_list.append(new_path) 146 sp_list.append(new_path)
164 return sp_list 147 return sp_list
165 148
166 149
167 def run_fit(data_group, gds, selected_paths, fv): 150 def run_fit(
151 data_groups: list, gds, pathlist, fv, selected_path_ids: list = None
152 ):
168 # create the transform group (prepare the fit space). 153 # create the transform group (prepare the fit space).
169 trans = TransformGroup( 154 trans = TransformGroup(
170 fitspace=fv["fitspace"], 155 fitspace=fv["fitspace"],
171 kmin=fv["kmin"], 156 kmin=fv["kmin"],
172 kmax=fv["kmax"], 157 kmax=fv["kmax"],
175 window=fv["window"], 160 window=fv["window"],
176 rmin=fv["rmin"], 161 rmin=fv["rmin"],
177 rmax=fv["rmax"], 162 rmax=fv["rmax"],
178 ) 163 )
179 164
180 dset = FeffitDataSet( 165 datasets = []
181 data=data_group, pathlist=selected_paths, transform=trans 166 for i, data_group in enumerate(data_groups):
182 ) 167 if selected_path_ids:
183 168 selected_paths = []
184 out = feffit(gds, dset) 169 for path_id in selected_path_ids[i]:
185 return dset, out 170 selected_paths.append(pathlist[path_id - 1])
171
172 dataset = FeffitDataSet(
173 data=data_group, pathlist=selected_paths, transform=trans
174 )
175
176 else:
177 dataset = FeffitDataSet(
178 data=data_group, pathlist=pathlist, transform=trans
179 )
180
181 datasets.append(dataset)
182
183 out = feffit(gds, datasets)
184 return datasets, out
186 185
187 186
188 def main( 187 def main(
189 prj_file: str, 188 prj_file: list,
190 gds_file: str, 189 gds_file: str,
191 sp_file: str, 190 sp_file: str,
192 fit_vars: dict, 191 fit_vars: dict,
193 plot_graph: bool, 192 plot_graph: bool,
194 series_id: str = "", 193 series_id: str = "",
195 ) -> Group: 194 ) -> Group:
196 report_path = f"report/fit_report{series_id}.txt" 195 report_path = f"report/fit_report{series_id}.txt"
197 rmr_path = f"rmr/rmr{series_id}.png" 196 rmr_path = f"rmr/rmr{series_id}.png"
198 chikr_path = f"chikr/chikr{series_id}.png" 197 chikr_path = f"chikr/chikr{series_id}.png"
199 198
199 gds = read_gds(gds_file)
200 pathlist = read_selected_paths_list(sp_file)
201
200 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the 202 # calc_with_defaults will hang indefinitely (>6 hours recorded) if the
201 # data contains any NaNs - consider adding an early error here if this is 203 # data contains any NaNs - consider adding an early error here if this is
202 # not fixed in Larch? 204 # not fixed in Larch?
203 data_group = read_group(prj_file) 205 selected_path_ids = []
204 206 if isinstance(prj_file[0], dict):
205 print(f"Fitting project from file {data_group.filename}") 207 data_groups = []
206 208 for dataset in prj_file:
207 gds = read_gds(gds_file) 209 data_groups.append(read_group(dataset["prj_file"]))
208 selected_paths = read_selected_paths_list(sp_file) 210 selected_path_ids.append([p["path_id"] for p in dataset["paths"]])
209 dset, out = run_fit(data_group, gds, selected_paths, fit_vars) 211 else:
212 data_groups = [read_group(p) for p in prj_file]
213
214 print(f"Fitting project from file {[d.filename for d in data_groups]}")
215
216 datasets, out = run_fit(
217 data_groups,
218 gds,
219 pathlist,
220 fit_vars,
221 selected_path_ids=selected_path_ids,
222 )
210 223
211 fit_report = feffit_report(out) 224 fit_report = feffit_report(out)
212 with open(report_path, "w") as fit_report_file: 225 with open(report_path, "w") as fit_report_file:
213 fit_report_file.write(fit_report) 226 fit_report_file.write(fit_report)
214 227
215 if plot_graph: 228 if plot_graph:
216 plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"]) 229 plot_rmr(rmr_path, datasets, fit_vars["rmin"], fit_vars["rmax"])
217 plot_chikr( 230 plot_chikr(
218 chikr_path, 231 chikr_path,
219 dset, 232 datasets,
220 fit_vars["rmin"], 233 fit_vars["rmin"],
221 fit_vars["rmax"], 234 fit_vars["rmax"],
222 fit_vars["kmin"], 235 fit_vars["kmin"],
223 fit_vars["kmax"], 236 fit_vars["kmax"],
224 ) 237 )
264 rows = [[f"{c['variable']:>12s}" for c in report_criteria]] 277 rows = [[f"{c['variable']:>12s}" for c in report_criteria]]
265 for series_index, series_file in enumerate(filepaths): 278 for series_index, series_file in enumerate(filepaths):
266 series_id = str(series_index).zfill(id_length) 279 series_id = str(series_index).zfill(id_length)
267 try: 280 try:
268 out = main( 281 out = main(
269 series_file, 282 [series_file],
270 gds_file, 283 gds_file,
271 sp_file, 284 sp_file,
272 fit_vars, 285 fit_vars,
273 plot_graph, 286 plot_graph,
274 f"_{series_id}", 287 f"_{series_id}",
341 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8")) 354 input_values = json.load(open(sys.argv[4], "r", encoding="utf-8"))
342 fit_vars = input_values["fit_vars"] 355 fit_vars = input_values["fit_vars"]
343 plot_graph = input_values["plot_graph"] 356 plot_graph = input_values["plot_graph"]
344 357
345 if input_values["execution"]["execution"] == "parallel": 358 if input_values["execution"]["execution"] == "parallel":
346 main(prj_file, gds_file, sp_file, fit_vars, plot_graph) 359 main([prj_file], gds_file, sp_file, fit_vars, plot_graph)
347 360 elif input_values["execution"]["execution"] == "simultaneous":
361 dataset_dicts = input_values["execution"]["simultaneous"]
362 main(dataset_dicts, gds_file, sp_file, fit_vars, plot_graph)
348 else: 363 else:
349 if os.path.isdir(prj_file): 364 if os.path.isdir(prj_file):
350 # Sort the unzipped directory, all filenames should be zero-padded 365 # Sort the unzipped directory, all filenames should be zero-padded
351 filepaths = [ 366 filepaths = [
352 os.path.join(prj_file, p) for p in os.listdir(prj_file) 367 os.path.join(prj_file, p) for p in os.listdir(prj_file)