view larch_select_paths.py @ 4:204c4afe2f1e draft default tip

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_select_paths commit 4814f53888643f1d3667789050914675fffb7d59
author muon-spectroscopy-computational-project
date Fri, 23 Aug 2024 14:10:59 +0000
parents 7fdca938d90c
children
line wrap: on
line source

import csv
import json
import os
import re
import sys
from itertools import combinations
from zipfile import ZIP_DEFLATED, ZipFile


class CriteriaSelector:
    def __init__(self, criteria: "dict[str, int|float]"):
        self.max_number = criteria["max_number"]
        self.max_path_length = criteria["max_path_length"]
        self.min_amp_ratio = criteria["min_amplitude_ratio"]
        self.max_degeneracy = criteria["max_degeneracy"]
        self.path_count = 0

    def evaluate(self, path_id: int, row: "list[str]") -> (bool, None):
        if self.max_number and self.path_count >= self.max_number:
            print(f"Reject path: {self.max_number} paths already reached")
            return (False, None)

        r_effective = float(row[5].strip())
        if self.max_path_length and r_effective > self.max_path_length:
            print(f"Reject path: {r_effective} > {self.max_path_length}")
            return (False, None)

        amplitude_ratio = float(row[2].strip())
        if self.min_amp_ratio and (amplitude_ratio < self.min_amp_ratio):
            print(f"Reject path: {amplitude_ratio} < {self.min_amp_ratio}")
            return (False, None)

        degeneracy = float(row[3].strip())
        if self.max_degeneracy and degeneracy > self.max_degeneracy:
            print(f"Reject path: {degeneracy} > {self.max_degeneracy}")
            return (False, None)

        self.path_count += 1
        return (True, None)


class ManualSelector:
    def __init__(self, selection: dict):
        self.select_all = selection["selection"] == "all"
        self.paths = selection["paths"]
        self.path_values_ids = [path_value["id"] for path_value in self.paths]

    def evaluate(self, path_id: int, row: "list[str]") -> (bool, "None|dict"):
        if path_id in self.path_values_ids:
            return (True, self.paths[self.path_values_ids.index(path_id)])

        if self.select_all or int(row[-1]):
            return (True, None)

        return (False, None)


class GDSWriter:
    def __init__(self, default_variables: "dict[str, dict]"):
        self.default_properties = {
            "degen": {"name": "degen"},
            "s02": {"name": "s02"},
            "e0": {"name": "e0"},
            "deltar": {"name": "alpha"},
            "sigma2": {"name": "sigma2"},
        }
        self.rows = [
            f"{'id':>4s}, {'name':>24s}, {'value':>5s}, {'expr':>4s}, "
            f"{'vary':>4s}\n"
        ]
        self.names = set()

        for property in self.default_properties:
            name = self.default_properties[property]["name"]
            value = default_variables[property]["value"]
            vary = default_variables[property]["vary"]
            is_common = default_variables[property]["is_common"]

            self.default_properties[property]["value"] = value
            self.default_properties[property]["vary"] = vary
            self.default_properties[property]["is_common"] = is_common

            if is_common:
                self.append_gds(name=name, value=value, vary=vary)

    def append_gds(
        self,
        name: str,
        value: float = 0.0,
        expr: str = None,
        vary: bool = True,
        label: str = "",
    ):
        """Append a single GDS variable to the list of rows, later to be
        written to file.

        Args:
            name (str): Name of the GDS variable.
            value (float, optional): Starting value for variable.
                Defaults to 0.
            expr (str, optional): Expression for setting the variable.
                Defaults to None.
            vary (bool, optional): Whether the variable is optimised during the
                fit. Defaults to True.
            label (str, optional): Label to keep variables for different FEFF
                directories distinct. Defaults to "".
        """
        formatted_name = name if (label is None) else label + name
        formatted_name = formatted_name.replace("*reff", "")
        if not expr:
            expr = "    "

        if formatted_name in self.names:
            raise ValueError(f"{formatted_name} already used as variable name")
        self.names.add(formatted_name)

        if value is not None:
            formatted_value = str(value)
        else:
            formatted_value = ""

        self.rows.append(
            f"{len(self.rows):4d}, {formatted_name:>24s}, "
            f"{formatted_value:>5s}, {expr:>4s}, {str(vary):>4s}\n"
        )

    def parse_gds(
        self,
        property_name: str,
        variable_name: str = None,
        path_variable: dict = None,
        directory_label: str = None,
        path_label: str = None,
    ) -> str:
        """Parse and append a row defining a GDS variable for a particular
        path.

        Args:
            property_name (str): The property to which the variable
                corresponds. Should be a key in `self.default_properties`.
            variable_name (str, optional): Custom name for this variable.
                Defaults to None.
            path_variable (dict, optional): Dictionary defining the GDS
                settings for this path's variable. Defaults to None.
            directory_label (str, optional): Label to indicate paths from a
                separate directory. Defaults to None.
            path_label (str, optional): Label indicating the atoms involved in
                this path. Defaults to None.

        Returns:
            str: Either `variable_name`, the name used as a default globally
                for this `property_name`, or an automatically generated unique
                name.
        """
        if variable_name:
            self.append_gds(
                name=variable_name,
                value=path_variable["value"],
                expr=path_variable["expr"],
                vary=path_variable["vary"],
            )
            return variable_name
        elif self.default_properties[property_name]["is_common"]:
            return self.default_properties[property_name]["name"]
        else:
            auto_name = self.default_properties[property_name]["name"]
            if directory_label:
                auto_name += f"_{directory_label}"
            if path_label:
                auto_name += f"_{path_label.lower().replace('.', '')}"

            self.append_gds(
                name=auto_name,
                value=self.default_properties[property_name]["value"],
                vary=self.default_properties[property_name]["vary"],
            )
            return auto_name

    def write(self):
        """Write GDS rows to file."""
        with open("gds.csv", "w") as out:
            out.writelines(self.rows)


class PathsWriter:
    def __init__(self, default_variables: "dict[str, dict]"):
        self.rows = [
            f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'degen':>5s}, "
            f"{'s02':>3s}, {'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n"
        ]
        self.gds_writer = GDSWriter(default_variables=default_variables)
        self.all_combinations = [[0]]  # 0 corresponds to the header row

    def parse_feff_output(
        self,
        paths_file: str,
        selection: "dict[str, str|list]",
        directory_label: str = "",
    ):
        """Parse selected paths from CSV summary and define GDS variables.

        Args:
            paths_file (str): CSV summary filename.
            selection (dict[str, str|list]): Dictionary indicating which paths
                to select, and how to define their variables.
            directory_label (str, optional): Label to indicate paths from a
                separate directory. Defaults to "".
        """
        combinations_list = []
        if selection["selection"] in {"criteria", "combinations"}:
            selector = CriteriaSelector(selection)
        else:
            selector = ManualSelector(selection)

        selected_ids = self.select_rows(paths_file, directory_label, selector)

        if selection["selection"] == "combinations":
            min_number = selection["min_combination_size"]
            min_number = min(min_number, len(selected_ids))
            max_number = selection["max_combination_size"]
            if not max_number or max_number > len(selected_ids):
                max_number = len(selected_ids)

            for number_of_paths in range(min_number, max_number + 1):
                for combination in combinations(selected_ids, number_of_paths):
                    combinations_list.append(combination)

            new_combinations = len(combinations_list)
            print(
                f"{new_combinations} combinations for {directory_label}:\n"
                f"{combinations_list}"
            )
            old_combinations_len = len(self.all_combinations)
            self.all_combinations *= new_combinations
            for i, combination in enumerate(self.all_combinations):
                new_combinations = combinations_list[i // old_combinations_len]
                self.all_combinations[i] = combination + list(new_combinations)
        else:
            for combination in self.all_combinations:
                combination.extend(selected_ids)

    def select_rows(
        self,
        paths_file: str,
        directory_label: str,
        selector: "CriteriaSelector|ManualSelector",
    ) -> "list[int]":
        """Evaluate each row in turn to decide whether or not it should be
        included in the final output. Does not account for combinations.

        Args:
            paths_file (str): CSV summary filename.
            directory_label (str): Label to indicate paths from a separate
                directory.
            selector (CriteriaSelector|ManualSelector): Object to evaluate
                whether to select each path or not.

        Returns:
            list[int]: The ids of the selected rows.
        """
        row_ids = []
        with open(paths_file) as file:
            reader = csv.reader(file)
            for row in reader:
                id_match = re.search(r"\d+", row[0])
                if id_match:
                    path_id = int(id_match.group())
                    selected, path_value = selector.evaluate(
                        path_id=path_id,
                        row=row,
                    )
                    if selected:
                        filename = row[0].strip()
                        path_label = row[-2].strip()
                        row_id = self.parse_row(
                            directory_label, filename, path_label, path_value
                        )
                        row_ids.append(row_id)

        return row_ids

    def parse_row(
        self,
        directory_label: str,
        filename: str,
        path_label: str,
        path_value: "None|dict",
    ) -> int:
        """Parse row for GDS and path information.

        Args:
            directory_label (str): Label to indicate paths from a separate
                directory.
            filename (str): Filename for the FEFF path, extracted from row.
            path_label (str): Label for the FEFF path, extracted from row.
            path_value (None|dict): The values associated with the selected
                FEFF path. May be None in which case defaults are used.

        Returns:
            int: The id of the added row.
        """
        variables = {}
        if path_value is not None:
            for property in self.gds_writer.default_properties:
                variables[property] = self.gds_writer.parse_gds(
                    property_name=property,
                    variable_name=path_value[property]["name"],
                    path_variable=path_value[property],
                    directory_label=directory_label,
                    path_label=path_label,
                )
        else:
            for property in self.gds_writer.default_properties:
                variables[property] = self.gds_writer.parse_gds(
                    property_name=property,
                    directory_label=directory_label,
                    path_label=path_label,
                )

        return self.parse_selected_path(
            filename=filename,
            path_label=path_label,
            directory_label=directory_label,
            **variables,
        )

    def parse_selected_path(
        self,
        filename: str,
        path_label: str,
        directory_label: str = "",
        degen: str = "degen",
        s02: str = "s02",
        e0: str = "e0",
        sigma2: str = "sigma2",
        deltar: str = "alpha",
    ) -> int:
        """Format and append row representing a selected FEFF path.

        Args:
            filename (str): Name of the underlying FEFF path file, without
                parent directory.
            path_label (str): Label indicating the atoms involved in this path.
            directory_label (str, optional): Label to indicate paths from a
                separate directory. Defaults to "".
            degen (str, optional): Path degeneracy variable name.
                Defaults to "degen".
            s02 (str, optional): Electron screening factor variable name.
                Defaults to "s02".
            e0 (str, optional): Energy shift variable name. Defaults to "e0".
            sigma2 (str, optional): Mean squared displacement variable name.
                Defaults to "sigma2".
            deltar (str, optional): Change in path length variable.
                Defaults to "alpha".

        Returns:
            int: The id of the added row.
        """
        if directory_label:
            filename = os.path.join(directory_label, filename)
            label = f"{directory_label}.{path_label}"
        else:
            filename = os.path.join("feff", filename)
            label = path_label

        row_id = len(self.rows)
        self.rows.append(
            f"{row_id:>4d}, {filename:>24s}, {label:>24s}, {degen:>5s}, "
            f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar + '*reff':>10s}\n"
        )

        return row_id

    def write(self):
        """Write selected path and GDS rows to file."""
        self.gds_writer.write()

        if len(self.all_combinations) == 1:
            with open("sp.csv", "w") as out:
                out.writelines(self.rows)
        else:
            for combination in self.all_combinations:
                filename = "_".join([str(c) for c in combination[1:]])
                print(f"Writing combination {filename}")
                with open(f"sp/{filename}.csv", "w") as out:
                    for row_id, row in enumerate(self.rows):
                        if row_id in combination:
                            out.write(row)


def main(input_values: dict):
    """Select paths and define GDS parameters.

    Args:
        input_values (dict): All input values from the Galaxy tool UI.

    Raises:
        ValueError: If a FEFF label is not unique.
    """
    default_variables = input_values["variables"]

    writer = PathsWriter(default_variables=default_variables)

    if len(input_values["feff_outputs"]) == 1:
        feff_output = input_values["feff_outputs"][0]
        writer.parse_feff_output(
            paths_file=feff_output["paths_file"],
            selection=feff_output["selection"],
        )
    else:
        zfill_length = len(str(len(input_values["feff_outputs"])))
        labels = set()
        with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out:
            for i, feff_output in enumerate(input_values["feff_outputs"]):
                label = feff_output["label"]
                if not label:
                    label = str(i + 1).zfill(zfill_length)
                if label in labels:
                    raise ValueError(f"Label '{label}' is not unique")
                labels.add(label)

                writer.parse_feff_output(
                    directory_label=label,
                    paths_file=feff_output["paths_file"],
                    selection=feff_output["selection"],
                )

                with ZipFile(feff_output["paths_zip"]) as z:
                    for zipinfo in z.infolist():
                        if zipinfo.filename != "feff/":
                            zipinfo.filename = zipinfo.filename[5:]
                            z.extract(member=zipinfo, path=label)
                            filename = os.path.join(label, zipinfo.filename)
                            zipfile_out.write(filename)

    writer.write()


if __name__ == "__main__":
    input_values = json.load(open(sys.argv[1], "r", encoding="utf-8"))
    main(input_values)