view larch_artemis.py @ 4:39ab361e6d59 draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_artemis commit 71cee2ed96b69a2e78a1eb3dadbd2e81bf332798
author muon-spectroscopy-computational-project
date Mon, 17 Jun 2024 13:54:30 +0000
parents 84c8e04bc1a1
children 7acb53ffb96f
line wrap: on
line source

import csv
import faulthandler
import gc
import json
import os
import sys

from common import read_group, sorting_key

from larch.fitting import guess, param, param_group
from larch.symboltable import Group
from larch.xafs import (
    FeffPathGroup,
    FeffitDataSet,
    TransformGroup,
    feffit,
    feffit_report,
)

import matplotlib
import matplotlib.pyplot as plt

import numpy as np


def read_csv_data(input_file, id_field="id"):
    csv_data = {}
    try:
        with open(input_file, encoding="utf8") as csvfile:
            reader = csv.DictReader(csvfile, skipinitialspace=True)
            for row in reader:
                csv_data[int(row[id_field])] = row
    except FileNotFoundError:
        print("The specified file does not exist")
    return csv_data


def dict_to_gds(data_dict):
    dgs_group = param_group()
    for par_idx in data_dict:
        # gds file structure:
        gds_name = data_dict[par_idx]["name"]
        gds_val = 0.0
        gds_expr = ""
        try:
            gds_val = float(data_dict[par_idx]["value"])
        except ValueError:
            gds_val = 0.00
        gds_expr = data_dict[par_idx]["expr"]
        gds_vary = (
            True
            if str(data_dict[par_idx]["vary"]).strip().capitalize() == "True"
            else False
        )
        one_par = None
        if gds_vary:
            # equivalent to a guess parameter in Demeter
            one_par = guess(
                name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
            )
        else:
            # equivalent to a defined parameter in Demeter
            one_par = param(
                name=gds_name, value=gds_val, vary=gds_vary, expr=gds_expr
            )
        if one_par is not None:
            dgs_group.__setattr__(gds_name, one_par)
    return dgs_group


def plot_rmr(path: str, data_set, rmin, rmax):
    plt.figure()
    plt.plot(data_set.data.r, data_set.data.chir_mag, color="b")
    plt.plot(data_set.data.r, data_set.data.chir_re, color="b", label="expt.")
    plt.plot(data_set.model.r, data_set.model.chir_mag, color="r")
    plt.plot(data_set.model.r, data_set.model.chir_re, color="r", label="fit")
    plt.ylabel(
        "Magnitude of Fourier Transform of "
        r"$k^2 \cdot \chi$/$\mathrm{\AA}^{-3}$"
    )
    plt.xlabel(r"Radial distance/$\mathrm{\AA}$")
    plt.xlim(0, 5)

    plt.fill(
        [rmin, rmin, rmax, rmax],
        [-rmax, rmax, rmax, -rmax],
        color="g",
        alpha=0.1,
    )
    plt.text(rmax - 0.65, -rmax + 0.5, "fit range")
    plt.legend()
    plt.savefig(path, format="png")
    plt.close("all")


def plot_chikr(path: str, data_set, rmin, rmax, kmin, kmax):
    fig = plt.figure(figsize=(16, 4))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.plot(
        data_set.data.k,
        data_set.data.chi * data_set.data.k**2,
        color="b",
        label="expt.",
    )
    ax1.plot(
        data_set.model.k,
        data_set.model.chi * data_set.data.k**2,
        color="r",
        label="fit",
    )
    ax1.set_xlim(0, 15)
    ax1.set_xlabel(r"$k (\mathrm{\AA})^{-1}$")
    ax1.set_ylabel(r"$k^2$ $\chi (k)(\mathrm{\AA})^{-2}$")

    ax1.fill(
        [kmin, kmin, kmax, kmax],
        [-rmax, rmax, rmax, -rmax],
        color="g",
        alpha=0.1,
    )
    ax1.text(kmax - 1.65, -rmax + 0.5, "fit range")
    ax1.legend()

    ax2.plot(data_set.data.r, data_set.data.chir_mag, color="b", label="expt.")
    ax2.plot(data_set.model.r, data_set.model.chir_mag, color="r", label="fit")
    ax2.set_xlim(0, 5)
    ax2.set_xlabel(r"$R(\mathrm{\AA})$")
    ax2.set_ylabel(r"$|\chi(R)|(\mathrm{\AA}^{-3})$")
    ax2.legend(loc="upper right")

    ax2.fill(
        [rmin, rmin, rmax, rmax],
        [-rmax, rmax, rmax, -rmax],
        color="g",
        alpha=0.1,
    )
    ax2.text(rmax - 0.65, -rmax + 0.5, "fit range")
    fig.savefig(path, format="png")
    plt.close("all")


def read_gds(gds_file):
    gds_pars = read_csv_data(gds_file)
    dgs_group = dict_to_gds(gds_pars)
    return dgs_group


def read_selected_paths_list(file_name):
    sp_dict = read_csv_data(file_name)
    sp_list = []
    for path_id in sp_dict:
        filename = sp_dict[path_id]["filename"]
        print(f"Reading selected path for file {filename}")
        new_path = FeffPathGroup(
            filename=filename,
            label=sp_dict[path_id]["label"],
            s02=sp_dict[path_id]["s02"],
            e0=sp_dict[path_id]["e0"],
            sigma2=sp_dict[path_id]["sigma2"],
            deltar=sp_dict[path_id]["deltar"],
        )
        sp_list.append(new_path)
    return sp_list


def run_fit(data_group, gds, selected_paths, fv):
    # create the transform group (prepare the fit space).
    trans = TransformGroup(
        fitspace=fv["fitspace"],
        kmin=fv["kmin"],
        kmax=fv["kmax"],
        kweight=fv["kweight"],
        dk=fv["dk"],
        window=fv["window"],
        rmin=fv["rmin"],
        rmax=fv["rmax"],
    )

    dset = FeffitDataSet(
        data=data_group, pathlist=selected_paths, transform=trans
    )

    out = feffit(gds, dset)
    return dset, out


def main(
    prj_file: str,
    gds_file: str,
    sp_file: str,
    fit_vars: dict,
    plot_graph: bool,
    series_id: str = "",
) -> Group:
    report_path = f"report/fit_report{series_id}.txt"
    rmr_path = f"rmr/rmr{series_id}.png"
    chikr_path = f"chikr/chikr{series_id}.png"

    # calc_with_defaults will hang indefinitely (>6 hours recorded) if the
    # data contains any NaNs - consider adding an early error here if this is
    # not fixed in Larch?
    data_group = read_group(prj_file)

    print(f"Fitting project from file {data_group.filename}")

    gds = read_gds(gds_file)
    selected_paths = read_selected_paths_list(sp_file)
    dset, out = run_fit(data_group, gds, selected_paths, fit_vars)

    fit_report = feffit_report(out)
    with open(report_path, "w") as fit_report_file:
        fit_report_file.write(fit_report)

    if plot_graph:
        plot_rmr(rmr_path, dset, fit_vars["rmin"], fit_vars["rmax"])
        plot_chikr(
            chikr_path,
            dset,
            fit_vars["rmin"],
            fit_vars["rmax"],
            fit_vars["kmin"],
            fit_vars["kmax"],
        )
    return out


def check_threshold(
    series_id: str,
    threshold: float,
    variable: str,
    value: float,
    early_stopping: bool = False,
):
    if abs(value) > threshold:
        if early_stopping:
            message = (
                "ERROR: Stopping series fit after project "
                f"{series_id} as {variable} > {threshold}"
            )
        else:
            message = (
                f"WARNING: Project {series_id} has {variable} > {threshold}"
            )

        print(message)
        return early_stopping

    return False


def series_execution(
    filepaths: "list[str]",
    gds_file: str,
    sp_file: str,
    fit_vars: dict,
    plot_graph: bool,
    report_criteria: "list[dict]",
    stop_on_error: bool,
) -> "list[list[str]]":
    report_criteria = input_values["execution"]["report_criteria"]
    id_length = len(str(len(filepaths)))
    stop = False
    rows = [[f"{c['variable']:>12s}" for c in report_criteria]]
    for series_index, series_file in enumerate(filepaths):
        series_id = str(series_index).zfill(id_length)
        try:
            out = main(
                series_file,
                gds_file,
                sp_file,
                fit_vars,
                plot_graph,
                f"_{series_id}",
            )
        except ValueError as e:
            rows.append([np.NaN for _ in report_criteria])
            if stop_on_error:
                print(
                    f"ERROR: fitting failed for {series_id}"
                    f" due to following error, stopping:\n{e}"
                )
                break
            else:
                print(
                    f"WARNING: fitting failed for {series_id} due to following"
                    f" error, continuing to next project:\n{e}"
                )
                continue

        row = []
        for criterium in report_criteria:
            stop = parse_row(series_id, out, row, criterium) or stop
        rows.append(row)

        gc.collect()

        if stop:
            break

    return rows


def parse_row(series_id: str, group: Group, row: "list[str]", criterium: dict):
    action = criterium["action"]["action"]
    variable = criterium["variable"]
    try:
        value = group.__getattribute__(variable)
    except AttributeError:
        value = group.params[variable].value

    row.append(f"{value:>12f}")
    if action == "stop":
        return check_threshold(
            series_id,
            criterium["action"]["threshold"],
            variable,
            value,
            True,
        )
    elif action == "warn":
        return check_threshold(
            series_id,
            criterium["action"]["threshold"],
            variable,
            value,
            False,
        )

    return False


if __name__ == "__main__":
    faulthandler.enable()
    # larch imports set this to an interactive backend, so need to change it
    matplotlib.use("Agg")

    prj_file = sys.argv[1]
    gds_file = sys.argv[2]
    sp_file = sys.argv[3]
    input_values = json.load(open(sys.argv[4], "r", encoding="utf-8"))
    fit_vars = input_values["fit_vars"]
    plot_graph = input_values["plot_graph"]

    if input_values["execution"]["execution"] == "parallel":
        main(prj_file, gds_file, sp_file, fit_vars, plot_graph)

    else:
        if os.path.isdir(prj_file):
            # Sort the unzipped directory, all filenames should be zero-padded
            filepaths = [
                os.path.join(prj_file, p) for p in os.listdir(prj_file)
            ]
            filepaths.sort(key=sorting_key)
        else:
            # DO NOT sort if we have multiple Galaxy datasets - the filenames
            # are arbitrary but should be in order
            filepaths = prj_file.split(",")

        rows = series_execution(
            filepaths,
            gds_file,
            sp_file,
            fit_vars,
            plot_graph,
            input_values["execution"]["report_criteria"],
            input_values["execution"]["stop_on_error"],
        )
        if len(rows[0]) > 0:
            with open("criteria_report.csv", "w") as f:
                writer = csv.writer(f)
                writer.writerows(rows)