comparison larch_select_paths.py @ 1:7fdca938d90c draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_select_paths commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author muon-spectroscopy-computational-project
date Wed, 06 Dec 2023 13:04:15 +0000
parents 2e827836f0ad
children 204c4afe2f1e
comparison
equal deleted inserted replaced
0:2e827836f0ad 1:7fdca938d90c
1 import csv 1 import csv
2 import json 2 import json
3 import os 3 import os
4 import re 4 import re
5 import sys 5 import sys
6 from itertools import combinations
6 from zipfile import ZIP_DEFLATED, ZipFile 7 from zipfile import ZIP_DEFLATED, ZipFile
8
9
10 class CriteriaSelector:
11 def __init__(self, criteria: "dict[str, int|float]"):
12 self.max_number = criteria["max_number"]
13 self.max_path_length = criteria["max_path_length"]
14 self.min_amp_ratio = criteria["min_amplitude_ratio"]
15 self.max_degeneracy = criteria["max_degeneracy"]
16 self.path_count = 0
17
18 def evaluate(self, path_id: int, row: "list[str]") -> (bool, None):
19 if self.max_number and self.path_count >= self.max_number:
20 print(f"Reject path: {self.max_number} paths already reached")
21 return (False, None)
22
23 r_effective = float(row[5].strip())
24 if self.max_path_length and r_effective > self.max_path_length:
25 print(f"Reject path: {r_effective} > {self.max_path_length}")
26 return (False, None)
27
28 amplitude_ratio = float(row[2].strip())
29 if self.min_amp_ratio and (amplitude_ratio < self.min_amp_ratio):
30 print(f"Reject path: {amplitude_ratio} < {self.min_amp_ratio}")
31 return (False, None)
32
33 degeneracy = float(row[3].strip())
34 if self.max_degeneracy and degeneracy > self.max_degeneracy:
35 print(f"Reject path: {degeneracy} > {self.max_degeneracy}")
36 return (False, None)
37
38 self.path_count += 1
39 return (True, None)
40
41
42 class ManualSelector:
43 def __init__(self, selection: dict):
44 self.select_all = selection["selection"] == "all"
45 self.paths = selection["paths"]
46 self.path_values_ids = [path_value["id"] for path_value in self.paths]
47
48 def evaluate(self, path_id: int, row: "list[str]") -> (bool, "None|dict"):
49 if path_id in self.path_values_ids:
50 return (True, self.paths[self.path_values_ids.index(path_id)])
51
52 if self.select_all or int(row[-1]):
53 return (True, None)
54
55 return (False, None)
7 56
8 57
9 class GDSWriter: 58 class GDSWriter:
10 def __init__(self, default_variables: "dict[str, dict]"): 59 def __init__(self, default_variables: "dict[str, dict]"):
11 self.default_properties = { 60 self.default_properties = {
34 self.append_gds(name=name, value=value, vary=vary) 83 self.append_gds(name=name, value=value, vary=vary)
35 84
36 def append_gds( 85 def append_gds(
37 self, 86 self,
38 name: str, 87 name: str,
39 value: float = 0., 88 value: float = 0.0,
40 expr: str = None, 89 expr: str = None,
41 vary: bool = True, 90 vary: bool = True,
42 label: str = "", 91 label: str = "",
43 ): 92 ):
44 """Append a single GDS variable to the list of rows, later to be 93 """Append a single GDS variable to the list of rows, later to be
120 vary=self.default_properties[property_name]["vary"], 169 vary=self.default_properties[property_name]["vary"],
121 ) 170 )
122 return auto_name 171 return auto_name
123 172
124 def write(self): 173 def write(self):
125 """Write GDS rows to file. 174 """Write GDS rows to file."""
126 """
127 with open("gds.csv", "w") as out: 175 with open("gds.csv", "w") as out:
128 out.writelines(self.rows) 176 out.writelines(self.rows)
129 177
130 178
131 class PathsWriter: 179 class PathsWriter:
133 self.rows = [ 181 self.rows = [
134 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, " 182 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, "
135 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n" 183 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n"
136 ] 184 ]
137 self.gds_writer = GDSWriter(default_variables=default_variables) 185 self.gds_writer = GDSWriter(default_variables=default_variables)
186 self.all_combinations = [[0]] # 0 corresponds to the header row
138 187
139 def parse_feff_output( 188 def parse_feff_output(
140 self, 189 self,
141 paths_file: str, 190 paths_file: str,
142 selection: "dict[str, str|list]", 191 selection: "dict[str, str|list]",
149 selection (dict[str, str|list]): Dictionary indicating which paths 198 selection (dict[str, str|list]): Dictionary indicating which paths
150 to select, and how to define their variables. 199 to select, and how to define their variables.
151 directory_label (str, optional): Label to indicate paths from a 200 directory_label (str, optional): Label to indicate paths from a
152 separate directory. Defaults to "". 201 separate directory. Defaults to "".
153 """ 202 """
154 paths = selection["paths"] 203 combinations_list = []
155 path_values_ids = [path_value["id"] for path_value in paths] 204 if selection["selection"] in {"criteria", "combinations"}:
156 205 selector = CriteriaSelector(selection)
206 else:
207 selector = ManualSelector(selection)
208
209 selected_ids = self.select_rows(paths_file, directory_label, selector)
210
211 if selection["selection"] == "combinations":
212 min_number = selection["min_combination_size"]
213 min_number = min(min_number, len(selected_ids))
214 max_number = selection["max_combination_size"]
215 if not max_number or max_number > len(selected_ids):
216 max_number = len(selected_ids)
217
218 for number_of_paths in range(min_number, max_number + 1):
219 for combination in combinations(selected_ids, number_of_paths):
220 combinations_list.append(combination)
221
222 new_combinations = len(combinations_list)
223 print(
224 f"{new_combinations} combinations for {directory_label}:\n"
225 f"{combinations_list}"
226 )
227 old_combinations_len = len(self.all_combinations)
228 self.all_combinations *= new_combinations
229 for i, combination in enumerate(self.all_combinations):
230 new_combinations = combinations_list[i // old_combinations_len]
231 self.all_combinations[i] = combination + list(new_combinations)
232 else:
233 for combination in self.all_combinations:
234 combination.extend(selected_ids)
235
236 def select_rows(
237 self,
238 paths_file: str,
239 directory_label: str,
240 selector: "CriteriaSelector|ManualSelector",
241 ) -> "list[int]":
242 """Evaluate each row in turn to decide whether or not it should be
243 included in the final output. Does not account for combinations.
244
245 Args:
246 paths_file (str): CSV summary filename.
247 directory_label (str): Label to indicate paths from a separate
248 directory.
249 selector (CriteriaSelector|ManualSelector): Object to evaluate
250 whether to select each path or not.
251
252 Returns:
253 list[int]: The ids of the selected rows.
254 """
255 row_ids = []
157 with open(paths_file) as file: 256 with open(paths_file) as file:
158 reader = csv.reader(file) 257 reader = csv.reader(file)
159 for row in reader: 258 for row in reader:
160 id_match = re.search(r"\d+", row[0]) 259 id_match = re.search(r"\d+", row[0])
161 if id_match: 260 if id_match:
162 path_id = int(id_match.group()) 261 path_id = int(id_match.group())
163 filename = row[0].strip() 262 selected, path_value = selector.evaluate(
164 path_label = row[-2].strip() 263 path_id=path_id,
165 variables = {} 264 row=row,
166 265 )
167 if path_id in path_values_ids: 266 if selected:
168 path_value = paths[path_values_ids.index(path_id)] 267 filename = row[0].strip()
169 for property in self.gds_writer.default_properties: 268 path_label = row[-2].strip()
170 variables[property] = self.gds_writer.parse_gds( 269 row_id = self.parse_row(
171 property_name=property, 270 directory_label, filename, path_label, path_value
172 variable_name=path_value[property]["name"],
173 path_variable=path_value[property],
174 directory_label=directory_label,
175 path_label=path_label,
176 )
177 self.parse_selected_path(
178 filename=filename,
179 path_label=path_label,
180 directory_label=directory_label,
181 **variables,
182 ) 271 )
183 elif selection["selection"] == "all" or int(row[-1]): 272 row_ids.append(row_id)
184 path_value = None 273
185 for property in self.gds_writer.default_properties: 274 return row_ids
186 variables[property] = self.gds_writer.parse_gds( 275
187 property_name=property, 276 def parse_row(
188 directory_label=directory_label, 277 self,
189 path_label=path_label, 278 directory_label: str,
190 ) 279 filename: str,
191 self.parse_selected_path( 280 path_label: str,
192 filename=filename, 281 path_value: "None|dict",
193 path_label=path_label, 282 ) -> int:
194 directory_label=directory_label, 283 """Parse row for GDS and path information.
195 **variables, 284
196 ) 285 Args:
286 directory_label (str): Label to indicate paths from a separate
287 directory.
288 filename (str): Filename for the FEFF path, extracted from row.
289 path_label (str): Label for the FEFF path, extracted from row.
290 path_value (None|dict): The values associated with the selected
291 FEFF path. May be None in which case defaults are used.
292
293 Returns:
294 int: The id of the added row.
295 """
296 variables = {}
297 if path_value is not None:
298 for property in self.gds_writer.default_properties:
299 variables[property] = self.gds_writer.parse_gds(
300 property_name=property,
301 variable_name=path_value[property]["name"],
302 path_variable=path_value[property],
303 directory_label=directory_label,
304 path_label=path_label,
305 )
306 else:
307 for property in self.gds_writer.default_properties:
308 variables[property] = self.gds_writer.parse_gds(
309 property_name=property,
310 directory_label=directory_label,
311 path_label=path_label,
312 )
313
314 return self.parse_selected_path(
315 filename=filename,
316 path_label=path_label,
317 directory_label=directory_label,
318 **variables,
319 )
197 320
198 def parse_selected_path( 321 def parse_selected_path(
199 self, 322 self,
200 filename: str, 323 filename: str,
201 path_label: str, 324 path_label: str,
202 directory_label: str = "", 325 directory_label: str = "",
203 s02: str = "s02", 326 s02: str = "s02",
204 e0: str = "e0", 327 e0: str = "e0",
205 sigma2: str = "sigma2", 328 sigma2: str = "sigma2",
206 deltar: str = "alpha*reff", 329 deltar: str = "alpha*reff",
207 ): 330 ) -> int:
208 """Format and append row representing a selected FEFF path. 331 """Format and append row representing a selected FEFF path.
209 332
210 Args: 333 Args:
211 filename (str): Name of the underlying FEFF path file, without 334 filename (str): Name of the underlying FEFF path file, without
212 parent directory. 335 parent directory.
218 e0 (str, optional): Energy shift variable name. Defaults to "e0". 341 e0 (str, optional): Energy shift variable name. Defaults to "e0".
219 sigma2 (str, optional): Mean squared displacement variable name. 342 sigma2 (str, optional): Mean squared displacement variable name.
220 Defaults to "sigma2". 343 Defaults to "sigma2".
221 deltar (str, optional): Change in path length variable. 344 deltar (str, optional): Change in path length variable.
222 Defaults to "alpha*reff". 345 Defaults to "alpha*reff".
346
347 Returns:
348 int: The id of the added row.
223 """ 349 """
224 if directory_label: 350 if directory_label:
225 filename = os.path.join(directory_label, filename) 351 filename = os.path.join(directory_label, filename)
226 label = f"{directory_label}.{path_label}" 352 label = f"{directory_label}.{path_label}"
227 else: 353 else:
228 filename = os.path.join("feff", filename) 354 filename = os.path.join("feff", filename)
229 label = path_label 355 label = path_label
230 356
357 row_id = len(self.rows)
231 self.rows.append( 358 self.rows.append(
232 f"{len(self.rows):>4d}, {filename:>24s}, {label:>24s}, " 359 f"{row_id:>4d}, {filename:>24s}, {label:>24s}, "
233 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n" 360 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n"
234 ) 361 )
235 362
363 return row_id
364
236 def write(self): 365 def write(self):
237 """Write selected path and GDS rows to file. 366 """Write selected path and GDS rows to file."""
238 """
239 self.gds_writer.write() 367 self.gds_writer.write()
240 with open("sp.csv", "w") as out: 368
241 out.writelines(self.rows) 369 if len(self.all_combinations) == 1:
370 with open("sp.csv", "w") as out:
371 out.writelines(self.rows)
372 else:
373 for combination in self.all_combinations:
374 filename = "_".join([str(c) for c in combination[1:]])
375 print(f"Writing combination {filename}")
376 with open(f"sp/{filename}.csv", "w") as out:
377 for row_id, row in enumerate(self.rows):
378 if row_id in combination:
379 out.write(row)
242 380
243 381
244 def main(input_values: dict): 382 def main(input_values: dict):
245 """Select paths and define GDS parameters. 383 """Select paths and define GDS parameters.
246 384
263 else: 401 else:
264 zfill_length = len(str(len(input_values["feff_outputs"]))) 402 zfill_length = len(str(len(input_values["feff_outputs"])))
265 labels = set() 403 labels = set()
266 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out: 404 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out:
267 for i, feff_output in enumerate(input_values["feff_outputs"]): 405 for i, feff_output in enumerate(input_values["feff_outputs"]):
268 label = feff_output.pop("label") or str(i + 1).zfill( 406 label = feff_output["label"]
269 zfill_length 407 if not label:
270 ) 408 label = str(i + 1).zfill(zfill_length)
271 if label in labels: 409 if label in labels:
272 raise ValueError(f"Label '{label}' is not unique") 410 raise ValueError(f"Label '{label}' is not unique")
273 labels.add(label) 411 labels.add(label)
274 412
275 writer.parse_feff_output( 413 writer.parse_feff_output(
281 with ZipFile(feff_output["paths_zip"]) as z: 419 with ZipFile(feff_output["paths_zip"]) as z:
282 for zipinfo in z.infolist(): 420 for zipinfo in z.infolist():
283 if zipinfo.filename != "feff/": 421 if zipinfo.filename != "feff/":
284 zipinfo.filename = zipinfo.filename[5:] 422 zipinfo.filename = zipinfo.filename[5:]
285 z.extract(member=zipinfo, path=label) 423 z.extract(member=zipinfo, path=label)
286 zipfile_out.write( 424 filename = os.path.join(label, zipinfo.filename)
287 os.path.join(label, zipinfo.filename) 425 zipfile_out.write(filename)
288 )
289 426
290 writer.write() 427 writer.write()
291 428
292 429
293 if __name__ == "__main__": 430 if __name__ == "__main__":