view maplot.py @ 0:e9212adafd7a draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc commit d5065f0bdf2d38c2344d96d68537223c1096daab
author iuc
date Thu, 15 May 2025 12:55:13 +0000
parents
children
line wrap: on
line source

import argparse
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import plotly.subplots as sp
import statsmodels.api as sm  # to build a LOWESS model
from scipy.stats import gaussian_kde


# subplot titles
def make_subplot_titles(sample_names: List[str]) -> List[str]:
    """Generates subplot titles for the MA plot.

    Args:
        sample_names (list): List of sample names.

    Returns:
        list: List of subplot titles.
    """
    subplot_titles = []
    num_samples = len(sample_names)
    for i in range(num_samples):
        for j in range(num_samples):
            if i == j:
                subplot_titles.append(f"{sample_names[i]}")
            else:
                subplot_titles.append(f"{sample_names[i]} vs. {sample_names[j]}")
    return subplot_titles


def densities(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """Calculates the density of points for a scatter plot.

    Args:
        x (array-like): X-axis values.
        y (array-like): Y-axis values.

    Returns:
        array: Density values for the points.
    """
    values = np.vstack([x, y])
    return gaussian_kde(values)(values)


def movingaverage(data: np.ndarray, window_width: int) -> np.ndarray:
    """Calculates the moving average of the data.

    Args:
        data (array-like): Input data.
        window_width (int): Width of the moving window.

    Returns:
        array: Moving average values.
    """
    cumsum_vec = np.cumsum(np.insert(data, 0, 0))
    ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width
    return ma_vec


def update_max(current: float, values: np.ndarray) -> float:
    """Updates the maximum value.

    Args:
        current (float): Current maximum value.
        values (array-like): Array of values to compare.

    Returns:
        float: Updated maximum value.
    """
    return max(current, np.max(values))


def get_indices(
    num_samples: int, num_cols: int, plot_num: int
) -> Tuple[int, int, int, int]:
    """Calculates the indices for subplot placement.

    Args:
        num_samples (int): Number of samples.
        num_cols (int): Number of columns in the subplot grid.
        plot_num (int): Plot number.

    Returns:
        tuple: Indices for subplot placement (i, j, col, row).
    """
    i = plot_num // num_samples
    j = plot_num % num_samples
    col = plot_num % num_cols + 1
    row = plot_num // num_cols + 1
    return i, j, col, row


def create_subplot_data(
    frac: float,
    it: int,
    num_bins: int,
    window_width: int,
    samples: pd.DataFrame,
    i: int,
    j: int,
) -> Dict:
    """Creates data for a single subplot.

    Args:
        frac (float): LOESS smoothing parameter.
        it (int): Number of iterations for LOESS smoothing.
        num_bins (int): Number of bins for histogram.
        window_width (int): Window width for moving average.
        samples (DataFrame): DataFrame containing sample data.
        i (int): Index of the first sample.
        j (int): Index of the second sample.

    Returns:
        dict: Data for the subplot.
    """
    subplot_data = {}
    subplot_data["mean"] = np.log(samples.iloc[:, [i, j]].mean(axis=1))
    if i == j:
        counts, bins = np.histogram(subplot_data["mean"], bins=num_bins)
        subplot_data["bins"] = bins
        subplot_data["counts"] = counts
        subplot_data["counts_smoothed"] = movingaverage(counts, window_width)
        subplot_data["max_counts"] = np.max(counts)
    else:
        subplot_data["log_fold_change"] = np.log2(
            samples.iloc[:, i] / samples.iloc[:, j]
        )
        subplot_data["max_log_fold_change"] = np.max(subplot_data["log_fold_change"])
        subplot_data["densities"] = densities(
            subplot_data["mean"], subplot_data["log_fold_change"]
        )
        subplot_data["regression"] = sm.nonparametric.lowess(
            subplot_data["log_fold_change"], subplot_data["mean"], frac=frac, it=it
        )
    return subplot_data


def create_plot_data(
    frac: float,
    it: int,
    num_bins: int,
    window_width: int,
    samples: pd.DataFrame,
    num_samples: int,
    num_plots: int,
    num_cols: int,
) -> List[Dict]:
    """Creates data for all subplots.

    Args:
        frac (float): LOESS smoothing parameter.
        it (int): Number of iterations for LOESS smoothing.
        num_bins (int): Number of bins for histogram.
        window_width (int): Window width for moving average.
        samples (DataFrame): DataFrame containing sample data.
        num_samples (int): Number of samples.
        num_plots (int): Number of plots.
        num_cols (int): Number of columns in the subplot grid.

    Returns:
        list: List of data for each subplot.
    """
    plots_data = []
    for plot_num in range(num_plots):
        i, j, _, _ = get_indices(num_samples, num_cols, plot_num)
        subplot_data = create_subplot_data(
            frac, it, num_bins, window_width, samples, i, j
        )
        plots_data.append(subplot_data)
    return plots_data


def ma_plots_plotly(
    num_rows: int,
    num_cols: int,
    num_plots: int,
    plots_data: List[Dict],
    sample_names: List[str],
    size: int,
    ylim_hist: float,
    ylim_ma: float,
    features: np.ndarray,
) -> go.Figure:
    """Generates MA plots using Plotly.

    Args:
        num_rows (int): Number of rows in the subplot grid.
        num_cols (int): Number of columns in the subplot grid.
        num_plots (int): Number of plots.
        plots_data (list): List of data for each subplot.
        sample_names (list): List of sample names.
        size (int): Size of the plot.
        ylim_hist (float): Y-axis limit for histograms.
        ylim_ma (float): Y-axis limit for MA plots.
        features (array-like): Feature names.

    Returns:
        Figure: Plotly figure object.
    """
    fig = sp.make_subplots(
        rows=num_rows,
        cols=num_cols,
        shared_xaxes="all",
        subplot_titles=make_subplot_titles(sample_names),
    )

    for plot_num in range(num_plots):
        i, j, col, row = get_indices(len(sample_names), num_cols, plot_num)
        subplot_data = plots_data[plot_num]

        mean = subplot_data["mean"]

        if i == j:
            # Plot histogram on the diagonal
            hist_bar = go.Bar(
                x=subplot_data["bins"],
                y=subplot_data["counts"],
            )
            fig.add_trace(hist_bar, row=row, col=col)

            hist_line = go.Scatter(
                x=subplot_data["bins"],
                y=subplot_data["counts_smoothed"],
                marker=dict(
                    color="red",
                ),
            )
            fig.add_trace(hist_line, row=row, col=col)
            fig.update_yaxes(
                title_text="Counts",
                range=[0, ylim_hist],
                matches="y1",
                showticklabels=True,
                row=row,
                col=col,
            )
        else:
            log_fold_change = subplot_data["log_fold_change"]
            scatter = go.Scatter(
                x=mean,
                y=log_fold_change,
                mode="markers",
                marker=dict(
                    color=subplot_data["densities"], symbol="circle", colorscale="jet"
                ),
                name=f"{sample_names[i]} vs {sample_names[j]}",
                text=features,
                hovertemplate="<b>%{text}</b><br>Log Mean: %{x}<br>Log2 Fold Change: %{y}<extra></extra>",
            )
            fig.add_trace(scatter, row=row, col=col)

            regression = subplot_data["regression"]
            line = go.Scatter(
                x=regression[:, 0],
                y=regression[:, 1],
                mode="lines",
                line=dict(color="red"),
                name=f"LOWESS {sample_names[i]} vs. {sample_names[j]}",
            )
            fig.add_trace(line, row=row, col=col)

            fig.update_yaxes(
                title_text="Log2 Fold Change",
                range=[-ylim_ma, ylim_ma],
                matches="y2",
                showticklabels=True,
                row=row,
                col=col,
            )
        fig.update_xaxes(
            title_text="Log Mean Intensity", showticklabels=True, row=row, col=col
        )

    # Update layout for the entire figure
    fig.update_layout(
        height=size * num_rows,
        width=size * num_cols,
        showlegend=False,
        template="simple_white",  # Apply the 'plotly_white' template
    )
    return fig


def ma_plots_matplotlib(
    num_rows: int,
    num_cols: int,
    num_plots: int,
    pots_data: List[Dict],
    sample_names: List[str],
    size: int,
    ylim_hist: float,
    ylim_ma: float,
    window_width: int,
) -> plt.Figure:
    """Generates MA plots using Matplotlib.

    Args:
        num_rows (int): Number of rows in the subplot grid.
        num_cols (int): Number of columns in the subplot grid.
        num_plots (int): Number of plots.
        pots_data (list): List of data for each subplot.
        sample_names (list): List of sample names.
        size (int): Size of the plot.
        ylim_hist (float): Y-axis limit for histograms.
        ylim_ma (float): Y-axis limit for MA plots.
        window_width (int): Window width for moving average.

    Returns:
        Figure: Matplotlib figure object.
    """
    subplot_titles = make_subplot_titles(sample_names)
    fig, axes = plt.subplots(
        num_rows,
        num_cols,
        figsize=(size * num_cols / 100, size * num_rows / 100),
        dpi=300,
        sharex="all",
    )
    axes = axes.flatten()

    for plot_num in range(num_plots):
        i, j, _, _ = get_indices(len(sample_names), num_cols, plot_num)
        subplot_data = pots_data[plot_num]

        mean = subplot_data["mean"]

        ax = axes[plot_num]

        if i == j:
            # Plot histogram on the diagonal
            ax.bar(
                subplot_data["bins"][:-1],
                subplot_data["counts"],
                width=np.diff(subplot_data["bins"]),
                edgecolor="black",
                align="edge",
            )

            # Plot moving average line
            ax.plot(
                subplot_data["bins"][window_width // 2: -window_width // 2],
                subplot_data["counts_smoothed"],
                color="red",
            )

            ax.set_ylabel("Counts")
            ax.set_ylim(0, ylim_hist)
        else:
            # Scatter plot
            ax.scatter(
                mean,
                subplot_data["log_fold_change"],
                c=subplot_data["densities"],
                cmap="jet",
                edgecolor="black",
                label=f"{sample_names[i]} vs {sample_names[j]}",
            )

            # Regression line
            regression = subplot_data["regression"]
            ax.plot(
                regression[:, 0],
                regression[:, 1],
                color="red",
                label=f"LOWESS {sample_names[i]} vs. {sample_names[j]}",
            )

            ax.set_ylabel("Log2 Fold Change")
            ax.set_ylim(-ylim_ma, ylim_ma)

        ax.set_xlabel("Log Mean Intensity")
        ax.tick_params(labelbottom=True)  # Force showing x-tick labels
        ax.set_title(subplot_titles[plot_num])  # Add subplot title

    # Adjust layout
    plt.tight_layout()
    return fig


def main():
    """Main function to generate MA plots."""
    parser = argparse.ArgumentParser(description="Generate MA plots.")
    parser.add_argument("--file_path", type=str, help="Path to the input CSV file")
    parser.add_argument("--file_extension", type=str, help="File extension")
    parser.add_argument(
        "--frac", type=float, default=4 / 5, help="LOESS smoothing parameter"
    )
    parser.add_argument(
        "--it", type=int, default=5, help="Number of iterations for LOESS smoothing"
    )
    parser.add_argument(
        "--num_bins", type=int, default=100, help="Number of bins for histogram"
    )
    parser.add_argument(
        "--window_width", type=int, default=5, help="Window width for moving average"
    )
    parser.add_argument("--size", type=int, default=500, help="Size of the plot")
    parser.add_argument(
        "--scale", type=int, default=3, help="Scale factor for the plot"
    )
    parser.add_argument(
        "--y_scale_factor", type=float, default=1.1, help="Y-axis scale factor"
    )
    parser.add_argument(
        "--max_num_cols",
        type=int,
        default=100,
        help="Maximum number of columns in the plot",
    )
    parser.add_argument(
        "--interactive",
        action="store_true",
        help="Generate interactive plot using Plotly",
    )
    parser.add_argument(
        "--output_format",
        type=str,
        default="pdf",
        choices=["pdf", "png", "html"],
        help="Output format for the plot",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="ma_plot",
        help="Output file name without extension",
    )

    args = parser.parse_args()

    # Load the data
    file_extension = args.file_extension.lower()
    if file_extension == "csv":
        data = pd.read_csv(args.file_path)
    elif file_extension in ["txt", "tsv", "tabular"]:
        data = pd.read_csv(args.file_path, sep="\t")
    elif file_extension == "parquet":
        data = pd.read_parquet(args.file_path)
    else:
        raise ValueError(f"Unsupported file format: {file_extension}")

    features = data.iloc[:, 0]  # Assuming the first column is the feature names
    samples = data.iloc[:, 1:]  # and the rest are samples

    # Create a subplot figure
    num_samples = samples.shape[1]
    sample_names = samples.columns
    num_plots = num_samples**2
    num_cols = min(num_samples, args.max_num_cols)
    num_rows = int(np.ceil(num_plots / num_cols))

    plots_data = create_plot_data(
        args.frac,
        args.it,
        args.num_bins,
        args.window_width,
        samples,
        num_samples,
        num_plots,
        num_cols,
    )

    count_max = np.max([x.get("max_counts", 0) for x in plots_data])
    log_fold_change_max = np.max([x.get("max_log_fold_change", 0) for x in plots_data])

    ylim_hist = count_max * args.y_scale_factor
    ylim_ma = log_fold_change_max * args.y_scale_factor

    if args.interactive:
        fig = ma_plots_plotly(
            num_rows,
            num_cols,
            num_plots,
            plots_data,
            sample_names,
            args.size,
            ylim_hist,
            ylim_ma,
            features,
        )
        fig.show()
        if args.output_format == "html":
            fig.write_html(f"{args.output_file}")
        else:
            pio.write_image(
                fig,
                f"{args.output_file}",
                format=args.output_format,
                width=args.size * num_cols,
                height=args.size * num_rows,
                scale=args.scale,
            )
    else:
        fig = ma_plots_matplotlib(
            num_rows,
            num_cols,
            num_plots,
            plots_data,
            sample_names,
            args.size,
            ylim_hist,
            ylim_ma,
            args.window_width,
        )
        plt.show()
        fig.savefig(f"{args.output_file}", format=args.output_format, dpi=300)
    return 0


if __name__ == "__main__":
    main()