changeset 4:35d24102cefd draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 3fe6078868efd0fcea0fb5eea8dcd4b152d9c0a8
author muon-spectroscopy-computational-project
date Thu, 11 Apr 2024 09:02:24 +0000
parents 5b993aff09e3
children 3584db5902b5
files common.py larch_plot.py larch_plot.xml
diffstat 3 files changed, 84 insertions(+), 52 deletions(-) [+]
line wrap: on
line diff
--- a/common.py	Fri Mar 22 14:23:33 2024 +0000
+++ b/common.py	Thu Apr 11 09:02:24 2024 +0000
@@ -7,14 +7,14 @@
 
 
 def get_group(athena_group: AthenaGroup, key: str = None) -> Group:
-    group_keys = list(athena_group._athena_groups.keys())
+    group_keys = list(athena_group.keys())
     if key is None:
         key = group_keys[0]
     else:
         key = key.replace("-", "_")
 
     try:
-        return extract_athenagroup(athena_group._athena_groups[key])
+        return extract_athenagroup(athena_group.groups[key])
     except KeyError as e:
         raise KeyError(f"{key} not in {group_keys}") from e
 
@@ -28,7 +28,7 @@
         do_fft=False,
     )
     all_groups = {}
-    for key in athena_group._athena_groups.keys():
+    for key in athena_group.keys():
         print(f"\nExtracting group {key}")
         group = get_group(athena_group, key)
         pre_edge_with_defaults(group=group)
@@ -52,13 +52,22 @@
     return group
 
 
-def pre_edge_with_defaults(group: Group, settings: dict = None):
+def pre_edge_with_defaults(
+    group: Group, settings: dict = None, ref_channel: str = None
+):
     merged_settings = {}
-    try:
-        bkg_parameters = group.athena_params.bkg
-    except AttributeError as e:
-        print(f"Cannot load group.athena_params.bkg from group:\n{e}")
-        bkg_parameters = None
+    if ref_channel is not None:
+        print(f"Performing pre-edge with reference channel {ref_channel}")
+        ref = getattr(group, ref_channel.lower())
+        group.e0 = None
+        pre_edge(energy=group.energy, mu=ref, group=group)
+        bkg_parameters = group.pre_edge_details
+    else:
+        try:
+            bkg_parameters = group.athena_params.bkg
+        except AttributeError as e:
+            print(f"Cannot load group.athena_params.bkg from group:\n{e}")
+            bkg_parameters = None
 
     keys = (
         ("e0", "e0", None),
--- a/larch_plot.py	Fri Mar 22 14:23:33 2024 +0000
+++ b/larch_plot.py	Thu Apr 11 09:02:24 2024 +0000
@@ -3,6 +3,8 @@
 
 from common import read_groups
 
+from larch.symboltable import Group
+
 import matplotlib
 import matplotlib.pyplot as plt
 
@@ -10,67 +12,84 @@
 
 
 AXIS_LABELS = {
+    "energy": "Energy (eV)",
+    "distance": "r (ang)",
+    "sample": "Sample",
     "flat": r"x$\mu$(E), flattened",
     "dmude": r"d(x$\mu$(E))/dE, normalised",
     "chir_mag": r"|$\chi$(r)|",
-    "energy": "Energy (eV)",
-    "distance": "r (ang)",
+    "e0": "Edge Energy (eV)",
 }
 
 
+def sample_plot(groups: "list[Group]", y_variable: str):
+    x = [get_label(group) for group in groups]
+    y = [getattr(group, y_variable) for group in groups]
+    plt.scatter(x, y)
+
+
 def main(dat_files: "list[str]", plot_settings: "list[dict]"):
     groups = list(read_groups(dat_files))
 
     for i, settings in enumerate(plot_settings):
         data_list = []
-        x_variable = "energy"
         y_variable = settings["variable"]["variable"]
-        x_min = settings["variable"]["x_limit_min"]
-        x_max = settings["variable"]["x_limit_max"]
-        y_min = settings["variable"]["y_limit_min"]
-        y_max = settings["variable"]["y_limit_max"]
         plot_path = f"plots/{i}_{y_variable}.png"
         plt.figure()
 
-        for group in groups:
-            params = group.athena_params
-            annotation = getattr(params, "annotation", None)
-            file = getattr(params, "file", None)
-            params_id = getattr(params, "id", None)
-            label = annotation or file or params_id
-            if y_variable == "chir_mag":
-                x_variable = "distance"
-                x = group.r
-            else:
-                x = group.energy
-
-            y = getattr(group, y_variable)
-            if x_min is None and x_max is None:
-                plt.plot(x, y, label=label)
-            else:
-                data_list.append({"x": x, "y": y, "label": label})
+        if y_variable == "e0":
+            x_variable = "sample"
+            sample_plot(groups, y_variable)
+        else:
+            x_variable = "energy"
+            x_min = settings["variable"]["x_limit_min"]
+            x_max = settings["variable"]["x_limit_max"]
+            y_min = settings["variable"]["y_limit_min"]
+            y_max = settings["variable"]["y_limit_max"]
+            for group in groups:
+                label = get_label(group)
+                if y_variable == "chir_mag":
+                    x_variable = "distance"
+                    x = group.r
+                else:
+                    x = group.energy
 
-        if x_min is not None or x_max is not None:
-            for data in data_list:
-                index_min = None
-                index_max = None
-                x = data["x"]
-                if x_min is not None:
-                    index_min = max(np.searchsorted(x, x_min) - 1, 0)
-                if x_max is not None:
-                    index_max = min(np.searchsorted(x, x_max) + 1, len(x))
-                plt.plot(
-                    x[index_min:index_max],
-                    data["y"][index_min:index_max],
-                    label=data["label"],
-                )
+                y = getattr(group, y_variable)
+                if x_min is None and x_max is None:
+                    plt.plot(x, y, label=label)
+                else:
+                    data_list.append({"x": x, "y": y, "label": label})
 
-        plt.xlim(x_min, x_max)
-        plt.ylim(y_min, y_max)
+            if x_min is not None or x_max is not None:
+                for data in data_list:
+                    index_min = None
+                    index_max = None
+                    x = data["x"]
+                    if x_min is not None:
+                        index_min = max(np.searchsorted(x, x_min) - 1, 0)
+                    if x_max is not None:
+                        index_max = min(np.searchsorted(x, x_max) + 1, len(x))
+                    plt.plot(
+                        x[index_min:index_max],
+                        data["y"][index_min:index_max],
+                        label=data["label"],
+                    )
+
+            plt.xlim(x_min, x_max)
+            plt.ylim(y_min, y_max)
 
         save_plot(x_variable, y_variable, plot_path)
 
 
+def get_label(group: Group) -> str:
+    params = group.athena_params
+    annotation = getattr(params, "annotation", None)
+    file = getattr(params, "file", None)
+    params_id = getattr(params, "id", None)
+    label = annotation or file or params_id
+    return label
+
+
 def save_plot(x_type: str, y_type: str, plot_path: str):
     plt.grid(color="r", linestyle=":", linewidth=1)
     plt.xlabel(AXIS_LABELS[x_type])
--- a/larch_plot.xml	Fri Mar 22 14:23:33 2024 +0000
+++ b/larch_plot.xml	Thu Apr 11 09:02:24 2024 +0000
@@ -2,9 +2,9 @@
     <description>plot Athena projects</description>
     <macros>
         <!-- version of underlying tool (PEP 440) -->
-        <token name="@TOOL_VERSION@">0.9.74</token>
+        <token name="@TOOL_VERSION@">0.9.75</token>
         <!-- version of this tool wrapper (integer) -->
-        <token name="@WRAPPER_VERSION@">1</token>
+        <token name="@WRAPPER_VERSION@">0</token>
         <!-- citation should be updated with every underlying tool version -->
         <!-- typical fields to update are version, month, year, and doi -->
         <token name="@TOOL_CITATION@">10.1088/1742-6596/430/1/012007</token>
@@ -36,6 +36,7 @@
                     <option value="flat" selected="true">Flattened xμ</option>
                     <option value="dmude">Derivative of xμ</option>
                     <option value="chir_mag">Magnitude of χ(r)</option>
+                    <option value="e0">E0</option>
                 </param>
                 <when value="flat">
                     <expand macro="plot_limits_energy"/>
@@ -53,6 +54,8 @@
                     <param name="y_limit_min" type="float" label="Minimum |χ(r)|" optional="true" help="If set, plot will be limited to this value on the y axis."/>
                     <param name="y_limit_max" type="float" label="Maximum |χ(r)|" optional="true" help="If set, plot will be limited to this value on the y axis."/>
                 </when>
+                <when value="e0">
+                </when>
             </conditional>
         </repeat>
     </inputs>
@@ -68,7 +71,8 @@
             <param name="variable" value="flat"/>
             <param name="variable" value="dmude"/>
             <param name="variable" value="chir_mag"/>
-            <output_collection name="plot_collection" type="list" count="3"/>
+            <param name="variable" value="e0"/>
+            <output_collection name="plot_collection" type="list" count="4"/>
         </test>
         <!-- 2: plot limits -->
         <test expect_num_outputs="1">