view larch_artemis.py @ 6:d17c5d62802f draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit be7fc75b35a77b50a588575ca0749c43f8fd7cf5
author muon-spectroscopy-computational-project
date Fri, 23 Aug 2024 16:44:13 +0000
parents 7acb53ffb96f
children
line wrap: on
line source

import csv
import faulthandler
import gc
import json
import os
import sys

from common import read_group, sorting_key

from larch.fitting import guess, param, param_group
from larch.symboltable import Group
from larch.xafs import (
    FeffPathGroup,
    FeffitDataSet,
    TransformGroup,
    feffit,
    feffit_report,
)

import matplotlib
import matplotlib.pyplot as plt

import numpy as np


def read_csv_data(input_file, id_field="id"):
    csv_data = {}
    try:
        with open(input_file, encoding="utf8") as csvfile:
            reader = csv.DictReader(csvfile, skipinitialspace=True)
            for row in reader:
                csv_data[int(row[id_field])] = row
    except FileNotFoundError:
        print("The specified file does not exist")
    return csv_data


def dict_to_gds(data_dict):
    dgs_group = param_group()
    for par_idx in data_dict:
        # gds file structure:
        gds_name = data_dict[par_idx]["name"]
        gds_val = 0.0
        gds_expr = ""
        try:
            gds_val = float(data_dict[par_idx]["value"])
        except ValueError:
            continue
        gds_expr = data_dict[par_idx]["expr"]
        gds_vary = (
            True
            if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True"
            else False
        )
        one_par = None
        if gds_vary:
            # equivalent to a guess parameter in Demeter
            one_par = guess(
                name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
            )
        else:
            # equivalent to a defined parameter in Demeter
            one_par = param(
                name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
            )
        if one_par is not None:
            dgs_group.__setattr__(gds_name, one_par)
    return dgs_group


def plot_rmr(path: str, datasets: list, rmin, rmax):
    plt.figure()
    for i, dataset in enumerate(datasets):
        plt.subplot(len(datasets), 1, i + 1)
        data = dataset.data
        model = dataset.model
        plt.plot(data.r, data.chir_mag, color="b")
        plt.plot(data.r, data.chir_re, color="b", label="expt.")
        plt.plot(model.r, model.chir_mag, color="r")
        plt.plot(model.r, model.chir_re, color="r", label="fit")
        plt.ylabel(
            "Magnitude of Fourier Transform of "
            r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
        )
        plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
        plt.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)
        plt.legend()

    plt.savefig(path, format="png")
    plt.close("all")


def plot_chikr(path: str, datasets, rmin, rmax, kmin, kmax):
    fig = plt.figure(figsize=(16, 4))
    for i, dataset in enumerate(datasets):
        data = dataset.data
        model = dataset.model
        ax1 = fig.add_subplot(len(datasets), 2, 2*i + 1)
        ax2 = fig.add_subplot(len(datasets), 2, 2*i + 2)
        ax1.plot(data.k, data.chi * data.k**2, color="b", label="expt.")
        ax1.plot(model.k, model.chi * data.k**2, color="r", label="fit")
        ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
        ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")
        ax1.axvspan(xmin=kmin, xmax=kmax, color="g", alpha=0.1)
        ax1.legend()

        ax2.plot(data.r, data.chir_mag, color="b", label="expt.")
        ax2.plot(model.r, model.chir_mag, color="r", label="fit")
        ax2.set_xlim(0, 5)
        ax2.set_xlabel(r"$R(\mathrm{\AA})$")
        ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
        ax2.legend(loc="upper right")
        ax2.axvspan(xmin=rmin, xmax=rmax, color="g", alpha=0.1)

    fig.savefig(path, format="png")
    plt.close("all")


def read_gds(gds_file):
    gds_pars = read_csv_data(gds_file)
    dgs_group = dict_to_gds(gds_pars)
    return dgs_group


def read_selected_paths_list(file_name):
    sp_dict = read_csv_data(file_name)
    sp_list = []
    for path_dict in sp_dict.values():
        filename = path_dict["filename"]
        if not os.path.isfile(filename):
            raise FileNotFoundError(
                f"{filename} not found, check paths in the Selected Paths "
                "table match those in the zipped directory structure."
            )

        print(f"Reading selected path for file {filename}")
        new_path = FeffPathGroup(
            filename=filename,
            label=path_dict["label"],
            degen=path_dict["degen"] if path_dict["degen"] != "" else None,
            s02=path_dict["s02"],
            e0=path_dict["e0"],
            sigma2=path_dict["sigma2"],
            deltar=path_dict["deltar"],
        )
        sp_list.append(new_path)
    return sp_list


def run_fit(
        data_groups: list, gds, pathlist, fv, selected_path_ids: list = None
):
    # create the transform group (prepare the fit space).
    trans = TransformGroup(
        fitspace=fv["fitspace"],
        kmin=fv["kmin"],
        kmax=fv["kmax"],
        kweight=fv["kweight"],
        dk=fv["dk"],
        window=fv["window"],
        rmin=fv["rmin"],
        rmax=fv["rmax"],
    )

    datasets = []
    for i, data_group in enumerate(data_groups):
        if selected_path_ids:
            selected_paths = []
            for path_id in selected_path_ids[i]:
                selected_paths.append(pathlist[path_id - 1])

            dataset = FeffitDataSet(
                data=data_group, pathlist=selected_paths, transform=trans
            )

        else:
            dataset = FeffitDataSet(
                data=data_group, pathlist=pathlist, transform=trans
            )

        datasets.append(dataset)

    out = feffit(gds, datasets)
    return datasets, out


def main(
    prj_file: list,
    gds_file: str,
    sp_file: str,
    fit_vars: dict,
    plot_graph: bool,
    series_id: str = "",
) -> Group:
    report_path = f"report/fit_report{series_id}.txt"
    rmr_path = f"rmr/rmr{series_id}.png"
    chikr_path = f"chikr/chikr{series_id}.png"

    gds = read_gds(gds_file)
    pathlist = read_selected_paths_list(sp_file)

    # calc_with_defaults will hang indefinitely (>6 hours recorded) if the
    # data contains any NaNs - consider adding an early error here if this is
    # not fixed in Larch?
    selected_path_ids = []
    if isinstance(prj_file[0], dict):
        data_groups = []
        for dataset in prj_file:
            data_groups.append(read_group(dataset["prj_file"]))
            selected_path_ids.append([p["path_id"] for p in dataset["paths"]])
    else:
        data_groups = [read_group(p) for p in prj_file]

    print(f"Fitting project from file {[d.filename for d in data_groups]}")

    datasets, out = run_fit(
        data_groups,
        gds,
        pathlist,
        fit_vars,
        selected_path_ids=selected_path_ids,
    )

    fit_report = feffit_report(out)
    with open(report_path, "w") as fit_report_file:
        fit_report_file.write(fit_report)

    if plot_graph:
        plot_rmr(rmr_path, datasets, fit_vars["rmin"], fit_vars["rmax"])
        plot_chikr(
            chikr_path,
            datasets,
            fit_vars["rmin"],
            fit_vars["rmax"],
            fit_vars["kmin"],
            fit_vars["kmax"],
        )
    return out


def check_threshold(
    series_id: str,
    threshold: float,
    variable: str,
    value: float,
    early_stopping: bool = False,
):
    if abs(value) > threshold:
        if early_stopping:
            message = (
                "ERROR: Stopping series fit after project "
                f"{series_id} as {variable} > {threshold}"
            )
        else:
            message = (
                f"WARNING: Project {series_id} has {variable} > {threshold}"
            )

        print(message)
        return early_stopping

    return False


def series_execution(
    filepaths: "list[str]",
    gds_file: str,
    sp_file: str,
    fit_vars: dict,
    plot_graph: bool,
    report_criteria: "list[dict]",
    stop_on_error: bool,
) -> "list[list[str]]":
    report_criteria = input_values["execution"]["report_criteria"]
    id_length = len(str(len(filepaths)))
    stop = False
    rows = [[f"{c['variable']:>12s}" for c in report_criteria]]
    for series_index, series_file in enumerate(filepaths):
        series_id = str(series_index).zfill(id_length)
        try:
            out = main(
                [series_file],
                gds_file,
                sp_file,
                fit_vars,
                plot_graph,
                f"_{series_id}",
            )
        except ValueError as e:
            rows.append([np.NaN for _ in report_criteria])
            if stop_on_error:
                print(
                    f"ERROR: fitting failed for {series_id}"
                    f" due to following error, stopping:\n{e}"
                )
                break
            else:
                print(
                    f"WARNING: fitting failed for {series_id} due to following"
                    f" error, continuing to next project:\n{e}"
                )
                continue

        row = []
        for criterium in report_criteria:
            stop = parse_row(series_id, out, row, criterium) or stop
        rows.append(row)

        gc.collect()

        if stop:
            break

    return rows


def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict):
    action = criterium["action"]["action"]
    variable = criterium["variable"]
    try:
        value = group.__getattribute__(variable)
    except AttributeError:
        value = group.params[variable].value

    row.append(f"{value:>12f}")
    if action == "stop":
        return check_threshold(
            series_id,
            criterium["action"]["threshold"],
            variable,
            value,
            True,
        )
    elif action == "warn":
        return check_threshold(
            series_id,
            criterium["action"]["threshold"],
            variable,
            value,
            False,
        )

    return False


if __name__ == "__main__":
    faulthandler.enable()
    # larch imports set this to an interactive backend, so need to change it
    matplotlib.use("Agg")

    prj_file = sys.argv[1]
    gds_file = sys.argv[2]
    sp_file = sys.argv[3]
    input_values = json.load(open(sys.argv[4], "r", encoding="utf-8"))
    fit_vars = input_values["fit_vars"]
    plot_graph = input_values["plot_graph"]

    if input_values["execution"]["execution"] == "parallel":
        main([prj_file], gds_file, sp_file, fit_vars, plot_graph)
    elif input_values["execution"]["execution"] == "simultaneous":
        dataset_dicts = input_values["execution"]["simultaneous"]
        main(dataset_dicts, gds_file, sp_file, fit_vars, plot_graph)
    else:
        if os.path.isdir(prj_file):
            # Sort the unzipped directory, all filenames should be zero-padded
            filepaths = [
                os.path.join(prj_file, p) for p in os.listdir(prj_file)
            ]
            filepaths.sort(key=sorting_key)
        else:
            # DO NOT sort if we have multiple Galaxy datasets - the filenames
            # are arbitrary but should be in order
            filepaths = prj_file.split(",")

        rows = series_execution(
            filepaths,
            gds_file,
            sp_file,
            fit_vars,
            plot_graph,
            input_values["execution"]["report_criteria"],
            input_values["execution"]["stop_on_error"],
        )
        if len(rows[0]) > 0:
            with open("criteria_report.csv", "w") as f:
                writer = csv.writer(f)
                writer.writerows(rows)