diff table_scipy_interpolate.py @ 0:0112f08c95ed draft default tip

planemo upload for repository https://github.com/RECETOX/galaxytools/tree/master/tools/tables commit d0ff40eb2b536fec6c973c3a9ea8e7f31cd9a0d6
author recetox
date Wed, 29 Jan 2025 15:36:02 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/table_scipy_interpolate.py	Wed Jan 29 15:36:02 2025 +0000
@@ -0,0 +1,177 @@
+import argparse
+import logging
+from typing import Callable, Tuple
+
+
+import numpy as np
+import pandas as pd
+from scipy.interpolate import Akima1DInterpolator, CubicSpline, PchipInterpolator
+from utils import LoadDataAction, StoreOutputAction
+
+
+class InterpolationModelAction(argparse.Action):
+    def __call__(
+        self,
+        parser: argparse.ArgumentParser,
+        namespace: argparse.Namespace,
+        values: str,
+        option_string: str = None,
+    ) -> None:
+        """
+        Custom argparse action to map interpolation method names to their corresponding functions.
+
+        Parameters:
+        parser (argparse.ArgumentParser): The argument parser instance.
+        namespace (argparse.Namespace): The namespace to hold the parsed values.
+        values (str): The interpolation method name.
+        option_string (str): The option string.
+        """
+        interpolators = {
+            "linear": np.interp,
+            "cubic": CubicSpline,
+            "pchip": PchipInterpolator,
+            "akima": Akima1DInterpolator,
+        }
+        if values not in interpolators:
+            raise ValueError(f"Unknown interpolation method: {values}")
+        setattr(namespace, self.dest, interpolators[values])
+
+
+def interpolate_data(
+    reference: pd.DataFrame,
+    query: pd.DataFrame,
+    x_col: int,
+    y_col: int,
+    xnew_col: int,
+    model: Callable,
+    output_dataset: Tuple[Callable[[pd.DataFrame, str], None], str],
+) -> None:
+    """
+    Interpolate data using the specified model.
+
+    Parameters:
+    reference (pd.DataFrame): The reference dataset.
+    query (pd.DataFrame): The query dataset.
+    x_col (int): The 1-based index of the x column in the reference dataset.
+    y_col (int): The 1-based index of the y column in the reference dataset.
+    xnew_col (int): The 1-based index of the x column in the query dataset.
+    model (Callable): The interpolation model to use.
+    output_dataset (Tuple[Callable[[pd.DataFrame, str], None], str]): The output dataset and its file extension.
+    """
+    try:
+        # Convert 1-based indices to 0-based indices
+        x_col_name = reference.columns[x_col - 1]
+        y_col_name = reference.columns[y_col - 1]
+        xnew_col_name = query.columns[xnew_col - 1]
+
+        # Check if y_col already exists in the query dataset
+        if y_col_name in query.columns:
+            raise ValueError(
+                f"Column '{y_col_name}' already exists in the query dataset."
+            )
+
+        if model == np.interp:
+            query[y_col_name] = model(
+                query[xnew_col_name], reference[x_col_name], reference[y_col_name]
+            )
+        else:
+            model_instance = model(reference[x_col_name], reference[y_col_name])
+            query[y_col_name] = model_instance(query[xnew_col_name]).astype(float)
+
+        write_func, file_path = output_dataset
+        write_func(query, file_path)
+    except Exception as e:
+        logging.error(f"Error in interpolate_data function: {e}")
+        raise
+
+
+def main(
+    reference_dataset: pd.DataFrame,
+    query_dataset: pd.DataFrame,
+    x_col: int,
+    y_col: int,
+    xnew_col: int,
+    model: Callable,
+    output_dataset: Tuple[Callable[[pd.DataFrame, str], None], str],
+) -> None:
+    """
+    Main function to load the datasets, perform interpolation, and save the result.
+
+    Parameters:
+    reference_dataset (Tuple[pd.DataFrame, str]): The reference dataset and its file extension.
+    query_dataset (Tuple[pd.DataFrame, str]): The query dataset and its file extension.
+    x_col (int): The 1-based index of the x column in the reference dataset.
+    y_col (int): The 1-based index of the y column in the reference dataset.
+    xnew_col (int): The 1-based index of the x column in the query dataset.
+    model (Callable): The interpolation model to use.
+    output_dataset (Tuple[Callable[[pd.DataFrame, str], None], str]): The output dataset and its file extension.
+    """
+    try:
+        interpolate_data(reference_dataset, query_dataset, x_col, y_col, xnew_col, model, output_dataset)
+    except Exception as e:
+        logging.error(f"Error in main function: {e}")
+        raise
+
+
+if __name__ == "__main__":
+    logging.basicConfig(level=logging.INFO)
+    parser = argparse.ArgumentParser(
+        description="Interpolate data using various methods."
+    )
+    parser.add_argument(
+        "--reference_dataset",
+        nargs=2,
+        action=LoadDataAction,
+        required=True,
+        help="Path to the reference dataset and its file extension (csv, tsv, parquet)",
+    )
+    parser.add_argument(
+        "--query_dataset",
+        nargs=2,
+        action=LoadDataAction,
+        required=True,
+        help="Path to the query dataset and its file extension (csv, tsv, parquet)",
+    )
+    parser.add_argument(
+        "--x_col",
+        type=int,
+        required=True,
+        help="1-based index of the x column in the reference dataset",
+    )
+    parser.add_argument(
+        "--y_col",
+        type=int,
+        required=True,
+        help="1-based index of the y column in the reference dataset",
+    )
+    parser.add_argument(
+        "--xnew_col",
+        type=int,
+        required=True,
+        help="1-based index of the x column in the query dataset",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        action=InterpolationModelAction,
+        required=True,
+        help="Interpolation model to use (linear, cubic, pchip, akima)",
+    )
+    parser.add_argument(
+        "--output_dataset",
+        nargs=2,
+        action=StoreOutputAction,
+        required=True,
+        help="Path to the output dataset and its file extension (csv, tsv, parquet)",
+    )
+
+    args = parser.parse_args()
+    main(
+        args.reference_dataset,
+        args.query_dataset,
+        args.x_col,
+        args.y_col,
+        args.xnew_col,
+        args.model,
+        args.output_dataset,
+    )