Mercurial > repos > iuc > maplot
diff maplot.py @ 0:e9212adafd7a draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc commit d5065f0bdf2d38c2344d96d68537223c1096daab
author | iuc |
---|---|
date | Thu, 15 May 2025 12:55:13 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/maplot.py Thu May 15 12:55:13 2025 +0000 @@ -0,0 +1,515 @@ +import argparse +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio +import plotly.subplots as sp +import statsmodels.api as sm # to build a LOWESS model +from scipy.stats import gaussian_kde + + +# subplot titles +def make_subplot_titles(sample_names: List[str]) -> List[str]: + """Generates subplot titles for the MA plot. + + Args: + sample_names (list): List of sample names. + + Returns: + list: List of subplot titles. + """ + subplot_titles = [] + num_samples = len(sample_names) + for i in range(num_samples): + for j in range(num_samples): + if i == j: + subplot_titles.append(f"{sample_names[i]}") + else: + subplot_titles.append(f"{sample_names[i]} vs. {sample_names[j]}") + return subplot_titles + + +def densities(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """Calculates the density of points for a scatter plot. + + Args: + x (array-like): X-axis values. + y (array-like): Y-axis values. + + Returns: + array: Density values for the points. + """ + values = np.vstack([x, y]) + return gaussian_kde(values)(values) + + +def movingaverage(data: np.ndarray, window_width: int) -> np.ndarray: + """Calculates the moving average of the data. + + Args: + data (array-like): Input data. + window_width (int): Width of the moving window. + + Returns: + array: Moving average values. + """ + cumsum_vec = np.cumsum(np.insert(data, 0, 0)) + ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width + return ma_vec + + +def update_max(current: float, values: np.ndarray) -> float: + """Updates the maximum value. + + Args: + current (float): Current maximum value. + values (array-like): Array of values to compare. + + Returns: + float: Updated maximum value. + """ + return max(current, np.max(values)) + + +def get_indices( + num_samples: int, num_cols: int, plot_num: int +) -> Tuple[int, int, int, int]: + """Calculates the indices for subplot placement. + + Args: + num_samples (int): Number of samples. + num_cols (int): Number of columns in the subplot grid. + plot_num (int): Plot number. + + Returns: + tuple: Indices for subplot placement (i, j, col, row). + """ + i = plot_num // num_samples + j = plot_num % num_samples + col = plot_num % num_cols + 1 + row = plot_num // num_cols + 1 + return i, j, col, row + + +def create_subplot_data( + frac: float, + it: int, + num_bins: int, + window_width: int, + samples: pd.DataFrame, + i: int, + j: int, +) -> Dict: + """Creates data for a single subplot. + + Args: + frac (float): LOESS smoothing parameter. + it (int): Number of iterations for LOESS smoothing. + num_bins (int): Number of bins for histogram. + window_width (int): Window width for moving average. + samples (DataFrame): DataFrame containing sample data. + i (int): Index of the first sample. + j (int): Index of the second sample. + + Returns: + dict: Data for the subplot. + """ + subplot_data = {} + subplot_data["mean"] = np.log(samples.iloc[:, [i, j]].mean(axis=1)) + if i == j: + counts, bins = np.histogram(subplot_data["mean"], bins=num_bins) + subplot_data["bins"] = bins + subplot_data["counts"] = counts + subplot_data["counts_smoothed"] = movingaverage(counts, window_width) + subplot_data["max_counts"] = np.max(counts) + else: + subplot_data["log_fold_change"] = np.log2( + samples.iloc[:, i] / samples.iloc[:, j] + ) + subplot_data["max_log_fold_change"] = np.max(subplot_data["log_fold_change"]) + subplot_data["densities"] = densities( + subplot_data["mean"], subplot_data["log_fold_change"] + ) + subplot_data["regression"] = sm.nonparametric.lowess( + subplot_data["log_fold_change"], subplot_data["mean"], frac=frac, it=it + ) + return subplot_data + + +def create_plot_data( + frac: float, + it: int, + num_bins: int, + window_width: int, + samples: pd.DataFrame, + num_samples: int, + num_plots: int, + num_cols: int, +) -> List[Dict]: + """Creates data for all subplots. + + Args: + frac (float): LOESS smoothing parameter. + it (int): Number of iterations for LOESS smoothing. + num_bins (int): Number of bins for histogram. + window_width (int): Window width for moving average. + samples (DataFrame): DataFrame containing sample data. + num_samples (int): Number of samples. + num_plots (int): Number of plots. + num_cols (int): Number of columns in the subplot grid. + + Returns: + list: List of data for each subplot. + """ + plots_data = [] + for plot_num in range(num_plots): + i, j, _, _ = get_indices(num_samples, num_cols, plot_num) + subplot_data = create_subplot_data( + frac, it, num_bins, window_width, samples, i, j + ) + plots_data.append(subplot_data) + return plots_data + + +def ma_plots_plotly( + num_rows: int, + num_cols: int, + num_plots: int, + plots_data: List[Dict], + sample_names: List[str], + size: int, + ylim_hist: float, + ylim_ma: float, + features: np.ndarray, +) -> go.Figure: + """Generates MA plots using Plotly. + + Args: + num_rows (int): Number of rows in the subplot grid. + num_cols (int): Number of columns in the subplot grid. + num_plots (int): Number of plots. + plots_data (list): List of data for each subplot. + sample_names (list): List of sample names. + size (int): Size of the plot. + ylim_hist (float): Y-axis limit for histograms. + ylim_ma (float): Y-axis limit for MA plots. + features (array-like): Feature names. + + Returns: + Figure: Plotly figure object. + """ + fig = sp.make_subplots( + rows=num_rows, + cols=num_cols, + shared_xaxes="all", + subplot_titles=make_subplot_titles(sample_names), + ) + + for plot_num in range(num_plots): + i, j, col, row = get_indices(len(sample_names), num_cols, plot_num) + subplot_data = plots_data[plot_num] + + mean = subplot_data["mean"] + + if i == j: + # Plot histogram on the diagonal + hist_bar = go.Bar( + x=subplot_data["bins"], + y=subplot_data["counts"], + ) + fig.add_trace(hist_bar, row=row, col=col) + + hist_line = go.Scatter( + x=subplot_data["bins"], + y=subplot_data["counts_smoothed"], + marker=dict( + color="red", + ), + ) + fig.add_trace(hist_line, row=row, col=col) + fig.update_yaxes( + title_text="Counts", + range=[0, ylim_hist], + matches="y1", + showticklabels=True, + row=row, + col=col, + ) + else: + log_fold_change = subplot_data["log_fold_change"] + scatter = go.Scatter( + x=mean, + y=log_fold_change, + mode="markers", + marker=dict( + color=subplot_data["densities"], symbol="circle", colorscale="jet" + ), + name=f"{sample_names[i]} vs {sample_names[j]}", + text=features, + hovertemplate="<b>%{text}</b><br>Log Mean: %{x}<br>Log2 Fold Change: %{y}<extra></extra>", + ) + fig.add_trace(scatter, row=row, col=col) + + regression = subplot_data["regression"] + line = go.Scatter( + x=regression[:, 0], + y=regression[:, 1], + mode="lines", + line=dict(color="red"), + name=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", + ) + fig.add_trace(line, row=row, col=col) + + fig.update_yaxes( + title_text="Log2 Fold Change", + range=[-ylim_ma, ylim_ma], + matches="y2", + showticklabels=True, + row=row, + col=col, + ) + fig.update_xaxes( + title_text="Log Mean Intensity", showticklabels=True, row=row, col=col + ) + + # Update layout for the entire figure + fig.update_layout( + height=size * num_rows, + width=size * num_cols, + showlegend=False, + template="simple_white", # Apply the 'plotly_white' template + ) + return fig + + +def ma_plots_matplotlib( + num_rows: int, + num_cols: int, + num_plots: int, + pots_data: List[Dict], + sample_names: List[str], + size: int, + ylim_hist: float, + ylim_ma: float, + window_width: int, +) -> plt.Figure: + """Generates MA plots using Matplotlib. + + Args: + num_rows (int): Number of rows in the subplot grid. + num_cols (int): Number of columns in the subplot grid. + num_plots (int): Number of plots. + pots_data (list): List of data for each subplot. + sample_names (list): List of sample names. + size (int): Size of the plot. + ylim_hist (float): Y-axis limit for histograms. + ylim_ma (float): Y-axis limit for MA plots. + window_width (int): Window width for moving average. + + Returns: + Figure: Matplotlib figure object. + """ + subplot_titles = make_subplot_titles(sample_names) + fig, axes = plt.subplots( + num_rows, + num_cols, + figsize=(size * num_cols / 100, size * num_rows / 100), + dpi=300, + sharex="all", + ) + axes = axes.flatten() + + for plot_num in range(num_plots): + i, j, _, _ = get_indices(len(sample_names), num_cols, plot_num) + subplot_data = pots_data[plot_num] + + mean = subplot_data["mean"] + + ax = axes[plot_num] + + if i == j: + # Plot histogram on the diagonal + ax.bar( + subplot_data["bins"][:-1], + subplot_data["counts"], + width=np.diff(subplot_data["bins"]), + edgecolor="black", + align="edge", + ) + + # Plot moving average line + ax.plot( + subplot_data["bins"][window_width // 2: -window_width // 2], + subplot_data["counts_smoothed"], + color="red", + ) + + ax.set_ylabel("Counts") + ax.set_ylim(0, ylim_hist) + else: + # Scatter plot + ax.scatter( + mean, + subplot_data["log_fold_change"], + c=subplot_data["densities"], + cmap="jet", + edgecolor="black", + label=f"{sample_names[i]} vs {sample_names[j]}", + ) + + # Regression line + regression = subplot_data["regression"] + ax.plot( + regression[:, 0], + regression[:, 1], + color="red", + label=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", + ) + + ax.set_ylabel("Log2 Fold Change") + ax.set_ylim(-ylim_ma, ylim_ma) + + ax.set_xlabel("Log Mean Intensity") + ax.tick_params(labelbottom=True) # Force showing x-tick labels + ax.set_title(subplot_titles[plot_num]) # Add subplot title + + # Adjust layout + plt.tight_layout() + return fig + + +def main(): + """Main function to generate MA plots.""" + parser = argparse.ArgumentParser(description="Generate MA plots.") + parser.add_argument("--file_path", type=str, help="Path to the input CSV file") + parser.add_argument("--file_extension", type=str, help="File extension") + parser.add_argument( + "--frac", type=float, default=4 / 5, help="LOESS smoothing parameter" + ) + parser.add_argument( + "--it", type=int, default=5, help="Number of iterations for LOESS smoothing" + ) + parser.add_argument( + "--num_bins", type=int, default=100, help="Number of bins for histogram" + ) + parser.add_argument( + "--window_width", type=int, default=5, help="Window width for moving average" + ) + parser.add_argument("--size", type=int, default=500, help="Size of the plot") + parser.add_argument( + "--scale", type=int, default=3, help="Scale factor for the plot" + ) + parser.add_argument( + "--y_scale_factor", type=float, default=1.1, help="Y-axis scale factor" + ) + parser.add_argument( + "--max_num_cols", + type=int, + default=100, + help="Maximum number of columns in the plot", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Generate interactive plot using Plotly", + ) + parser.add_argument( + "--output_format", + type=str, + default="pdf", + choices=["pdf", "png", "html"], + help="Output format for the plot", + ) + parser.add_argument( + "--output_file", + type=str, + default="ma_plot", + help="Output file name without extension", + ) + + args = parser.parse_args() + + # Load the data + file_extension = args.file_extension.lower() + if file_extension == "csv": + data = pd.read_csv(args.file_path) + elif file_extension in ["txt", "tsv", "tabular"]: + data = pd.read_csv(args.file_path, sep="\t") + elif file_extension == "parquet": + data = pd.read_parquet(args.file_path) + else: + raise ValueError(f"Unsupported file format: {file_extension}") + + features = data.iloc[:, 0] # Assuming the first column is the feature names + samples = data.iloc[:, 1:] # and the rest are samples + + # Create a subplot figure + num_samples = samples.shape[1] + sample_names = samples.columns + num_plots = num_samples**2 + num_cols = min(num_samples, args.max_num_cols) + num_rows = int(np.ceil(num_plots / num_cols)) + + plots_data = create_plot_data( + args.frac, + args.it, + args.num_bins, + args.window_width, + samples, + num_samples, + num_plots, + num_cols, + ) + + count_max = np.max([x.get("max_counts", 0) for x in plots_data]) + log_fold_change_max = np.max([x.get("max_log_fold_change", 0) for x in plots_data]) + + ylim_hist = count_max * args.y_scale_factor + ylim_ma = log_fold_change_max * args.y_scale_factor + + if args.interactive: + fig = ma_plots_plotly( + num_rows, + num_cols, + num_plots, + plots_data, + sample_names, + args.size, + ylim_hist, + ylim_ma, + features, + ) + fig.show() + if args.output_format == "html": + fig.write_html(f"{args.output_file}") + else: + pio.write_image( + fig, + f"{args.output_file}", + format=args.output_format, + width=args.size * num_cols, + height=args.size * num_rows, + scale=args.scale, + ) + else: + fig = ma_plots_matplotlib( + num_rows, + num_cols, + num_plots, + plots_data, + sample_names, + args.size, + ylim_hist, + ylim_ma, + args.window_width, + ) + plt.show() + fig.savefig(f"{args.output_file}", format=args.output_format, dpi=300) + return 0 + + +if __name__ == "__main__": + main()