Mercurial > repos > ebi-gxa > decoupler_pathway_inference
comparison decoupler_pseudobulk.py @ 0:77d680b36e23 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 1034a450c97dcbb77871050cf0c6d3da90dac823
author | ebi-gxa |
---|---|
date | Fri, 15 Mar 2024 12:17:49 +0000 |
parents | |
children | c6787c2aee46 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:77d680b36e23 |
---|---|
1 import argparse | |
2 | |
3 import anndata | |
4 import decoupler | |
5 import pandas as pd | |
6 | |
7 | |
8 def get_pseudobulk( | |
9 adata, | |
10 sample_col, | |
11 groups_col, | |
12 layer=None, | |
13 mode="sum", | |
14 min_cells=10, | |
15 min_counts=1000, | |
16 use_raw=False, | |
17 ): | |
18 """ | |
19 >>> import scanpy as sc | |
20 >>> adata = sc.datasets.pbmc68k_reduced() | |
21 >>> adata.X = abs(adata.X).astype(int) | |
22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
23 """ | |
24 | |
25 return decoupler.get_pseudobulk( | |
26 adata, | |
27 sample_col=sample_col, | |
28 groups_col=groups_col, | |
29 layer=layer, | |
30 mode=mode, | |
31 use_raw=use_raw, | |
32 min_cells=min_cells, | |
33 min_counts=min_counts, | |
34 ) | |
35 | |
36 | |
37 def prepend_c_to_index(index_value): | |
38 if index_value and index_value[0].isdigit(): | |
39 return "C" + index_value | |
40 return index_value | |
41 | |
42 | |
43 # write results for loading into DESeq2 | |
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | |
45 """ | |
46 >>> import scanpy as sc | |
47 >>> adata = sc.datasets.pbmc68k_reduced() | |
48 >>> adata.X = abs(adata.X).astype(int) | |
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
50 >>> write_DESeq2_inputs(pseudobulk) | |
51 """ | |
52 # add / to output_dir if is not empty or if it doesn't end with / | |
53 if output_dir != "" and not output_dir.endswith("/"): | |
54 output_dir = output_dir + "/" | |
55 obs_for_deseq = pdata.obs.copy() | |
56 # replace any index starting with digits to start with C instead. | |
57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True) | |
58 # avoid dash that is read as point on R colnames. | |
59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_") | |
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | |
61 col_metadata_file = f"{output_dir}col_metadata.tsv" | |
62 # write obs to a col_metadata file | |
63 if factor_fields: | |
64 # only output the index plus the columns in factor_fields in that order | |
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) | |
66 else: | |
67 obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) | |
68 # write var to a gene_metadata file | |
69 pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) | |
70 # write the counts matrix of a specified layer to file | |
71 if layer is None: | |
72 # write the X numpy matrix transposed to file | |
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | |
74 else: | |
75 df = pd.DataFrame( | |
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | |
77 ) | |
78 df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") | |
79 | |
80 | |
81 def plot_pseudobulk_samples( | |
82 pseudobulk_data, | |
83 groupby, | |
84 figsize=(10, 10), | |
85 save_path=None, | |
86 ): | |
87 """ | |
88 >>> import scanpy as sc | |
89 >>> adata = sc.datasets.pbmc68k_reduced() | |
90 >>> adata.X = abs(adata.X).astype(int) | |
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | |
93 """ | |
94 fig = decoupler.plot_psbulk_samples( | |
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | |
96 ) | |
97 if save_path: | |
98 fig.savefig(f"{save_path}/pseudobulk_samples.png") | |
99 else: | |
100 fig.show() | |
101 | |
102 | |
103 def plot_filter_by_expr( | |
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | |
105 ): | |
106 """ | |
107 >>> import scanpy as sc | |
108 >>> adata = sc.datasets.pbmc68k_reduced() | |
109 >>> adata.X = abs(adata.X).astype(int) | |
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | |
112 """ | |
113 fig = decoupler.plot_filter_by_expr( | |
114 pseudobulk_data, | |
115 group=group, | |
116 min_count=min_count, | |
117 min_total_count=min_total_count, | |
118 return_fig=True, | |
119 ) | |
120 if save_path: | |
121 fig.savefig(f"{save_path}/filter_by_expr.png") | |
122 else: | |
123 fig.show() | |
124 | |
125 | |
126 def filter_by_expr(pdata, min_count=None, min_total_count=None): | |
127 """ | |
128 >>> import scanpy as sc | |
129 >>> adata = sc.datasets.pbmc68k_reduced() | |
130 >>> adata.X = abs(adata.X).astype(int) | |
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | |
133 """ | |
134 genes = decoupler.filter_by_expr( | |
135 pdata, min_count=min_count, min_total_count=min_total_count | |
136 ) | |
137 return pdata[:, genes].copy() | |
138 | |
139 | |
140 def check_fields(fields, adata, obs=True, context=None): | |
141 """ | |
142 >>> import scanpy as sc | |
143 >>> adata = sc.datasets.pbmc68k_reduced() | |
144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True) | |
145 """ | |
146 | |
147 legend = "" | |
148 if context: | |
149 legend = f", passed in {context}," | |
150 if obs: | |
151 if not set(fields).issubset(set(adata.obs.columns)): | |
152 raise ValueError( | |
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | |
154 ) | |
155 else: | |
156 if not set(fields).issubset(set(adata.var.columns)): | |
157 raise ValueError( | |
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | |
159 ) | |
160 | |
161 | |
162 def main(args): | |
163 # Load AnnData object from file | |
164 adata = anndata.read_h5ad(args.adata_file) | |
165 | |
166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
167 if args.adata_obs_fields_to_merge: | |
168 # first split potential groups by ":" and iterate over them | |
169 for group in args.adata_obs_fields_to_merge.split(":"): | |
170 fields = group.split(",") | |
171 check_fields(fields, adata) | |
172 adata = merge_adata_obs_fields(fields, adata) | |
173 | |
174 check_fields([args.groupby, args.sample_key], adata) | |
175 | |
176 factor_fields = None | |
177 if args.factor_fields: | |
178 factor_fields = args.factor_fields.split(",") | |
179 check_fields(factor_fields, adata) | |
180 | |
181 print(f"Using mode: {args.mode}") | |
182 # Perform pseudobulk analysis | |
183 pseudobulk_data = get_pseudobulk( | |
184 adata, | |
185 sample_col=args.sample_key, | |
186 groups_col=args.groupby, | |
187 layer=args.layer, | |
188 mode=args.mode, | |
189 use_raw=args.use_raw, | |
190 min_cells=args.min_cells, | |
191 min_counts=args.min_counts, | |
192 ) | |
193 | |
194 # Plot pseudobulk samples | |
195 plot_pseudobulk_samples( | |
196 pseudobulk_data, | |
197 args.groupby, | |
198 save_path=args.save_path, | |
199 figsize=args.plot_samples_figsize, | |
200 ) | |
201 | |
202 plot_filter_by_expr( | |
203 pseudobulk_data, | |
204 group=args.groupby, | |
205 min_count=args.min_counts, | |
206 min_total_count=args.min_total_counts, | |
207 save_path=args.save_path, | |
208 ) | |
209 | |
210 # Filter by expression if enabled | |
211 if args.filter_expr: | |
212 filtered_adata = filter_by_expr( | |
213 pseudobulk_data, | |
214 min_count=args.min_counts, | |
215 min_total_count=args.min_total_counts, | |
216 ) | |
217 | |
218 pseudobulk_data = filtered_adata | |
219 | |
220 # Save the pseudobulk data | |
221 if args.anndata_output_path: | |
222 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | |
223 | |
224 write_DESeq2_inputs( | |
225 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | |
226 ) | |
227 | |
228 | |
229 def merge_adata_obs_fields(obs_fields_to_merge, adata): | |
230 """ | |
231 Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
232 | |
233 Parameters | |
234 ---------- | |
235 obs_fields_to_merge : str | |
236 Fields in adata.obs to merge, comma separated | |
237 adata : anndata.AnnData | |
238 The AnnData object | |
239 | |
240 Returns | |
241 ------- | |
242 anndata.AnnData | |
243 The merged AnnData object | |
244 | |
245 docstring tests: | |
246 >>> import scanpy as sc | |
247 >>> ad = sc.datasets.pbmc68k_reduced() | |
248 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad) | |
249 >>> ad.obs.columns | |
250 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', | |
251 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'], | |
252 dtype='object') | |
253 """ | |
254 field_name = "_".join(obs_fields_to_merge) | |
255 for field in obs_fields_to_merge: | |
256 if field not in adata.obs.columns: | |
257 raise ValueError(f"The '{field}' column is not present in adata.obs.") | |
258 if field_name not in adata.obs.columns: | |
259 adata.obs[field_name] = adata.obs[field].astype(str) | |
260 else: | |
261 adata.obs[field_name] = ( | |
262 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | |
263 ) | |
264 return adata | |
265 | |
266 | |
267 if __name__ == "__main__": | |
268 # Create argument parser | |
269 parser = argparse.ArgumentParser( | |
270 description="Perform pseudobulk analysis on an AnnData object" | |
271 ) | |
272 | |
273 # Add arguments | |
274 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | |
275 parser.add_argument( | |
276 "-m", | |
277 "--adata_obs_fields_to_merge", | |
278 type=str, | |
279 help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", | |
280 ) | |
281 parser.add_argument( | |
282 "--groupby", | |
283 type=str, | |
284 required=True, | |
285 help="The column in adata.obs that defines the groups", | |
286 ) | |
287 parser.add_argument( | |
288 "--sample_key", | |
289 required=True, | |
290 type=str, | |
291 help="The column in adata.obs that defines the samples", | |
292 ) | |
293 # add argument for layer | |
294 parser.add_argument( | |
295 "--layer", | |
296 type=str, | |
297 default=None, | |
298 help="The name of the layer of the AnnData object to use", | |
299 ) | |
300 # add argument for mode | |
301 parser.add_argument( | |
302 "--mode", | |
303 type=str, | |
304 default="sum", | |
305 help="The mode for Decoupler pseudobulk analysis", | |
306 choices=["sum", "mean", "median"], | |
307 ) | |
308 # add boolean argument for use_raw | |
309 parser.add_argument( | |
310 "--use_raw", | |
311 action="store_true", | |
312 default=False, | |
313 help="Whether to use the raw part of the AnnData object", | |
314 ) | |
315 # add argument for min_cells | |
316 parser.add_argument( | |
317 "--min_cells", | |
318 type=int, | |
319 default=10, | |
320 help="Minimum number of cells for pseudobulk analysis", | |
321 ) | |
322 parser.add_argument( | |
323 "--save_path", type=str, help="Path to save the plot (optional)" | |
324 ) | |
325 parser.add_argument( | |
326 "--min_counts", | |
327 type=int, | |
328 help="Minimum count threshold for filtering by expression", | |
329 ) | |
330 parser.add_argument( | |
331 "--min_total_counts", | |
332 type=int, | |
333 help="Minimum total count threshold for filtering by expression", | |
334 ) | |
335 parser.add_argument( | |
336 "--anndata_output_path", | |
337 type=str, | |
338 help="Path to save the filtered AnnData object or pseudobulk data", | |
339 ) | |
340 parser.add_argument( | |
341 "--filter_expr", action="store_true", help="Enable filtering by expression" | |
342 ) | |
343 parser.add_argument( | |
344 "--factor_fields", | |
345 type=str, | |
346 help="Comma separated list of fields for the factors", | |
347 ) | |
348 parser.add_argument( | |
349 "--deseq2_output_path", | |
350 type=str, | |
351 help="Path to save the DESeq2 inputs", | |
352 required=True, | |
353 ) | |
354 parser.add_argument( | |
355 "--plot_samples_figsize", | |
356 type=int, | |
357 default=[10, 10], | |
358 nargs=2, | |
359 help="Size of the samples plot as a tuple (two arguments)", | |
360 ) | |
361 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | |
362 | |
363 # Parse the command line arguments | |
364 args = parser.parse_args() | |
365 | |
366 # Call the main function | |
367 main(args) |