Mercurial > repos > iuc > maplot
comparison maplot.py @ 0:e9212adafd7a draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc commit d5065f0bdf2d38c2344d96d68537223c1096daab
author | iuc |
---|---|
date | Thu, 15 May 2025 12:55:13 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9212adafd7a |
---|---|
1 import argparse | |
2 from typing import Dict, List, Tuple | |
3 | |
4 import matplotlib.pyplot as plt | |
5 import numpy as np | |
6 import pandas as pd | |
7 import plotly.graph_objects as go | |
8 import plotly.io as pio | |
9 import plotly.subplots as sp | |
10 import statsmodels.api as sm # to build a LOWESS model | |
11 from scipy.stats import gaussian_kde | |
12 | |
13 | |
14 # subplot titles | |
15 def make_subplot_titles(sample_names: List[str]) -> List[str]: | |
16 """Generates subplot titles for the MA plot. | |
17 | |
18 Args: | |
19 sample_names (list): List of sample names. | |
20 | |
21 Returns: | |
22 list: List of subplot titles. | |
23 """ | |
24 subplot_titles = [] | |
25 num_samples = len(sample_names) | |
26 for i in range(num_samples): | |
27 for j in range(num_samples): | |
28 if i == j: | |
29 subplot_titles.append(f"{sample_names[i]}") | |
30 else: | |
31 subplot_titles.append(f"{sample_names[i]} vs. {sample_names[j]}") | |
32 return subplot_titles | |
33 | |
34 | |
35 def densities(x: np.ndarray, y: np.ndarray) -> np.ndarray: | |
36 """Calculates the density of points for a scatter plot. | |
37 | |
38 Args: | |
39 x (array-like): X-axis values. | |
40 y (array-like): Y-axis values. | |
41 | |
42 Returns: | |
43 array: Density values for the points. | |
44 """ | |
45 values = np.vstack([x, y]) | |
46 return gaussian_kde(values)(values) | |
47 | |
48 | |
49 def movingaverage(data: np.ndarray, window_width: int) -> np.ndarray: | |
50 """Calculates the moving average of the data. | |
51 | |
52 Args: | |
53 data (array-like): Input data. | |
54 window_width (int): Width of the moving window. | |
55 | |
56 Returns: | |
57 array: Moving average values. | |
58 """ | |
59 cumsum_vec = np.cumsum(np.insert(data, 0, 0)) | |
60 ma_vec = (cumsum_vec[window_width:] - cumsum_vec[:-window_width]) / window_width | |
61 return ma_vec | |
62 | |
63 | |
64 def update_max(current: float, values: np.ndarray) -> float: | |
65 """Updates the maximum value. | |
66 | |
67 Args: | |
68 current (float): Current maximum value. | |
69 values (array-like): Array of values to compare. | |
70 | |
71 Returns: | |
72 float: Updated maximum value. | |
73 """ | |
74 return max(current, np.max(values)) | |
75 | |
76 | |
77 def get_indices( | |
78 num_samples: int, num_cols: int, plot_num: int | |
79 ) -> Tuple[int, int, int, int]: | |
80 """Calculates the indices for subplot placement. | |
81 | |
82 Args: | |
83 num_samples (int): Number of samples. | |
84 num_cols (int): Number of columns in the subplot grid. | |
85 plot_num (int): Plot number. | |
86 | |
87 Returns: | |
88 tuple: Indices for subplot placement (i, j, col, row). | |
89 """ | |
90 i = plot_num // num_samples | |
91 j = plot_num % num_samples | |
92 col = plot_num % num_cols + 1 | |
93 row = plot_num // num_cols + 1 | |
94 return i, j, col, row | |
95 | |
96 | |
97 def create_subplot_data( | |
98 frac: float, | |
99 it: int, | |
100 num_bins: int, | |
101 window_width: int, | |
102 samples: pd.DataFrame, | |
103 i: int, | |
104 j: int, | |
105 ) -> Dict: | |
106 """Creates data for a single subplot. | |
107 | |
108 Args: | |
109 frac (float): LOESS smoothing parameter. | |
110 it (int): Number of iterations for LOESS smoothing. | |
111 num_bins (int): Number of bins for histogram. | |
112 window_width (int): Window width for moving average. | |
113 samples (DataFrame): DataFrame containing sample data. | |
114 i (int): Index of the first sample. | |
115 j (int): Index of the second sample. | |
116 | |
117 Returns: | |
118 dict: Data for the subplot. | |
119 """ | |
120 subplot_data = {} | |
121 subplot_data["mean"] = np.log(samples.iloc[:, [i, j]].mean(axis=1)) | |
122 if i == j: | |
123 counts, bins = np.histogram(subplot_data["mean"], bins=num_bins) | |
124 subplot_data["bins"] = bins | |
125 subplot_data["counts"] = counts | |
126 subplot_data["counts_smoothed"] = movingaverage(counts, window_width) | |
127 subplot_data["max_counts"] = np.max(counts) | |
128 else: | |
129 subplot_data["log_fold_change"] = np.log2( | |
130 samples.iloc[:, i] / samples.iloc[:, j] | |
131 ) | |
132 subplot_data["max_log_fold_change"] = np.max(subplot_data["log_fold_change"]) | |
133 subplot_data["densities"] = densities( | |
134 subplot_data["mean"], subplot_data["log_fold_change"] | |
135 ) | |
136 subplot_data["regression"] = sm.nonparametric.lowess( | |
137 subplot_data["log_fold_change"], subplot_data["mean"], frac=frac, it=it | |
138 ) | |
139 return subplot_data | |
140 | |
141 | |
142 def create_plot_data( | |
143 frac: float, | |
144 it: int, | |
145 num_bins: int, | |
146 window_width: int, | |
147 samples: pd.DataFrame, | |
148 num_samples: int, | |
149 num_plots: int, | |
150 num_cols: int, | |
151 ) -> List[Dict]: | |
152 """Creates data for all subplots. | |
153 | |
154 Args: | |
155 frac (float): LOESS smoothing parameter. | |
156 it (int): Number of iterations for LOESS smoothing. | |
157 num_bins (int): Number of bins for histogram. | |
158 window_width (int): Window width for moving average. | |
159 samples (DataFrame): DataFrame containing sample data. | |
160 num_samples (int): Number of samples. | |
161 num_plots (int): Number of plots. | |
162 num_cols (int): Number of columns in the subplot grid. | |
163 | |
164 Returns: | |
165 list: List of data for each subplot. | |
166 """ | |
167 plots_data = [] | |
168 for plot_num in range(num_plots): | |
169 i, j, _, _ = get_indices(num_samples, num_cols, plot_num) | |
170 subplot_data = create_subplot_data( | |
171 frac, it, num_bins, window_width, samples, i, j | |
172 ) | |
173 plots_data.append(subplot_data) | |
174 return plots_data | |
175 | |
176 | |
177 def ma_plots_plotly( | |
178 num_rows: int, | |
179 num_cols: int, | |
180 num_plots: int, | |
181 plots_data: List[Dict], | |
182 sample_names: List[str], | |
183 size: int, | |
184 ylim_hist: float, | |
185 ylim_ma: float, | |
186 features: np.ndarray, | |
187 ) -> go.Figure: | |
188 """Generates MA plots using Plotly. | |
189 | |
190 Args: | |
191 num_rows (int): Number of rows in the subplot grid. | |
192 num_cols (int): Number of columns in the subplot grid. | |
193 num_plots (int): Number of plots. | |
194 plots_data (list): List of data for each subplot. | |
195 sample_names (list): List of sample names. | |
196 size (int): Size of the plot. | |
197 ylim_hist (float): Y-axis limit for histograms. | |
198 ylim_ma (float): Y-axis limit for MA plots. | |
199 features (array-like): Feature names. | |
200 | |
201 Returns: | |
202 Figure: Plotly figure object. | |
203 """ | |
204 fig = sp.make_subplots( | |
205 rows=num_rows, | |
206 cols=num_cols, | |
207 shared_xaxes="all", | |
208 subplot_titles=make_subplot_titles(sample_names), | |
209 ) | |
210 | |
211 for plot_num in range(num_plots): | |
212 i, j, col, row = get_indices(len(sample_names), num_cols, plot_num) | |
213 subplot_data = plots_data[plot_num] | |
214 | |
215 mean = subplot_data["mean"] | |
216 | |
217 if i == j: | |
218 # Plot histogram on the diagonal | |
219 hist_bar = go.Bar( | |
220 x=subplot_data["bins"], | |
221 y=subplot_data["counts"], | |
222 ) | |
223 fig.add_trace(hist_bar, row=row, col=col) | |
224 | |
225 hist_line = go.Scatter( | |
226 x=subplot_data["bins"], | |
227 y=subplot_data["counts_smoothed"], | |
228 marker=dict( | |
229 color="red", | |
230 ), | |
231 ) | |
232 fig.add_trace(hist_line, row=row, col=col) | |
233 fig.update_yaxes( | |
234 title_text="Counts", | |
235 range=[0, ylim_hist], | |
236 matches="y1", | |
237 showticklabels=True, | |
238 row=row, | |
239 col=col, | |
240 ) | |
241 else: | |
242 log_fold_change = subplot_data["log_fold_change"] | |
243 scatter = go.Scatter( | |
244 x=mean, | |
245 y=log_fold_change, | |
246 mode="markers", | |
247 marker=dict( | |
248 color=subplot_data["densities"], symbol="circle", colorscale="jet" | |
249 ), | |
250 name=f"{sample_names[i]} vs {sample_names[j]}", | |
251 text=features, | |
252 hovertemplate="<b>%{text}</b><br>Log Mean: %{x}<br>Log2 Fold Change: %{y}<extra></extra>", | |
253 ) | |
254 fig.add_trace(scatter, row=row, col=col) | |
255 | |
256 regression = subplot_data["regression"] | |
257 line = go.Scatter( | |
258 x=regression[:, 0], | |
259 y=regression[:, 1], | |
260 mode="lines", | |
261 line=dict(color="red"), | |
262 name=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", | |
263 ) | |
264 fig.add_trace(line, row=row, col=col) | |
265 | |
266 fig.update_yaxes( | |
267 title_text="Log2 Fold Change", | |
268 range=[-ylim_ma, ylim_ma], | |
269 matches="y2", | |
270 showticklabels=True, | |
271 row=row, | |
272 col=col, | |
273 ) | |
274 fig.update_xaxes( | |
275 title_text="Log Mean Intensity", showticklabels=True, row=row, col=col | |
276 ) | |
277 | |
278 # Update layout for the entire figure | |
279 fig.update_layout( | |
280 height=size * num_rows, | |
281 width=size * num_cols, | |
282 showlegend=False, | |
283 template="simple_white", # Apply the 'plotly_white' template | |
284 ) | |
285 return fig | |
286 | |
287 | |
288 def ma_plots_matplotlib( | |
289 num_rows: int, | |
290 num_cols: int, | |
291 num_plots: int, | |
292 pots_data: List[Dict], | |
293 sample_names: List[str], | |
294 size: int, | |
295 ylim_hist: float, | |
296 ylim_ma: float, | |
297 window_width: int, | |
298 ) -> plt.Figure: | |
299 """Generates MA plots using Matplotlib. | |
300 | |
301 Args: | |
302 num_rows (int): Number of rows in the subplot grid. | |
303 num_cols (int): Number of columns in the subplot grid. | |
304 num_plots (int): Number of plots. | |
305 pots_data (list): List of data for each subplot. | |
306 sample_names (list): List of sample names. | |
307 size (int): Size of the plot. | |
308 ylim_hist (float): Y-axis limit for histograms. | |
309 ylim_ma (float): Y-axis limit for MA plots. | |
310 window_width (int): Window width for moving average. | |
311 | |
312 Returns: | |
313 Figure: Matplotlib figure object. | |
314 """ | |
315 subplot_titles = make_subplot_titles(sample_names) | |
316 fig, axes = plt.subplots( | |
317 num_rows, | |
318 num_cols, | |
319 figsize=(size * num_cols / 100, size * num_rows / 100), | |
320 dpi=300, | |
321 sharex="all", | |
322 ) | |
323 axes = axes.flatten() | |
324 | |
325 for plot_num in range(num_plots): | |
326 i, j, _, _ = get_indices(len(sample_names), num_cols, plot_num) | |
327 subplot_data = pots_data[plot_num] | |
328 | |
329 mean = subplot_data["mean"] | |
330 | |
331 ax = axes[plot_num] | |
332 | |
333 if i == j: | |
334 # Plot histogram on the diagonal | |
335 ax.bar( | |
336 subplot_data["bins"][:-1], | |
337 subplot_data["counts"], | |
338 width=np.diff(subplot_data["bins"]), | |
339 edgecolor="black", | |
340 align="edge", | |
341 ) | |
342 | |
343 # Plot moving average line | |
344 ax.plot( | |
345 subplot_data["bins"][window_width // 2: -window_width // 2], | |
346 subplot_data["counts_smoothed"], | |
347 color="red", | |
348 ) | |
349 | |
350 ax.set_ylabel("Counts") | |
351 ax.set_ylim(0, ylim_hist) | |
352 else: | |
353 # Scatter plot | |
354 ax.scatter( | |
355 mean, | |
356 subplot_data["log_fold_change"], | |
357 c=subplot_data["densities"], | |
358 cmap="jet", | |
359 edgecolor="black", | |
360 label=f"{sample_names[i]} vs {sample_names[j]}", | |
361 ) | |
362 | |
363 # Regression line | |
364 regression = subplot_data["regression"] | |
365 ax.plot( | |
366 regression[:, 0], | |
367 regression[:, 1], | |
368 color="red", | |
369 label=f"LOWESS {sample_names[i]} vs. {sample_names[j]}", | |
370 ) | |
371 | |
372 ax.set_ylabel("Log2 Fold Change") | |
373 ax.set_ylim(-ylim_ma, ylim_ma) | |
374 | |
375 ax.set_xlabel("Log Mean Intensity") | |
376 ax.tick_params(labelbottom=True) # Force showing x-tick labels | |
377 ax.set_title(subplot_titles[plot_num]) # Add subplot title | |
378 | |
379 # Adjust layout | |
380 plt.tight_layout() | |
381 return fig | |
382 | |
383 | |
384 def main(): | |
385 """Main function to generate MA plots.""" | |
386 parser = argparse.ArgumentParser(description="Generate MA plots.") | |
387 parser.add_argument("--file_path", type=str, help="Path to the input CSV file") | |
388 parser.add_argument("--file_extension", type=str, help="File extension") | |
389 parser.add_argument( | |
390 "--frac", type=float, default=4 / 5, help="LOESS smoothing parameter" | |
391 ) | |
392 parser.add_argument( | |
393 "--it", type=int, default=5, help="Number of iterations for LOESS smoothing" | |
394 ) | |
395 parser.add_argument( | |
396 "--num_bins", type=int, default=100, help="Number of bins for histogram" | |
397 ) | |
398 parser.add_argument( | |
399 "--window_width", type=int, default=5, help="Window width for moving average" | |
400 ) | |
401 parser.add_argument("--size", type=int, default=500, help="Size of the plot") | |
402 parser.add_argument( | |
403 "--scale", type=int, default=3, help="Scale factor for the plot" | |
404 ) | |
405 parser.add_argument( | |
406 "--y_scale_factor", type=float, default=1.1, help="Y-axis scale factor" | |
407 ) | |
408 parser.add_argument( | |
409 "--max_num_cols", | |
410 type=int, | |
411 default=100, | |
412 help="Maximum number of columns in the plot", | |
413 ) | |
414 parser.add_argument( | |
415 "--interactive", | |
416 action="store_true", | |
417 help="Generate interactive plot using Plotly", | |
418 ) | |
419 parser.add_argument( | |
420 "--output_format", | |
421 type=str, | |
422 default="pdf", | |
423 choices=["pdf", "png", "html"], | |
424 help="Output format for the plot", | |
425 ) | |
426 parser.add_argument( | |
427 "--output_file", | |
428 type=str, | |
429 default="ma_plot", | |
430 help="Output file name without extension", | |
431 ) | |
432 | |
433 args = parser.parse_args() | |
434 | |
435 # Load the data | |
436 file_extension = args.file_extension.lower() | |
437 if file_extension == "csv": | |
438 data = pd.read_csv(args.file_path) | |
439 elif file_extension in ["txt", "tsv", "tabular"]: | |
440 data = pd.read_csv(args.file_path, sep="\t") | |
441 elif file_extension == "parquet": | |
442 data = pd.read_parquet(args.file_path) | |
443 else: | |
444 raise ValueError(f"Unsupported file format: {file_extension}") | |
445 | |
446 features = data.iloc[:, 0] # Assuming the first column is the feature names | |
447 samples = data.iloc[:, 1:] # and the rest are samples | |
448 | |
449 # Create a subplot figure | |
450 num_samples = samples.shape[1] | |
451 sample_names = samples.columns | |
452 num_plots = num_samples**2 | |
453 num_cols = min(num_samples, args.max_num_cols) | |
454 num_rows = int(np.ceil(num_plots / num_cols)) | |
455 | |
456 plots_data = create_plot_data( | |
457 args.frac, | |
458 args.it, | |
459 args.num_bins, | |
460 args.window_width, | |
461 samples, | |
462 num_samples, | |
463 num_plots, | |
464 num_cols, | |
465 ) | |
466 | |
467 count_max = np.max([x.get("max_counts", 0) for x in plots_data]) | |
468 log_fold_change_max = np.max([x.get("max_log_fold_change", 0) for x in plots_data]) | |
469 | |
470 ylim_hist = count_max * args.y_scale_factor | |
471 ylim_ma = log_fold_change_max * args.y_scale_factor | |
472 | |
473 if args.interactive: | |
474 fig = ma_plots_plotly( | |
475 num_rows, | |
476 num_cols, | |
477 num_plots, | |
478 plots_data, | |
479 sample_names, | |
480 args.size, | |
481 ylim_hist, | |
482 ylim_ma, | |
483 features, | |
484 ) | |
485 fig.show() | |
486 if args.output_format == "html": | |
487 fig.write_html(f"{args.output_file}") | |
488 else: | |
489 pio.write_image( | |
490 fig, | |
491 f"{args.output_file}", | |
492 format=args.output_format, | |
493 width=args.size * num_cols, | |
494 height=args.size * num_rows, | |
495 scale=args.scale, | |
496 ) | |
497 else: | |
498 fig = ma_plots_matplotlib( | |
499 num_rows, | |
500 num_cols, | |
501 num_plots, | |
502 plots_data, | |
503 sample_names, | |
504 args.size, | |
505 ylim_hist, | |
506 ylim_ma, | |
507 args.window_width, | |
508 ) | |
509 plt.show() | |
510 fig.savefig(f"{args.output_file}", format=args.output_format, dpi=300) | |
511 return 0 | |
512 | |
513 | |
514 if __name__ == "__main__": | |
515 main() |