diff larch_plot.py @ 0:886949a03377 draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 5be486890442dedfb327289d597e1c8110240735
author muon-spectroscopy-computational-project
date Tue, 14 Nov 2023 15:35:36 +0000
parents
children 002c18a3e642
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/larch_plot.py	Tue Nov 14 15:35:36 2023 +0000
@@ -0,0 +1,101 @@
+import json
+import sys
+
+from common import read_groups
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+import numpy as np
+
+
+Y_LABELS = {
+    "norm": r"x$\mu$(E), normalised",
+    "dmude": r"d(x$\mu$(E))/dE, normalised",
+    "chir_mag": r"|$\chi$(r)|",
+}
+
+
+def main(dat_files: "list[str]", plot_settings: "list[dict]"):
+    groups = list(read_groups(dat_files))
+
+    for i, settings in enumerate(plot_settings):
+        data_list = []
+        e0_min = None
+        e0_max = None
+        variable = settings["variable"]["variable"]
+        x_min = settings["variable"]["energy_min"]
+        x_max = settings["variable"]["energy_max"]
+        plot_path = f"plots/{i}_{variable}.png"
+        plt.figure()
+
+        for group in groups:
+            label = group.athena_params.annotation or group.athena_params.id
+            if variable == "chir_mag":
+                x = group.r
+                energy_format = None
+            else:
+                x = group.energy
+                energy_format = settings["variable"]["energy_format"]
+                if energy_format == "relative":
+                    e0 = group.athena_params.bkg.e0
+                    e0_min = find_relative_limit(e0_min, e0, min)
+                    e0_max = find_relative_limit(e0_max, e0, max)
+
+            y = getattr(group, variable)
+            if x_min is None and x_max is None:
+                plt.plot(x, y, label=label)
+            else:
+                data_list.append({"x": x, "y": y, "label": label})
+
+        if variable != "chir_mag" and energy_format == "relative":
+            if x_min is not None:
+                x_min += e0_min
+            if x_max is not None:
+                x_max += e0_max
+
+        if x_min is not None or x_max is not None:
+            for data in data_list:
+                index_min = None
+                index_max = None
+                x = data["x"]
+                if x_min is not None:
+                    index_min = max(np.searchsorted(x, x_min) - 1, 0)
+                if x_max is not None:
+                    index_max = min(np.searchsorted(x, x_max) + 1, len(x))
+                plt.plot(
+                    x[index_min:index_max],
+                    data["y"][index_min:index_max],
+                    label=data["label"],
+                )
+
+        plt.xlim(x_min, x_max)
+
+        save_plot(variable, plot_path)
+
+
+def find_relative_limit(e0_min: "float|None", e0: float, function: callable):
+    if e0_min is None:
+        e0_min = e0
+    else:
+        e0_min = function(e0_min, e0)
+    return e0_min
+
+
+def save_plot(y_type: str, plot_path: str):
+    plt.grid(color="r", linestyle=":", linewidth=1)
+    plt.xlabel("Energy (eV)")
+    plt.ylabel(Y_LABELS[y_type])
+    plt.legend()
+    plt.savefig(plot_path, format="png")
+    plt.close("all")
+
+
+if __name__ == "__main__":
+    # larch imports set this to an interactive backend, so need to change it
+    matplotlib.use("Agg")
+
+    dat_files = sys.argv[1]
+    input_values = json.load(open(sys.argv[2], "r", encoding="utf-8"))
+
+    main(dat_files.split(","), input_values["plots"])