view larch_athena.py @ 0:ae2f265ecf8e draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 5be486890442dedfb327289d597e1c8110240735
author muon-spectroscopy-computational-project
date Tue, 14 Nov 2023 15:34:40 +0000
parents
children 2b3115342fef
line wrap: on
line source

import gc
import json
import os
import re
import sys

from common import read_group

from larch.io import (
    create_athena,
    h5group,
    merge_groups,
    read_ascii,
    set_array_labels,
)
from larch.symboltable import Group
from larch.xafs import autobk, pre_edge, rebin_xafs, xftf

import matplotlib
import matplotlib.pyplot as plt

import numpy as np


class Reader:
    def __init__(
        self,
        energy_column: str,
        mu_column: str,
        xftf_params: dict,
        data_format: str,
        extract_group: str = None,
    ):
        self.energy_column = energy_column
        self.mu_column = mu_column
        self.xftf_params = xftf_params
        self.data_format = data_format
        self.extract_group = extract_group

    def load_data(
        self,
        dat_file: str,
        merge_inputs: bool,
        is_zipped: bool,
    ) -> "dict[str, Group]":
        if merge_inputs:
            out_group = self.merge_files(
                dat_files=dat_file, is_zipped=is_zipped
            )
            return {"out": out_group}
        else:
            return self.load_single_file(
                filepath=dat_file, is_zipped=is_zipped
            )

    def merge_files(
        self,
        dat_files: str,
        is_zipped: bool,
    ) -> Group:
        if is_zipped:
            all_groups = list(self.load_zipped_files().values())
        else:
            all_groups = []
            for filepath in dat_files.split(","):
                group = self.load_single_file(filepath)["out"]
                all_groups.append(group)

        return merge_groups(all_groups, xarray="energy", yarray="mu")

    def load_single_file(
        self,
        filepath: str,
        is_zipped: bool = False,
    ) -> "dict[str,Group]":
        if is_zipped:
            return self.load_zipped_files()

        print(f"Attempting to read from {filepath}")
        if self.data_format == "athena":
            group = read_group(filepath, self.extract_group, self.xftf_params)
        else:
            # Try ascii anyway
            try:
                group = self.load_ascii(filepath)
                if not group.array_labels:
                    # In later versions of larch, won't get a type error it
                    # will just fail to load any data
                    group = self.load_h5(filepath)
            except (UnicodeDecodeError, TypeError):
                # Indicates this isn't plaintext, try h5
                group = self.load_h5(filepath)
        return {"out": group}

    def load_ascii(self, dat_file):
        with open(dat_file) as f:
            labels = None
            last_line = None
            line = f.readline()
            while line:
                if not line.startswith("#"):
                    if last_line is not None and last_line.find("\t") > 0:
                        labels = []
                        for label in last_line.split("\t"):
                            labels.append(label.strip())
                    break

                last_line = line
                line = f.readline()

        xas_data = read_ascii(filename=dat_file, labels=labels)
        xas_data = self.rename_cols(xas_data)
        return xas_data

    def load_h5(self, dat_file):
        h5_group = h5group(fname=dat_file, mode="r")
        energy = h5_group.entry1.instrument.qexafs_energy.qexafs_energy
        mu = h5_group.entry1.instrument.qexafs_counterTimer01.lnI0It
        xafs_group = Group(data=np.array([energy[:], mu[:]]))
        set_array_labels(xafs_group, ["energy", "mu"])
        return xafs_group

    def load_zipped_files(self) -> "dict[str, Group]":
        def sorting_key(filename: str) -> str:
            return re.findall(r"\d+", filename)[-1]

        all_paths = list(os.walk("dat_files"))
        all_paths.sort(key=lambda x: x[0])
        file_total = sum([len(f) for _, _, f in all_paths])
        print(f"{file_total} files found")
        key_length = len(str(file_total))
        i = 0
        keyed_data = {}
        for dirpath, _, filenames in all_paths:
            try:
                filenames.sort(key=sorting_key)
            except IndexError as e:
                print(
                    "WARNING: Unable to sort files numerically, "
                    f"defaulting to sorting alphabetically:\n{e}"
                )
                filenames.sort()

            for filename in filenames:
                key = str(i).zfill(key_length)
                filepath = os.path.join(dirpath, filename)
                xas_data = self.load_single_file(filepath)
                keyed_data[key] = xas_data["out"]
                i += 1

        return keyed_data

    def rename_cols(self, xafs_group: Group) -> Group:
        labels = [label.lower() for label in xafs_group.array_labels]
        print(f"Read columns: {labels}")

        if "energy" in labels:
            print("'energy' present in column headers")
        elif self.energy_column is not None:
            if self.energy_column.lower() in labels:
                labels[labels.index(self.energy_column.lower())] = "energy"
            else:
                raise ValueError(f"{self.energy_column} not found in {labels}")
        else:
            for i, label in enumerate(labels):
                if label == "col1" or label.endswith("energy"):
                    labels[i] = "energy"
                    break

        if "mu" in labels:
            print("'mu' present in column headers")
        elif self.mu_column is not None:
            if self.mu_column.lower() in labels:
                labels[labels.index(self.mu_column.lower())] = "mu"
            else:
                raise ValueError(f"{self.mu_column} not found in {labels}")
        else:
            for i, label in enumerate(labels):
                if label in ["col2", "xmu", "lni0it", "ffi0"]:
                    labels[i] = "mu"
                    break

        if labels != xafs_group.array_labels:
            print(f"Renaming columns to: {labels}")
            return set_array_labels(xafs_group, labels)
        else:
            return xafs_group


def calibrate_energy(
    xafs_group: Group,
    energy_0: float,
    energy_min: float,
    energy_max: float,
    energy_format: str,
):
    if energy_0 is not None:
        print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}")
        xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0
        xafs_group.e0 = energy_0

    if not (energy_min or energy_max):
        return xafs_group

    if energy_min:
        if energy_format == "relative":
            energy_min += xafs_group.e0
        index_min = np.searchsorted(xafs_group.energy, energy_min)
    else:
        index_min = 0

    if energy_max:
        if energy_format == "relative":
            energy_max += xafs_group.e0
        index_max = np.searchsorted(xafs_group.energy, energy_max)
    else:
        index_max = len(xafs_group.energy)

    print(
        f"Cropping energy range from {energy_min} to {energy_max}, "
        f"index {index_min} to {index_max}"
    )
    try:
        xafs_group.dmude = xafs_group.dmude[index_min:index_max]
        xafs_group.pre_edge = xafs_group.pre_edge[index_min:index_max]
        xafs_group.post_edge = xafs_group.post_edge[index_min:index_max]
        xafs_group.flat = xafs_group.flat[index_min:index_max]
    except AttributeError:
        pass

    xafs_group.energy = xafs_group.energy[index_min:index_max]
    xafs_group.mu = xafs_group.mu[index_min:index_max]

    # Sanity check
    if len(xafs_group.energy) == 0:
        raise ValueError("Energy cropping led to an empty array")

    return xafs_group


def main(
    xas_data: Group,
    input_values: dict,
    path_key: str = "out",
):
    energy_0 = input_values["variables"]["energy_0"]
    if energy_0 is None and hasattr(xas_data, "e0"):
        energy_0 = xas_data.e0

    energy_format = input_values["variables"]["energy_format"]
    pre1 = input_values["variables"]["pre1"]
    pre2 = input_values["variables"]["pre2"]
    pre1 = validate_pre(pre1, energy_0, energy_format)
    pre2 = validate_pre(pre2, energy_0, energy_format)

    pre_edge(
        energy=xas_data.energy,
        mu=xas_data.mu,
        group=xas_data,
        e0=energy_0,
        pre1=pre1,
        pre2=pre2,
    )

    energy_min = input_values["variables"]["energy_min"]
    energy_max = input_values["variables"]["energy_max"]
    xas_data = calibrate_energy(
        xas_data,
        energy_0,
        energy_min,
        energy_max,
        energy_format=energy_format,
    )

    if input_values["rebin"]:
        print(xas_data.energy, xas_data.mu)
        rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
        xas_data = xas_data.rebinned
        pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)

    try:
        autobk(xas_data)
    except ValueError as e:
        raise ValueError(
            f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n"
            "This may occur if the edge is not included in the above ranges."
        ) from e
    xftf(xas_data, **xftf_params)

    if input_values["plot_graph"]:
        plot_edge_fits(f"edge/{path_key}.png", xas_data)
        plot_flattened(f"flat/{path_key}.png", xas_data)
        plot_derivative(f"derivative/{path_key}.png", xas_data)

    xas_project = create_athena(f"prj/{path_key}.prj")
    xas_project.add_group(xas_data)
    if input_values["annotation"]:
        group = next(iter(xas_project.groups.values()))
        group.args["annotation"] = input_values["annotation"]
    xas_project.save()

    # Ensure that we do not run out of memory when running on large zips
    gc.collect()


def validate_pre(pre, energy_0, energy_format):
    if pre is not None and energy_format == "absolute":
        if energy_0 is None:
            raise ValueError(
                "Edge energy must be set manually or be present in the "
                "existing Athena project if using absolute format."
            )
        pre -= energy_0

    return pre


def plot_derivative(plot_path: str, xafs_group: Group):
    plt.figure()
    plt.plot(xafs_group.energy, xafs_group.dmude)
    plt.grid(color="r", linestyle=":", linewidth=1)
    plt.xlabel("Energy (eV)")
    plt.ylabel("Derivative normalised to x$\mu$(E)")  # noqa: W605
    plt.savefig(plot_path, format="png")
    plt.close("all")


def plot_edge_fits(plot_path: str, xafs_group: Group):
    plt.figure()
    plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge")
    plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge")
    plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data")
    plt.grid(color="r", linestyle=":", linewidth=1)
    plt.xlabel("Energy (eV)")
    plt.ylabel("x$\mu$(E)")  # noqa: W605
    plt.title("pre-edge and post_edge fitting to $\mu$")  # noqa: W605
    plt.legend()
    plt.savefig(plot_path, format="png")
    plt.close("all")


def plot_flattened(plot_path: str, xafs_group: Group):
    plt.figure()
    plt.plot(xafs_group.energy, xafs_group.flat)
    plt.grid(color="r", linestyle=":", linewidth=1)
    plt.xlabel("Energy (eV)")
    plt.ylabel("normalised x$\mu$(E)")  # noqa: W605
    plt.savefig(plot_path, format="png")
    plt.close("all")


if __name__ == "__main__":
    # larch imports set this to an interactive backend, so need to change it
    matplotlib.use("Agg")

    dat_file = sys.argv[1]
    input_values = json.load(open(sys.argv[2], "r", encoding="utf-8"))
    merge_inputs = input_values["merge_inputs"]["merge_inputs"]
    data_format = input_values["merge_inputs"]["format"]["format"]
    if "is_zipped" in input_values["merge_inputs"]["format"]:
        is_zipped = bool(
            input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"]
        )
    else:
        is_zipped = False
    xftf_params = input_values["variables"]["xftf"]
    extract_group = None

    if "extract_group" in input_values["merge_inputs"]["format"]:
        extract_group = input_values["merge_inputs"]["format"]["extract_group"]

    energy_column = None
    mu_column = None
    if "energy_column" in input_values["merge_inputs"]["format"]:
        energy_column = input_values["merge_inputs"]["format"]["energy_column"]
    if "mu_column" in input_values["merge_inputs"]["format"]:
        mu_column = input_values["merge_inputs"]["format"]["mu_column"]

    reader = Reader(
        energy_column=energy_column,
        mu_column=mu_column,
        xftf_params=xftf_params,
        data_format=data_format,
        extract_group=extract_group,
    )
    keyed_data = reader.load_data(
        dat_file=dat_file,
        merge_inputs=merge_inputs,
        is_zipped=is_zipped,
    )
    for key, group in keyed_data.items():
        main(
            group,
            input_values=input_values,
            path_key=key,
        )