diff larch_select_paths.py @ 1:7fdca938d90c draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_select_paths commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author muon-spectroscopy-computational-project
date Wed, 06 Dec 2023 13:04:15 +0000 (16 months ago)
parents 2e827836f0ad
children 204c4afe2f1e
line wrap: on
line diff
--- a/larch_select_paths.py	Tue Nov 14 15:35:52 2023 +0000
+++ b/larch_select_paths.py	Wed Dec 06 13:04:15 2023 +0000
@@ -3,9 +3,58 @@
 import os
 import re
 import sys
+from itertools import combinations
 from zipfile import ZIP_DEFLATED, ZipFile
 
 
+class CriteriaSelector:
+    def __init__(self, criteria: "dict[str, int|float]"):
+        self.max_number = criteria["max_number"]
+        self.max_path_length = criteria["max_path_length"]
+        self.min_amp_ratio = criteria["min_amplitude_ratio"]
+        self.max_degeneracy = criteria["max_degeneracy"]
+        self.path_count = 0
+
+    def evaluate(self, path_id: int, row: "list[str]") -> (bool, None):
+        if self.max_number and self.path_count >= self.max_number:
+            print(f"Reject path: {self.max_number} paths already reached")
+            return (False, None)
+
+        r_effective = float(row[5].strip())
+        if self.max_path_length and r_effective > self.max_path_length:
+            print(f"Reject path: {r_effective} > {self.max_path_length}")
+            return (False, None)
+
+        amplitude_ratio = float(row[2].strip())
+        if self.min_amp_ratio and (amplitude_ratio < self.min_amp_ratio):
+            print(f"Reject path: {amplitude_ratio} < {self.min_amp_ratio}")
+            return (False, None)
+
+        degeneracy = float(row[3].strip())
+        if self.max_degeneracy and degeneracy > self.max_degeneracy:
+            print(f"Reject path: {degeneracy} > {self.max_degeneracy}")
+            return (False, None)
+
+        self.path_count += 1
+        return (True, None)
+
+
+class ManualSelector:
+    def __init__(self, selection: dict):
+        self.select_all = selection["selection"] == "all"
+        self.paths = selection["paths"]
+        self.path_values_ids = [path_value["id"] for path_value in self.paths]
+
+    def evaluate(self, path_id: int, row: "list[str]") -> (bool, "None|dict"):
+        if path_id in self.path_values_ids:
+            return (True, self.paths[self.path_values_ids.index(path_id)])
+
+        if self.select_all or int(row[-1]):
+            return (True, None)
+
+        return (False, None)
+
+
 class GDSWriter:
     def __init__(self, default_variables: "dict[str, dict]"):
         self.default_properties = {
@@ -36,7 +85,7 @@
     def append_gds(
         self,
         name: str,
-        value: float = 0.,
+        value: float = 0.0,
         expr: str = None,
         vary: bool = True,
         label: str = "",
@@ -122,8 +171,7 @@
             return auto_name
 
     def write(self):
-        """Write GDS rows to file.
-        """
+        """Write GDS rows to file."""
         with open("gds.csv", "w") as out:
             out.writelines(self.rows)
 
@@ -135,6 +183,7 @@
             f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n"
         ]
         self.gds_writer = GDSWriter(default_variables=default_variables)
+        self.all_combinations = [[0]]  # 0 corresponds to the header row
 
     def parse_feff_output(
         self,
@@ -151,49 +200,123 @@
             directory_label (str, optional): Label to indicate paths from a
                 separate directory. Defaults to "".
         """
-        paths = selection["paths"]
-        path_values_ids = [path_value["id"] for path_value in paths]
+        combinations_list = []
+        if selection["selection"] in {"criteria", "combinations"}:
+            selector = CriteriaSelector(selection)
+        else:
+            selector = ManualSelector(selection)
+
+        selected_ids = self.select_rows(paths_file, directory_label, selector)
+
+        if selection["selection"] == "combinations":
+            min_number = selection["min_combination_size"]
+            min_number = min(min_number, len(selected_ids))
+            max_number = selection["max_combination_size"]
+            if not max_number or max_number > len(selected_ids):
+                max_number = len(selected_ids)
+
+            for number_of_paths in range(min_number, max_number + 1):
+                for combination in combinations(selected_ids, number_of_paths):
+                    combinations_list.append(combination)
 
+            new_combinations = len(combinations_list)
+            print(
+                f"{new_combinations} combinations for {directory_label}:\n"
+                f"{combinations_list}"
+            )
+            old_combinations_len = len(self.all_combinations)
+            self.all_combinations *= new_combinations
+            for i, combination in enumerate(self.all_combinations):
+                new_combinations = combinations_list[i // old_combinations_len]
+                self.all_combinations[i] = combination + list(new_combinations)
+        else:
+            for combination in self.all_combinations:
+                combination.extend(selected_ids)
+
+    def select_rows(
+        self,
+        paths_file: str,
+        directory_label: str,
+        selector: "CriteriaSelector|ManualSelector",
+    ) -> "list[int]":
+        """Evaluate each row in turn to decide whether or not it should be
+        included in the final output. Does not account for combinations.
+
+        Args:
+            paths_file (str): CSV summary filename.
+            directory_label (str): Label to indicate paths from a separate
+                directory.
+            selector (CriteriaSelector|ManualSelector): Object to evaluate
+                whether to select each path or not.
+
+        Returns:
+            list[int]: The ids of the selected rows.
+        """
+        row_ids = []
         with open(paths_file) as file:
             reader = csv.reader(file)
             for row in reader:
                 id_match = re.search(r"\d+", row[0])
                 if id_match:
                     path_id = int(id_match.group())
-                    filename = row[0].strip()
-                    path_label = row[-2].strip()
-                    variables = {}
+                    selected, path_value = selector.evaluate(
+                        path_id=path_id,
+                        row=row,
+                    )
+                    if selected:
+                        filename = row[0].strip()
+                        path_label = row[-2].strip()
+                        row_id = self.parse_row(
+                            directory_label, filename, path_label, path_value
+                        )
+                        row_ids.append(row_id)
+
+        return row_ids
+
+    def parse_row(
+        self,
+        directory_label: str,
+        filename: str,
+        path_label: str,
+        path_value: "None|dict",
+    ) -> int:
+        """Parse row for GDS and path information.
 
-                    if path_id in path_values_ids:
-                        path_value = paths[path_values_ids.index(path_id)]
-                        for property in self.gds_writer.default_properties:
-                            variables[property] = self.gds_writer.parse_gds(
-                                property_name=property,
-                                variable_name=path_value[property]["name"],
-                                path_variable=path_value[property],
-                                directory_label=directory_label,
-                                path_label=path_label,
-                            )
-                        self.parse_selected_path(
-                            filename=filename,
-                            path_label=path_label,
-                            directory_label=directory_label,
-                            **variables,
-                        )
-                    elif selection["selection"] == "all" or int(row[-1]):
-                        path_value = None
-                        for property in self.gds_writer.default_properties:
-                            variables[property] = self.gds_writer.parse_gds(
-                                property_name=property,
-                                directory_label=directory_label,
-                                path_label=path_label,
-                            )
-                        self.parse_selected_path(
-                            filename=filename,
-                            path_label=path_label,
-                            directory_label=directory_label,
-                            **variables,
-                        )
+        Args:
+            directory_label (str): Label to indicate paths from a separate
+                directory.
+            filename (str): Filename for the FEFF path, extracted from row.
+            path_label (str): Label for the FEFF path, extracted from row.
+            path_value (None|dict): The values associated with the selected
+                FEFF path. May be None in which case defaults are used.
+
+        Returns:
+            int: The id of the added row.
+        """
+        variables = {}
+        if path_value is not None:
+            for property in self.gds_writer.default_properties:
+                variables[property] = self.gds_writer.parse_gds(
+                    property_name=property,
+                    variable_name=path_value[property]["name"],
+                    path_variable=path_value[property],
+                    directory_label=directory_label,
+                    path_label=path_label,
+                )
+        else:
+            for property in self.gds_writer.default_properties:
+                variables[property] = self.gds_writer.parse_gds(
+                    property_name=property,
+                    directory_label=directory_label,
+                    path_label=path_label,
+                )
+
+        return self.parse_selected_path(
+            filename=filename,
+            path_label=path_label,
+            directory_label=directory_label,
+            **variables,
+        )
 
     def parse_selected_path(
         self,
@@ -204,7 +327,7 @@
         e0: str = "e0",
         sigma2: str = "sigma2",
         deltar: str = "alpha*reff",
-    ):
+    ) -> int:
         """Format and append row representing a selected FEFF path.
 
         Args:
@@ -220,6 +343,9 @@
                 Defaults to "sigma2".
             deltar (str, optional): Change in path length variable.
                 Defaults to "alpha*reff".
+
+        Returns:
+            int: The id of the added row.
         """
         if directory_label:
             filename = os.path.join(directory_label, filename)
@@ -228,17 +354,29 @@
             filename = os.path.join("feff", filename)
             label = path_label
 
+        row_id = len(self.rows)
         self.rows.append(
-            f"{len(self.rows):>4d}, {filename:>24s}, {label:>24s}, "
+            f"{row_id:>4d}, {filename:>24s}, {label:>24s}, "
             f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n"
         )
 
+        return row_id
+
     def write(self):
-        """Write selected path and GDS rows to file.
-        """
+        """Write selected path and GDS rows to file."""
         self.gds_writer.write()
-        with open("sp.csv", "w") as out:
-            out.writelines(self.rows)
+
+        if len(self.all_combinations) == 1:
+            with open("sp.csv", "w") as out:
+                out.writelines(self.rows)
+        else:
+            for combination in self.all_combinations:
+                filename = "_".join([str(c) for c in combination[1:]])
+                print(f"Writing combination {filename}")
+                with open(f"sp/{filename}.csv", "w") as out:
+                    for row_id, row in enumerate(self.rows):
+                        if row_id in combination:
+                            out.write(row)
 
 
 def main(input_values: dict):
@@ -265,9 +403,9 @@
         labels = set()
         with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out:
             for i, feff_output in enumerate(input_values["feff_outputs"]):
-                label = feff_output.pop("label") or str(i + 1).zfill(
-                    zfill_length
-                )
+                label = feff_output["label"]
+                if not label:
+                    label = str(i + 1).zfill(zfill_length)
                 if label in labels:
                     raise ValueError(f"Label '{label}' is not unique")
                 labels.add(label)
@@ -283,9 +421,8 @@
                         if zipinfo.filename != "feff/":
                             zipinfo.filename = zipinfo.filename[5:]
                             z.extract(member=zipinfo, path=label)
-                            zipfile_out.write(
-                                os.path.join(label, zipinfo.filename)
-                            )
+                            filename = os.path.join(label, zipinfo.filename)
+                            zipfile_out.write(filename)
 
     writer.write()