Mercurial > repos > ebi-gxa > decoupler_pseudobulk
comparison decoupler_pseudobulk.py @ 0:59a7f3f83aec draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 20f4a739092bd05106d5de170523ad61d66e41fc
author | ebi-gxa |
---|---|
date | Sun, 24 Sep 2023 08:44:24 +0000 |
parents | |
children | 046d8ff974ff |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:59a7f3f83aec |
---|---|
1 import argparse | |
2 | |
3 import anndata | |
4 import decoupler | |
5 import pandas as pd | |
6 | |
7 | |
8 def get_pseudobulk( | |
9 adata, | |
10 sample_col, | |
11 groups_col, | |
12 layer=None, | |
13 mode="sum", | |
14 min_cells=10, | |
15 min_counts=1000, | |
16 use_raw=False, | |
17 ): | |
18 """ | |
19 >>> import scanpy as sc | |
20 >>> adata = sc.datasets.pbmc68k_reduced() | |
21 >>> adata.X = abs(adata.X).astype(int) | |
22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
23 """ | |
24 | |
25 return decoupler.get_pseudobulk( | |
26 adata, | |
27 sample_col=sample_col, | |
28 groups_col=groups_col, | |
29 layer=layer, | |
30 mode=mode, | |
31 use_raw=use_raw, | |
32 min_cells=min_cells, | |
33 min_counts=min_counts, | |
34 ) | |
35 | |
36 | |
37 def prepend_c_to_index(index_value): | |
38 if index_value and index_value[0].isdigit(): | |
39 return "C" + index_value | |
40 return index_value | |
41 | |
42 | |
43 # write results for loading into DESeq2 | |
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): | |
45 """ | |
46 >>> import scanpy as sc | |
47 >>> adata = sc.datasets.pbmc68k_reduced() | |
48 >>> adata.X = abs(adata.X).astype(int) | |
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
50 >>> write_DESeq2_inputs(pseudobulk) | |
51 """ | |
52 # add / to output_dir if is not empty or if it doesn't end with / | |
53 if output_dir != "" and not output_dir.endswith("/"): | |
54 output_dir = output_dir + "/" | |
55 obs_for_deseq = pdata.obs.copy() | |
56 # replace any index starting with digits to start with C instead. | |
57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True) | |
58 # avoid dash that is read as point on R colnames. | |
59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_") | |
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") | |
61 col_metadata_file = f"{output_dir}col_metadata.csv" | |
62 # write obs to a col_metadata file | |
63 if factor_fields: | |
64 # only output the index plus the columns in factor_fields in that order | |
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep=",", index=True) | |
66 else: | |
67 obs_for_deseq.to_csv(col_metadata_file, sep=",", index=True) | |
68 # write var to a gene_metadata file | |
69 pdata.var.to_csv(f"{output_dir}gene_metadata.csv", sep=",", index=True) | |
70 # write the counts matrix of a specified layer to file | |
71 if layer is None: | |
72 # write the X numpy matrix transposed to file | |
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) | |
74 else: | |
75 df = pd.DataFrame( | |
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index | |
77 ) | |
78 df.to_csv(f"{output_dir}counts_matrix.csv", sep=",", index_label="") | |
79 | |
80 | |
81 def plot_pseudobulk_samples( | |
82 pseudobulk_data, | |
83 groupby, | |
84 figsize=(10, 10), | |
85 save_path=None, | |
86 ): | |
87 """ | |
88 >>> import scanpy as sc | |
89 >>> adata = sc.datasets.pbmc68k_reduced() | |
90 >>> adata.X = abs(adata.X).astype(int) | |
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) | |
93 """ | |
94 fig = decoupler.plot_psbulk_samples( | |
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True | |
96 ) | |
97 if save_path: | |
98 fig.savefig(f"{save_path}/pseudobulk_samples.png") | |
99 else: | |
100 fig.show() | |
101 | |
102 | |
103 def plot_filter_by_expr( | |
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None | |
105 ): | |
106 """ | |
107 >>> import scanpy as sc | |
108 >>> adata = sc.datasets.pbmc68k_reduced() | |
109 >>> adata.X = abs(adata.X).astype(int) | |
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) | |
112 """ | |
113 fig = decoupler.plot_filter_by_expr( | |
114 pseudobulk_data, | |
115 group=group, | |
116 min_count=min_count, | |
117 min_total_count=min_total_count, | |
118 return_fig=True, | |
119 ) | |
120 if save_path: | |
121 fig.savefig(f"{save_path}/filter_by_expr.png") | |
122 else: | |
123 fig.show() | |
124 | |
125 | |
126 def filter_by_expr(pdata, min_count=None, min_total_count=None): | |
127 """ | |
128 >>> import scanpy as sc | |
129 >>> adata = sc.datasets.pbmc68k_reduced() | |
130 >>> adata.X = abs(adata.X).astype(int) | |
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") | |
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) | |
133 """ | |
134 genes = decoupler.filter_by_expr( | |
135 pdata, min_count=min_count, min_total_count=min_total_count | |
136 ) | |
137 return pdata[:, genes].copy() | |
138 | |
139 | |
140 def check_fields(fields, adata, obs=True, context=None): | |
141 """ | |
142 >>> import scanpy as sc | |
143 >>> adata = sc.datasets.pbmc68k_reduced() | |
144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True) | |
145 """ | |
146 | |
147 legend = "" | |
148 if context: | |
149 legend = f", passed in {context}," | |
150 if obs: | |
151 if not set(fields).issubset(set(adata.obs.columns)): | |
152 raise ValueError( | |
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" | |
154 ) | |
155 else: | |
156 if not set(fields).issubset(set(adata.var.columns)): | |
157 raise ValueError( | |
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" | |
159 ) | |
160 | |
161 | |
162 def main(args): | |
163 # Load AnnData object from file | |
164 adata = anndata.read_h5ad(args.adata_file) | |
165 | |
166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
167 if args.adata_obs_fields_to_merge: | |
168 fields = args.adata_obs_fields_to_merge.split(",") | |
169 check_fields(fields, adata) | |
170 adata = merge_adata_obs_fields(fields, adata) | |
171 | |
172 check_fields([args.groupby, args.sample_key], adata) | |
173 | |
174 factor_fields = None | |
175 if args.factor_fields: | |
176 factor_fields = args.factor_fields.split(",") | |
177 check_fields(factor_fields, adata) | |
178 | |
179 print(f"Using mode: {args.mode}") | |
180 # Perform pseudobulk analysis | |
181 pseudobulk_data = get_pseudobulk( | |
182 adata, | |
183 sample_col=args.sample_key, | |
184 groups_col=args.groupby, | |
185 layer=args.layer, | |
186 mode=args.mode, | |
187 use_raw=args.use_raw, | |
188 min_cells=args.min_cells, | |
189 min_counts=args.min_counts, | |
190 ) | |
191 | |
192 # Plot pseudobulk samples | |
193 plot_pseudobulk_samples( | |
194 pseudobulk_data, | |
195 args.groupby, | |
196 save_path=args.save_path, | |
197 figsize=args.plot_samples_figsize, | |
198 ) | |
199 | |
200 plot_filter_by_expr( | |
201 pseudobulk_data, | |
202 group=args.groupby, | |
203 min_count=args.min_counts, | |
204 min_total_count=args.min_total_counts, | |
205 save_path=args.save_path, | |
206 ) | |
207 | |
208 # Filter by expression if enabled | |
209 if args.filter_expr: | |
210 filtered_adata = filter_by_expr( | |
211 pseudobulk_data, | |
212 min_count=args.min_counts, | |
213 min_total_count=args.min_total_counts, | |
214 ) | |
215 | |
216 pseudobulk_data = filtered_adata | |
217 | |
218 # Save the pseudobulk data | |
219 if args.anndata_output_path: | |
220 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") | |
221 | |
222 write_DESeq2_inputs( | |
223 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields | |
224 ) | |
225 | |
226 | |
227 def merge_adata_obs_fields(obs_fields_to_merge, adata): | |
228 """ | |
229 Merge adata.obs fields specified in args.adata_obs_fields_to_merge | |
230 | |
231 Parameters | |
232 ---------- | |
233 obs_fields_to_merge : str | |
234 Fields in adata.obs to merge, comma separated | |
235 adata : anndata.AnnData | |
236 The AnnData object | |
237 | |
238 Returns | |
239 ------- | |
240 anndata.AnnData | |
241 The merged AnnData object | |
242 | |
243 docstring tests: | |
244 >>> import scanpy as sc | |
245 >>> ad = sc.datasets.pbmc68k_reduced() | |
246 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad) | |
247 >>> ad.obs.columns | |
248 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', | |
249 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'], | |
250 dtype='object') | |
251 """ | |
252 field_name = "_".join(obs_fields_to_merge) | |
253 for field in obs_fields_to_merge: | |
254 if field not in adata.obs.columns: | |
255 raise ValueError(f"The '{field}' column is not present in adata.obs.") | |
256 if field_name not in adata.obs.columns: | |
257 adata.obs[field_name] = adata.obs[field].astype(str) | |
258 else: | |
259 adata.obs[field_name] = ( | |
260 adata.obs[field_name] + "_" + adata.obs[field].astype(str) | |
261 ) | |
262 return adata | |
263 | |
264 | |
265 if __name__ == "__main__": | |
266 # Create argument parser | |
267 parser = argparse.ArgumentParser( | |
268 description="Perform pseudobulk analysis on an AnnData object" | |
269 ) | |
270 | |
271 # Add arguments | |
272 parser.add_argument("adata_file", type=str, help="Path to the AnnData file") | |
273 parser.add_argument( | |
274 "-m", | |
275 "--adata_obs_fields_to_merge", | |
276 type=str, | |
277 help="Fields in adata.obs to merge, comma separated", | |
278 ) | |
279 parser.add_argument( | |
280 "--groupby", | |
281 type=str, | |
282 required=True, | |
283 help="The column in adata.obs that defines the groups", | |
284 ) | |
285 parser.add_argument( | |
286 "--sample_key", | |
287 required=True, | |
288 type=str, | |
289 help="The column in adata.obs that defines the samples", | |
290 ) | |
291 # add argument for layer | |
292 parser.add_argument( | |
293 "--layer", | |
294 type=str, | |
295 default=None, | |
296 help="The name of the layer of the AnnData object to use", | |
297 ) | |
298 # add argument for mode | |
299 parser.add_argument( | |
300 "--mode", | |
301 type=str, | |
302 default="sum", | |
303 help="The mode for Decoupler pseudobulk analysis", | |
304 choices=["sum", "mean", "median"], | |
305 ) | |
306 # add boolean argument for use_raw | |
307 parser.add_argument( | |
308 "--use_raw", | |
309 action="store_true", | |
310 default=False, | |
311 help="Whether to use the raw part of the AnnData object", | |
312 ) | |
313 # add argument for min_cells | |
314 parser.add_argument( | |
315 "--min_cells", | |
316 type=int, | |
317 default=10, | |
318 help="Minimum number of cells for pseudobulk analysis", | |
319 ) | |
320 parser.add_argument( | |
321 "--save_path", type=str, help="Path to save the plot (optional)" | |
322 ) | |
323 parser.add_argument( | |
324 "--min_counts", | |
325 type=int, | |
326 help="Minimum count threshold for filtering by expression", | |
327 ) | |
328 parser.add_argument( | |
329 "--min_total_counts", | |
330 type=int, | |
331 help="Minimum total count threshold for filtering by expression", | |
332 ) | |
333 parser.add_argument( | |
334 "--anndata_output_path", | |
335 type=str, | |
336 help="Path to save the filtered AnnData object or pseudobulk data", | |
337 ) | |
338 parser.add_argument( | |
339 "--filter_expr", action="store_true", help="Enable filtering by expression" | |
340 ) | |
341 parser.add_argument( | |
342 "--factor_fields", | |
343 type=str, | |
344 help="Comma separated list of fields for the factors", | |
345 ) | |
346 parser.add_argument( | |
347 "--deseq2_output_path", | |
348 type=str, | |
349 help="Path to save the DESeq2 inputs", | |
350 required=True, | |
351 ) | |
352 parser.add_argument( | |
353 "--plot_samples_figsize", | |
354 type=int, | |
355 default=[10, 10], | |
356 nargs=2, | |
357 help="Size of the samples plot as a tuple (two arguments)", | |
358 ) | |
359 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) | |
360 | |
361 # Parse the command line arguments | |
362 args = parser.parse_args() | |
363 | |
364 # Call the main function | |
365 main(args) |