Mercurial > repos > muon-spectroscopy-computational-project > larch_select_paths
comparison larch_select_paths.py @ 1:7fdca938d90c draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_select_paths commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author | muon-spectroscopy-computational-project |
---|---|
date | Wed, 06 Dec 2023 13:04:15 +0000 |
parents | 2e827836f0ad |
children | 204c4afe2f1e |
comparison
equal
deleted
inserted
replaced
0:2e827836f0ad | 1:7fdca938d90c |
---|---|
1 import csv | 1 import csv |
2 import json | 2 import json |
3 import os | 3 import os |
4 import re | 4 import re |
5 import sys | 5 import sys |
6 from itertools import combinations | |
6 from zipfile import ZIP_DEFLATED, ZipFile | 7 from zipfile import ZIP_DEFLATED, ZipFile |
8 | |
9 | |
10 class CriteriaSelector: | |
11 def __init__(self, criteria: "dict[str, int|float]"): | |
12 self.max_number = criteria["max_number"] | |
13 self.max_path_length = criteria["max_path_length"] | |
14 self.min_amp_ratio = criteria["min_amplitude_ratio"] | |
15 self.max_degeneracy = criteria["max_degeneracy"] | |
16 self.path_count = 0 | |
17 | |
18 def evaluate(self, path_id: int, row: "list[str]") -> (bool, None): | |
19 if self.max_number and self.path_count >= self.max_number: | |
20 print(f"Reject path: {self.max_number} paths already reached") | |
21 return (False, None) | |
22 | |
23 r_effective = float(row[5].strip()) | |
24 if self.max_path_length and r_effective > self.max_path_length: | |
25 print(f"Reject path: {r_effective} > {self.max_path_length}") | |
26 return (False, None) | |
27 | |
28 amplitude_ratio = float(row[2].strip()) | |
29 if self.min_amp_ratio and (amplitude_ratio < self.min_amp_ratio): | |
30 print(f"Reject path: {amplitude_ratio} < {self.min_amp_ratio}") | |
31 return (False, None) | |
32 | |
33 degeneracy = float(row[3].strip()) | |
34 if self.max_degeneracy and degeneracy > self.max_degeneracy: | |
35 print(f"Reject path: {degeneracy} > {self.max_degeneracy}") | |
36 return (False, None) | |
37 | |
38 self.path_count += 1 | |
39 return (True, None) | |
40 | |
41 | |
42 class ManualSelector: | |
43 def __init__(self, selection: dict): | |
44 self.select_all = selection["selection"] == "all" | |
45 self.paths = selection["paths"] | |
46 self.path_values_ids = [path_value["id"] for path_value in self.paths] | |
47 | |
48 def evaluate(self, path_id: int, row: "list[str]") -> (bool, "None|dict"): | |
49 if path_id in self.path_values_ids: | |
50 return (True, self.paths[self.path_values_ids.index(path_id)]) | |
51 | |
52 if self.select_all or int(row[-1]): | |
53 return (True, None) | |
54 | |
55 return (False, None) | |
7 | 56 |
8 | 57 |
9 class GDSWriter: | 58 class GDSWriter: |
10 def __init__(self, default_variables: "dict[str, dict]"): | 59 def __init__(self, default_variables: "dict[str, dict]"): |
11 self.default_properties = { | 60 self.default_properties = { |
34 self.append_gds(name=name, value=value, vary=vary) | 83 self.append_gds(name=name, value=value, vary=vary) |
35 | 84 |
36 def append_gds( | 85 def append_gds( |
37 self, | 86 self, |
38 name: str, | 87 name: str, |
39 value: float = 0., | 88 value: float = 0.0, |
40 expr: str = None, | 89 expr: str = None, |
41 vary: bool = True, | 90 vary: bool = True, |
42 label: str = "", | 91 label: str = "", |
43 ): | 92 ): |
44 """Append a single GDS variable to the list of rows, later to be | 93 """Append a single GDS variable to the list of rows, later to be |
120 vary=self.default_properties[property_name]["vary"], | 169 vary=self.default_properties[property_name]["vary"], |
121 ) | 170 ) |
122 return auto_name | 171 return auto_name |
123 | 172 |
124 def write(self): | 173 def write(self): |
125 """Write GDS rows to file. | 174 """Write GDS rows to file.""" |
126 """ | |
127 with open("gds.csv", "w") as out: | 175 with open("gds.csv", "w") as out: |
128 out.writelines(self.rows) | 176 out.writelines(self.rows) |
129 | 177 |
130 | 178 |
131 class PathsWriter: | 179 class PathsWriter: |
133 self.rows = [ | 181 self.rows = [ |
134 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, " | 182 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, " |
135 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n" | 183 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n" |
136 ] | 184 ] |
137 self.gds_writer = GDSWriter(default_variables=default_variables) | 185 self.gds_writer = GDSWriter(default_variables=default_variables) |
186 self.all_combinations = [[0]] # 0 corresponds to the header row | |
138 | 187 |
139 def parse_feff_output( | 188 def parse_feff_output( |
140 self, | 189 self, |
141 paths_file: str, | 190 paths_file: str, |
142 selection: "dict[str, str|list]", | 191 selection: "dict[str, str|list]", |
149 selection (dict[str, str|list]): Dictionary indicating which paths | 198 selection (dict[str, str|list]): Dictionary indicating which paths |
150 to select, and how to define their variables. | 199 to select, and how to define their variables. |
151 directory_label (str, optional): Label to indicate paths from a | 200 directory_label (str, optional): Label to indicate paths from a |
152 separate directory. Defaults to "". | 201 separate directory. Defaults to "". |
153 """ | 202 """ |
154 paths = selection["paths"] | 203 combinations_list = [] |
155 path_values_ids = [path_value["id"] for path_value in paths] | 204 if selection["selection"] in {"criteria", "combinations"}: |
156 | 205 selector = CriteriaSelector(selection) |
206 else: | |
207 selector = ManualSelector(selection) | |
208 | |
209 selected_ids = self.select_rows(paths_file, directory_label, selector) | |
210 | |
211 if selection["selection"] == "combinations": | |
212 min_number = selection["min_combination_size"] | |
213 min_number = min(min_number, len(selected_ids)) | |
214 max_number = selection["max_combination_size"] | |
215 if not max_number or max_number > len(selected_ids): | |
216 max_number = len(selected_ids) | |
217 | |
218 for number_of_paths in range(min_number, max_number + 1): | |
219 for combination in combinations(selected_ids, number_of_paths): | |
220 combinations_list.append(combination) | |
221 | |
222 new_combinations = len(combinations_list) | |
223 print( | |
224 f"{new_combinations} combinations for {directory_label}:\n" | |
225 f"{combinations_list}" | |
226 ) | |
227 old_combinations_len = len(self.all_combinations) | |
228 self.all_combinations *= new_combinations | |
229 for i, combination in enumerate(self.all_combinations): | |
230 new_combinations = combinations_list[i // old_combinations_len] | |
231 self.all_combinations[i] = combination + list(new_combinations) | |
232 else: | |
233 for combination in self.all_combinations: | |
234 combination.extend(selected_ids) | |
235 | |
236 def select_rows( | |
237 self, | |
238 paths_file: str, | |
239 directory_label: str, | |
240 selector: "CriteriaSelector|ManualSelector", | |
241 ) -> "list[int]": | |
242 """Evaluate each row in turn to decide whether or not it should be | |
243 included in the final output. Does not account for combinations. | |
244 | |
245 Args: | |
246 paths_file (str): CSV summary filename. | |
247 directory_label (str): Label to indicate paths from a separate | |
248 directory. | |
249 selector (CriteriaSelector|ManualSelector): Object to evaluate | |
250 whether to select each path or not. | |
251 | |
252 Returns: | |
253 list[int]: The ids of the selected rows. | |
254 """ | |
255 row_ids = [] | |
157 with open(paths_file) as file: | 256 with open(paths_file) as file: |
158 reader = csv.reader(file) | 257 reader = csv.reader(file) |
159 for row in reader: | 258 for row in reader: |
160 id_match = re.search(r"\d+", row[0]) | 259 id_match = re.search(r"\d+", row[0]) |
161 if id_match: | 260 if id_match: |
162 path_id = int(id_match.group()) | 261 path_id = int(id_match.group()) |
163 filename = row[0].strip() | 262 selected, path_value = selector.evaluate( |
164 path_label = row[-2].strip() | 263 path_id=path_id, |
165 variables = {} | 264 row=row, |
166 | 265 ) |
167 if path_id in path_values_ids: | 266 if selected: |
168 path_value = paths[path_values_ids.index(path_id)] | 267 filename = row[0].strip() |
169 for property in self.gds_writer.default_properties: | 268 path_label = row[-2].strip() |
170 variables[property] = self.gds_writer.parse_gds( | 269 row_id = self.parse_row( |
171 property_name=property, | 270 directory_label, filename, path_label, path_value |
172 variable_name=path_value[property]["name"], | |
173 path_variable=path_value[property], | |
174 directory_label=directory_label, | |
175 path_label=path_label, | |
176 ) | |
177 self.parse_selected_path( | |
178 filename=filename, | |
179 path_label=path_label, | |
180 directory_label=directory_label, | |
181 **variables, | |
182 ) | 271 ) |
183 elif selection["selection"] == "all" or int(row[-1]): | 272 row_ids.append(row_id) |
184 path_value = None | 273 |
185 for property in self.gds_writer.default_properties: | 274 return row_ids |
186 variables[property] = self.gds_writer.parse_gds( | 275 |
187 property_name=property, | 276 def parse_row( |
188 directory_label=directory_label, | 277 self, |
189 path_label=path_label, | 278 directory_label: str, |
190 ) | 279 filename: str, |
191 self.parse_selected_path( | 280 path_label: str, |
192 filename=filename, | 281 path_value: "None|dict", |
193 path_label=path_label, | 282 ) -> int: |
194 directory_label=directory_label, | 283 """Parse row for GDS and path information. |
195 **variables, | 284 |
196 ) | 285 Args: |
286 directory_label (str): Label to indicate paths from a separate | |
287 directory. | |
288 filename (str): Filename for the FEFF path, extracted from row. | |
289 path_label (str): Label for the FEFF path, extracted from row. | |
290 path_value (None|dict): The values associated with the selected | |
291 FEFF path. May be None in which case defaults are used. | |
292 | |
293 Returns: | |
294 int: The id of the added row. | |
295 """ | |
296 variables = {} | |
297 if path_value is not None: | |
298 for property in self.gds_writer.default_properties: | |
299 variables[property] = self.gds_writer.parse_gds( | |
300 property_name=property, | |
301 variable_name=path_value[property]["name"], | |
302 path_variable=path_value[property], | |
303 directory_label=directory_label, | |
304 path_label=path_label, | |
305 ) | |
306 else: | |
307 for property in self.gds_writer.default_properties: | |
308 variables[property] = self.gds_writer.parse_gds( | |
309 property_name=property, | |
310 directory_label=directory_label, | |
311 path_label=path_label, | |
312 ) | |
313 | |
314 return self.parse_selected_path( | |
315 filename=filename, | |
316 path_label=path_label, | |
317 directory_label=directory_label, | |
318 **variables, | |
319 ) | |
197 | 320 |
198 def parse_selected_path( | 321 def parse_selected_path( |
199 self, | 322 self, |
200 filename: str, | 323 filename: str, |
201 path_label: str, | 324 path_label: str, |
202 directory_label: str = "", | 325 directory_label: str = "", |
203 s02: str = "s02", | 326 s02: str = "s02", |
204 e0: str = "e0", | 327 e0: str = "e0", |
205 sigma2: str = "sigma2", | 328 sigma2: str = "sigma2", |
206 deltar: str = "alpha*reff", | 329 deltar: str = "alpha*reff", |
207 ): | 330 ) -> int: |
208 """Format and append row representing a selected FEFF path. | 331 """Format and append row representing a selected FEFF path. |
209 | 332 |
210 Args: | 333 Args: |
211 filename (str): Name of the underlying FEFF path file, without | 334 filename (str): Name of the underlying FEFF path file, without |
212 parent directory. | 335 parent directory. |
218 e0 (str, optional): Energy shift variable name. Defaults to "e0". | 341 e0 (str, optional): Energy shift variable name. Defaults to "e0". |
219 sigma2 (str, optional): Mean squared displacement variable name. | 342 sigma2 (str, optional): Mean squared displacement variable name. |
220 Defaults to "sigma2". | 343 Defaults to "sigma2". |
221 deltar (str, optional): Change in path length variable. | 344 deltar (str, optional): Change in path length variable. |
222 Defaults to "alpha*reff". | 345 Defaults to "alpha*reff". |
346 | |
347 Returns: | |
348 int: The id of the added row. | |
223 """ | 349 """ |
224 if directory_label: | 350 if directory_label: |
225 filename = os.path.join(directory_label, filename) | 351 filename = os.path.join(directory_label, filename) |
226 label = f"{directory_label}.{path_label}" | 352 label = f"{directory_label}.{path_label}" |
227 else: | 353 else: |
228 filename = os.path.join("feff", filename) | 354 filename = os.path.join("feff", filename) |
229 label = path_label | 355 label = path_label |
230 | 356 |
357 row_id = len(self.rows) | |
231 self.rows.append( | 358 self.rows.append( |
232 f"{len(self.rows):>4d}, {filename:>24s}, {label:>24s}, " | 359 f"{row_id:>4d}, {filename:>24s}, {label:>24s}, " |
233 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n" | 360 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n" |
234 ) | 361 ) |
235 | 362 |
363 return row_id | |
364 | |
236 def write(self): | 365 def write(self): |
237 """Write selected path and GDS rows to file. | 366 """Write selected path and GDS rows to file.""" |
238 """ | |
239 self.gds_writer.write() | 367 self.gds_writer.write() |
240 with open("sp.csv", "w") as out: | 368 |
241 out.writelines(self.rows) | 369 if len(self.all_combinations) == 1: |
370 with open("sp.csv", "w") as out: | |
371 out.writelines(self.rows) | |
372 else: | |
373 for combination in self.all_combinations: | |
374 filename = "_".join([str(c) for c in combination[1:]]) | |
375 print(f"Writing combination {filename}") | |
376 with open(f"sp/{filename}.csv", "w") as out: | |
377 for row_id, row in enumerate(self.rows): | |
378 if row_id in combination: | |
379 out.write(row) | |
242 | 380 |
243 | 381 |
244 def main(input_values: dict): | 382 def main(input_values: dict): |
245 """Select paths and define GDS parameters. | 383 """Select paths and define GDS parameters. |
246 | 384 |
263 else: | 401 else: |
264 zfill_length = len(str(len(input_values["feff_outputs"]))) | 402 zfill_length = len(str(len(input_values["feff_outputs"]))) |
265 labels = set() | 403 labels = set() |
266 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out: | 404 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out: |
267 for i, feff_output in enumerate(input_values["feff_outputs"]): | 405 for i, feff_output in enumerate(input_values["feff_outputs"]): |
268 label = feff_output.pop("label") or str(i + 1).zfill( | 406 label = feff_output["label"] |
269 zfill_length | 407 if not label: |
270 ) | 408 label = str(i + 1).zfill(zfill_length) |
271 if label in labels: | 409 if label in labels: |
272 raise ValueError(f"Label '{label}' is not unique") | 410 raise ValueError(f"Label '{label}' is not unique") |
273 labels.add(label) | 411 labels.add(label) |
274 | 412 |
275 writer.parse_feff_output( | 413 writer.parse_feff_output( |
281 with ZipFile(feff_output["paths_zip"]) as z: | 419 with ZipFile(feff_output["paths_zip"]) as z: |
282 for zipinfo in z.infolist(): | 420 for zipinfo in z.infolist(): |
283 if zipinfo.filename != "feff/": | 421 if zipinfo.filename != "feff/": |
284 zipinfo.filename = zipinfo.filename[5:] | 422 zipinfo.filename = zipinfo.filename[5:] |
285 z.extract(member=zipinfo, path=label) | 423 z.extract(member=zipinfo, path=label) |
286 zipfile_out.write( | 424 filename = os.path.join(label, zipinfo.filename) |
287 os.path.join(label, zipinfo.filename) | 425 zipfile_out.write(filename) |
288 ) | |
289 | 426 |
290 writer.write() | 427 writer.write() |
291 | 428 |
292 | 429 |
293 if __name__ == "__main__": | 430 if __name__ == "__main__": |