comparison maplot.py @ 0:e9212adafd7a draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc commit d5065f0bdf2d38c2344d96d68537223c1096daab
author iuc
date Thu, 15 May 2025 12:55:13 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9212adafd7a
1 import argparse
2 from typing import Dict, List, Tuple
3
4 import matplotlib.pyplot as plt
5 import numpy as np
6 import pandas as pd
7 import plotly.graph_objects as go
8 import plotly.io as pio
9 import plotly.subplots as sp
10 import statsmodels.api as sm # to build a LOWESS model
11 from scipy.stats import gaussian_kde
12
13
14 # subplot titles
15 def make_subplot_titles(sample_names: List[str]) -> List[str]:
16 """Generates subplot titles for the MA plot.
17
18 Args:
19 sample_names (list): List of sample names.
20
21 Returns:
22 list: List of subplot titles.
23 """
24 subplot_titles = []
25 num_samples = len(sample_names)
26 for i in range(num_samples):
27 for j in range(num_samples):
28 if i == j:
29 subplot_titles.append(f"{sample_names[i]}")
30 else:
31 subplot_titles.append(f"{sample_names[i]} vs. {sample_names[j]}")
32 return subplot_titles
33
34
35 def densities(x: np.ndarray, y: np.ndarray) -> np.ndarray:
36 """Calculates the density of points for a scatter plot.
37
38 Args:
39 x (array-like): X-axis values.
40 y (array-like): Y-axis values.
41
42 Returns:
43 array: Density values for the points.
44 """
45 values = np.vstack([x, y])
46 return gaussian_kde(values)(values)
47
48
49 def movingaverage(data: np.ndarray, window_width: int) -> np.ndarray:
50 """Calculates the moving average of the data.
51
52 Args:
53 data (array-like): Input data.
54 window_width (int): Width of the moving window.
55
56 Returns:
57 array: Moving average values.
58 """
59 cumsum_vec = np.cumsum(np.insert(data, 0, 0))
60 ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width
61 return ma_vec
62
63
64 def update_max(current: float, values: np.ndarray) -> float:
65 """Updates the maximum value.
66
67 Args:
68 current (float): Current maximum value.
69 values (array-like): Array of values to compare.
70
71 Returns:
72 float: Updated maximum value.
73 """
74 return max(current, np.max(values))
75
76
77 def get_indices(
78 num_samples: int, num_cols: int, plot_num: int
79 ) -> Tuple[int, int, int, int]:
80 """Calculates the indices for subplot placement.
81
82 Args:
83 num_samples (int): Number of samples.
84 num_cols (int): Number of columns in the subplot grid.
85 plot_num (int): Plot number.
86
87 Returns:
88 tuple: Indices for subplot placement (i, j, col, row).
89 """
90 i = plot_num // num_samples
91 j = plot_num % num_samples
92 col = plot_num % num_cols + 1
93 row = plot_num // num_cols + 1
94 return i, j, col, row
95
96
97 def create_subplot_data(
98 frac: float,
99 it: int,
100 num_bins: int,
101 window_width: int,
102 samples: pd.DataFrame,
103 i: int,
104 j: int,
105 ) -> Dict:
106 """Creates data for a single subplot.
107
108 Args:
109 frac (float): LOESS smoothing parameter.
110 it (int): Number of iterations for LOESS smoothing.
111 num_bins (int): Number of bins for histogram.
112 window_width (int): Window width for moving average.
113 samples (DataFrame): DataFrame containing sample data.
114 i (int): Index of the first sample.
115 j (int): Index of the second sample.
116
117 Returns:
118 dict: Data for the subplot.
119 """
120 subplot_data = {}
121 subplot_data["mean"] = np.log(samples.iloc[:, [i, j]].mean(axis=1))
122 if i == j:
123 counts, bins = np.histogram(subplot_data["mean"], bins=num_bins)
124 subplot_data["bins"] = bins
125 subplot_data["counts"] = counts
126 subplot_data["counts_smoothed"] = movingaverage(counts, window_width)
127 subplot_data["max_counts"] = np.max(counts)
128 else:
129 subplot_data["log_fold_change"] = np.log2(
130 samples.iloc[:, i] / samples.iloc[:, j]
131 )
132 subplot_data["max_log_fold_change"] = np.max(subplot_data["log_fold_change"])
133 subplot_data["densities"] = densities(
134 subplot_data["mean"], subplot_data["log_fold_change"]
135 )
136 subplot_data["regression"] = sm.nonparametric.lowess(
137 subplot_data["log_fold_change"], subplot_data["mean"], frac=frac, it=it
138 )
139 return subplot_data
140
141
142 def create_plot_data(
143 frac: float,
144 it: int,
145 num_bins: int,
146 window_width: int,
147 samples: pd.DataFrame,
148 num_samples: int,
149 num_plots: int,
150 num_cols: int,
151 ) -> List[Dict]:
152 """Creates data for all subplots.
153
154 Args:
155 frac (float): LOESS smoothing parameter.
156 it (int): Number of iterations for LOESS smoothing.
157 num_bins (int): Number of bins for histogram.
158 window_width (int): Window width for moving average.
159 samples (DataFrame): DataFrame containing sample data.
160 num_samples (int): Number of samples.
161 num_plots (int): Number of plots.
162 num_cols (int): Number of columns in the subplot grid.
163
164 Returns:
165 list: List of data for each subplot.
166 """
167 plots_data = []
168 for plot_num in range(num_plots):
169 i, j, _, _ = get_indices(num_samples, num_cols, plot_num)
170 subplot_data = create_subplot_data(
171 frac, it, num_bins, window_width, samples, i, j
172 )
173 plots_data.append(subplot_data)
174 return plots_data
175
176
177 def ma_plots_plotly(
178 num_rows: int,
179 num_cols: int,
180 num_plots: int,
181 plots_data: List[Dict],
182 sample_names: List[str],
183 size: int,
184 ylim_hist: float,
185 ylim_ma: float,
186 features: np.ndarray,
187 ) -> go.Figure:
188 """Generates MA plots using Plotly.
189
190 Args:
191 num_rows (int): Number of rows in the subplot grid.
192 num_cols (int): Number of columns in the subplot grid.
193 num_plots (int): Number of plots.
194 plots_data (list): List of data for each subplot.
195 sample_names (list): List of sample names.
196 size (int): Size of the plot.
197 ylim_hist (float): Y-axis limit for histograms.
198 ylim_ma (float): Y-axis limit for MA plots.
199 features (array-like): Feature names.
200
201 Returns:
202 Figure: Plotly figure object.
203 """
204 fig = sp.make_subplots(
205 rows=num_rows,
206 cols=num_cols,
207 shared_xaxes="all",
208 subplot_titles=make_subplot_titles(sample_names),
209 )
210
211 for plot_num in range(num_plots):
212 i, j, col, row = get_indices(len(sample_names), num_cols, plot_num)
213 subplot_data = plots_data[plot_num]
214
215 mean = subplot_data["mean"]
216
217 if i == j:
218 # Plot histogram on the diagonal
219 hist_bar = go.Bar(
220 x=subplot_data["bins"],
221 y=subplot_data["counts"],
222 )
223 fig.add_trace(hist_bar, row=row, col=col)
224
225 hist_line = go.Scatter(
226 x=subplot_data["bins"],
227 y=subplot_data["counts_smoothed"],
228 marker=dict(
229 color="red",
230 ),
231 )
232 fig.add_trace(hist_line, row=row, col=col)
233 fig.update_yaxes(
234 title_text="Counts",
235 range=[0, ylim_hist],
236 matches="y1",
237 showticklabels=True,
238 row=row,
239 col=col,
240 )
241 else:
242 log_fold_change = subplot_data["log_fold_change"]
243 scatter = go.Scatter(
244 x=mean,
245 y=log_fold_change,
246 mode="markers",
247 marker=dict(
248 color=subplot_data["densities"], symbol="circle", colorscale="jet"
249 ),
250 name=f"{sample_names[i]} vs {sample_names[j]}",
251 text=features,
252 hovertemplate="<b>%{text}</b><br>Log Mean: %{x}<br>Log2 Fold Change: %{y}<extra></extra>",
253 )
254 fig.add_trace(scatter, row=row, col=col)
255
256 regression = subplot_data["regression"]
257 line = go.Scatter(
258 x=regression[:, 0],
259 y=regression[:, 1],
260 mode="lines",
261 line=dict(color="red"),
262 name=f"LOWESS {sample_names[i]} vs. {sample_names[j]}",
263 )
264 fig.add_trace(line, row=row, col=col)
265
266 fig.update_yaxes(
267 title_text="Log2 Fold Change",
268 range=[-ylim_ma, ylim_ma],
269 matches="y2",
270 showticklabels=True,
271 row=row,
272 col=col,
273 )
274 fig.update_xaxes(
275 title_text="Log Mean Intensity", showticklabels=True, row=row, col=col
276 )
277
278 # Update layout for the entire figure
279 fig.update_layout(
280 height=size * num_rows,
281 width=size * num_cols,
282 showlegend=False,
283 template="simple_white", # Apply the 'plotly_white' template
284 )
285 return fig
286
287
288 def ma_plots_matplotlib(
289 num_rows: int,
290 num_cols: int,
291 num_plots: int,
292 pots_data: List[Dict],
293 sample_names: List[str],
294 size: int,
295 ylim_hist: float,
296 ylim_ma: float,
297 window_width: int,
298 ) -> plt.Figure:
299 """Generates MA plots using Matplotlib.
300
301 Args:
302 num_rows (int): Number of rows in the subplot grid.
303 num_cols (int): Number of columns in the subplot grid.
304 num_plots (int): Number of plots.
305 pots_data (list): List of data for each subplot.
306 sample_names (list): List of sample names.
307 size (int): Size of the plot.
308 ylim_hist (float): Y-axis limit for histograms.
309 ylim_ma (float): Y-axis limit for MA plots.
310 window_width (int): Window width for moving average.
311
312 Returns:
313 Figure: Matplotlib figure object.
314 """
315 subplot_titles = make_subplot_titles(sample_names)
316 fig, axes = plt.subplots(
317 num_rows,
318 num_cols,
319 figsize=(size * num_cols / 100, size * num_rows / 100),
320 dpi=300,
321 sharex="all",
322 )
323 axes = axes.flatten()
324
325 for plot_num in range(num_plots):
326 i, j, _, _ = get_indices(len(sample_names), num_cols, plot_num)
327 subplot_data = pots_data[plot_num]
328
329 mean = subplot_data["mean"]
330
331 ax = axes[plot_num]
332
333 if i == j:
334 # Plot histogram on the diagonal
335 ax.bar(
336 subplot_data["bins"][:-1],
337 subplot_data["counts"],
338 width=np.diff(subplot_data["bins"]),
339 edgecolor="black",
340 align="edge",
341 )
342
343 # Plot moving average line
344 ax.plot(
345 subplot_data["bins"][window_width // 2: -window_width // 2],
346 subplot_data["counts_smoothed"],
347 color="red",
348 )
349
350 ax.set_ylabel("Counts")
351 ax.set_ylim(0, ylim_hist)
352 else:
353 # Scatter plot
354 ax.scatter(
355 mean,
356 subplot_data["log_fold_change"],
357 c=subplot_data["densities"],
358 cmap="jet",
359 edgecolor="black",
360 label=f"{sample_names[i]} vs {sample_names[j]}",
361 )
362
363 # Regression line
364 regression = subplot_data["regression"]
365 ax.plot(
366 regression[:, 0],
367 regression[:, 1],
368 color="red",
369 label=f"LOWESS {sample_names[i]} vs. {sample_names[j]}",
370 )
371
372 ax.set_ylabel("Log2 Fold Change")
373 ax.set_ylim(-ylim_ma, ylim_ma)
374
375 ax.set_xlabel("Log Mean Intensity")
376 ax.tick_params(labelbottom=True) # Force showing x-tick labels
377 ax.set_title(subplot_titles[plot_num]) # Add subplot title
378
379 # Adjust layout
380 plt.tight_layout()
381 return fig
382
383
384 def main():
385 """Main function to generate MA plots."""
386 parser = argparse.ArgumentParser(description="Generate MA plots.")
387 parser.add_argument("--file_path", type=str, help="Path to the input CSV file")
388 parser.add_argument("--file_extension", type=str, help="File extension")
389 parser.add_argument(
390 "--frac", type=float, default=4 / 5, help="LOESS smoothing parameter"
391 )
392 parser.add_argument(
393 "--it", type=int, default=5, help="Number of iterations for LOESS smoothing"
394 )
395 parser.add_argument(
396 "--num_bins", type=int, default=100, help="Number of bins for histogram"
397 )
398 parser.add_argument(
399 "--window_width", type=int, default=5, help="Window width for moving average"
400 )
401 parser.add_argument("--size", type=int, default=500, help="Size of the plot")
402 parser.add_argument(
403 "--scale", type=int, default=3, help="Scale factor for the plot"
404 )
405 parser.add_argument(
406 "--y_scale_factor", type=float, default=1.1, help="Y-axis scale factor"
407 )
408 parser.add_argument(
409 "--max_num_cols",
410 type=int,
411 default=100,
412 help="Maximum number of columns in the plot",
413 )
414 parser.add_argument(
415 "--interactive",
416 action="store_true",
417 help="Generate interactive plot using Plotly",
418 )
419 parser.add_argument(
420 "--output_format",
421 type=str,
422 default="pdf",
423 choices=["pdf", "png", "html"],
424 help="Output format for the plot",
425 )
426 parser.add_argument(
427 "--output_file",
428 type=str,
429 default="ma_plot",
430 help="Output file name without extension",
431 )
432
433 args = parser.parse_args()
434
435 # Load the data
436 file_extension = args.file_extension.lower()
437 if file_extension == "csv":
438 data = pd.read_csv(args.file_path)
439 elif file_extension in ["txt", "tsv", "tabular"]:
440 data = pd.read_csv(args.file_path, sep="\t")
441 elif file_extension == "parquet":
442 data = pd.read_parquet(args.file_path)
443 else:
444 raise ValueError(f"Unsupported file format: {file_extension}")
445
446 features = data.iloc[:, 0] # Assuming the first column is the feature names
447 samples = data.iloc[:, 1:] # and the rest are samples
448
449 # Create a subplot figure
450 num_samples = samples.shape[1]
451 sample_names = samples.columns
452 num_plots = num_samples**2
453 num_cols = min(num_samples, args.max_num_cols)
454 num_rows = int(np.ceil(num_plots / num_cols))
455
456 plots_data = create_plot_data(
457 args.frac,
458 args.it,
459 args.num_bins,
460 args.window_width,
461 samples,
462 num_samples,
463 num_plots,
464 num_cols,
465 )
466
467 count_max = np.max([x.get("max_counts", 0) for x in plots_data])
468 log_fold_change_max = np.max([x.get("max_log_fold_change", 0) for x in plots_data])
469
470 ylim_hist = count_max * args.y_scale_factor
471 ylim_ma = log_fold_change_max * args.y_scale_factor
472
473 if args.interactive:
474 fig = ma_plots_plotly(
475 num_rows,
476 num_cols,
477 num_plots,
478 plots_data,
479 sample_names,
480 args.size,
481 ylim_hist,
482 ylim_ma,
483 features,
484 )
485 fig.show()
486 if args.output_format == "html":
487 fig.write_html(f"{args.output_file}")
488 else:
489 pio.write_image(
490 fig,
491 f"{args.output_file}",
492 format=args.output_format,
493 width=args.size * num_cols,
494 height=args.size * num_rows,
495 scale=args.scale,
496 )
497 else:
498 fig = ma_plots_matplotlib(
499 num_rows,
500 num_cols,
501 num_plots,
502 plots_data,
503 sample_names,
504 args.size,
505 ylim_hist,
506 ylim_ma,
507 args.window_width,
508 )
509 plt.show()
510 fig.savefig(f"{args.output_file}", format=args.output_format, dpi=300)
511 return 0
512
513
514 if __name__ == "__main__":
515 main()