comparison decoupler_pseudobulk.py @ 0:59a7f3f83aec draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit 20f4a739092bd05106d5de170523ad61d66e41fc
author ebi-gxa
date Sun, 24 Sep 2023 08:44:24 +0000
parents
children 046d8ff974ff
comparison
equal deleted inserted replaced
-1:000000000000 0:59a7f3f83aec
1 import argparse
2
3 import anndata
4 import decoupler
5 import pandas as pd
6
7
8 def get_pseudobulk(
9 adata,
10 sample_col,
11 groups_col,
12 layer=None,
13 mode="sum",
14 min_cells=10,
15 min_counts=1000,
16 use_raw=False,
17 ):
18 """
19 >>> import scanpy as sc
20 >>> adata = sc.datasets.pbmc68k_reduced()
21 >>> adata.X = abs(adata.X).astype(int)
22 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
23 """
24
25 return decoupler.get_pseudobulk(
26 adata,
27 sample_col=sample_col,
28 groups_col=groups_col,
29 layer=layer,
30 mode=mode,
31 use_raw=use_raw,
32 min_cells=min_cells,
33 min_counts=min_counts,
34 )
35
36
37 def prepend_c_to_index(index_value):
38 if index_value and index_value[0].isdigit():
39 return "C" + index_value
40 return index_value
41
42
43 # write results for loading into DESeq2
44 def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None):
45 """
46 >>> import scanpy as sc
47 >>> adata = sc.datasets.pbmc68k_reduced()
48 >>> adata.X = abs(adata.X).astype(int)
49 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
50 >>> write_DESeq2_inputs(pseudobulk)
51 """
52 # add / to output_dir if is not empty or if it doesn't end with /
53 if output_dir != "" and not output_dir.endswith("/"):
54 output_dir = output_dir + "/"
55 obs_for_deseq = pdata.obs.copy()
56 # replace any index starting with digits to start with C instead.
57 obs_for_deseq.rename(index=prepend_c_to_index, inplace=True)
58 # avoid dash that is read as point on R colnames.
59 obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_")
60 obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_")
61 col_metadata_file = f"{output_dir}col_metadata.csv"
62 # write obs to a col_metadata file
63 if factor_fields:
64 # only output the index plus the columns in factor_fields in that order
65 obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep=",", index=True)
66 else:
67 obs_for_deseq.to_csv(col_metadata_file, sep=",", index=True)
68 # write var to a gene_metadata file
69 pdata.var.to_csv(f"{output_dir}gene_metadata.csv", sep=",", index=True)
70 # write the counts matrix of a specified layer to file
71 if layer is None:
72 # write the X numpy matrix transposed to file
73 df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index)
74 else:
75 df = pd.DataFrame(
76 pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index
77 )
78 df.to_csv(f"{output_dir}counts_matrix.csv", sep=",", index_label="")
79
80
81 def plot_pseudobulk_samples(
82 pseudobulk_data,
83 groupby,
84 figsize=(10, 10),
85 save_path=None,
86 ):
87 """
88 >>> import scanpy as sc
89 >>> adata = sc.datasets.pbmc68k_reduced()
90 >>> adata.X = abs(adata.X).astype(int)
91 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
92 >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10))
93 """
94 fig = decoupler.plot_psbulk_samples(
95 pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
96 )
97 if save_path:
98 fig.savefig(f"{save_path}/pseudobulk_samples.png")
99 else:
100 fig.show()
101
102
103 def plot_filter_by_expr(
104 pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None
105 ):
106 """
107 >>> import scanpy as sc
108 >>> adata = sc.datasets.pbmc68k_reduced()
109 >>> adata.X = abs(adata.X).astype(int)
110 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
111 >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200)
112 """
113 fig = decoupler.plot_filter_by_expr(
114 pseudobulk_data,
115 group=group,
116 min_count=min_count,
117 min_total_count=min_total_count,
118 return_fig=True,
119 )
120 if save_path:
121 fig.savefig(f"{save_path}/filter_by_expr.png")
122 else:
123 fig.show()
124
125
126 def filter_by_expr(pdata, min_count=None, min_total_count=None):
127 """
128 >>> import scanpy as sc
129 >>> adata = sc.datasets.pbmc68k_reduced()
130 >>> adata.X = abs(adata.X).astype(int)
131 >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
132 >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200)
133 """
134 genes = decoupler.filter_by_expr(
135 pdata, min_count=min_count, min_total_count=min_total_count
136 )
137 return pdata[:, genes].copy()
138
139
140 def check_fields(fields, adata, obs=True, context=None):
141 """
142 >>> import scanpy as sc
143 >>> adata = sc.datasets.pbmc68k_reduced()
144 >>> check_fields(["bulk_labels", "louvain"], adata, obs=True)
145 """
146
147 legend = ""
148 if context:
149 legend = f", passed in {context},"
150 if obs:
151 if not set(fields).issubset(set(adata.obs.columns)):
152 raise ValueError(
153 f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}"
154 )
155 else:
156 if not set(fields).issubset(set(adata.var.columns)):
157 raise ValueError(
158 f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}"
159 )
160
161
162 def main(args):
163 # Load AnnData object from file
164 adata = anndata.read_h5ad(args.adata_file)
165
166 # Merge adata.obs fields specified in args.adata_obs_fields_to_merge
167 if args.adata_obs_fields_to_merge:
168 fields = args.adata_obs_fields_to_merge.split(",")
169 check_fields(fields, adata)
170 adata = merge_adata_obs_fields(fields, adata)
171
172 check_fields([args.groupby, args.sample_key], adata)
173
174 factor_fields = None
175 if args.factor_fields:
176 factor_fields = args.factor_fields.split(",")
177 check_fields(factor_fields, adata)
178
179 print(f"Using mode: {args.mode}")
180 # Perform pseudobulk analysis
181 pseudobulk_data = get_pseudobulk(
182 adata,
183 sample_col=args.sample_key,
184 groups_col=args.groupby,
185 layer=args.layer,
186 mode=args.mode,
187 use_raw=args.use_raw,
188 min_cells=args.min_cells,
189 min_counts=args.min_counts,
190 )
191
192 # Plot pseudobulk samples
193 plot_pseudobulk_samples(
194 pseudobulk_data,
195 args.groupby,
196 save_path=args.save_path,
197 figsize=args.plot_samples_figsize,
198 )
199
200 plot_filter_by_expr(
201 pseudobulk_data,
202 group=args.groupby,
203 min_count=args.min_counts,
204 min_total_count=args.min_total_counts,
205 save_path=args.save_path,
206 )
207
208 # Filter by expression if enabled
209 if args.filter_expr:
210 filtered_adata = filter_by_expr(
211 pseudobulk_data,
212 min_count=args.min_counts,
213 min_total_count=args.min_total_counts,
214 )
215
216 pseudobulk_data = filtered_adata
217
218 # Save the pseudobulk data
219 if args.anndata_output_path:
220 pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip")
221
222 write_DESeq2_inputs(
223 pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields
224 )
225
226
227 def merge_adata_obs_fields(obs_fields_to_merge, adata):
228 """
229 Merge adata.obs fields specified in args.adata_obs_fields_to_merge
230
231 Parameters
232 ----------
233 obs_fields_to_merge : str
234 Fields in adata.obs to merge, comma separated
235 adata : anndata.AnnData
236 The AnnData object
237
238 Returns
239 -------
240 anndata.AnnData
241 The merged AnnData object
242
243 docstring tests:
244 >>> import scanpy as sc
245 >>> ad = sc.datasets.pbmc68k_reduced()
246 >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad)
247 >>> ad.obs.columns
248 Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score',
249 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'],
250 dtype='object')
251 """
252 field_name = "_".join(obs_fields_to_merge)
253 for field in obs_fields_to_merge:
254 if field not in adata.obs.columns:
255 raise ValueError(f"The '{field}' column is not present in adata.obs.")
256 if field_name not in adata.obs.columns:
257 adata.obs[field_name] = adata.obs[field].astype(str)
258 else:
259 adata.obs[field_name] = (
260 adata.obs[field_name] + "_" + adata.obs[field].astype(str)
261 )
262 return adata
263
264
265 if __name__ == "__main__":
266 # Create argument parser
267 parser = argparse.ArgumentParser(
268 description="Perform pseudobulk analysis on an AnnData object"
269 )
270
271 # Add arguments
272 parser.add_argument("adata_file", type=str, help="Path to the AnnData file")
273 parser.add_argument(
274 "-m",
275 "--adata_obs_fields_to_merge",
276 type=str,
277 help="Fields in adata.obs to merge, comma separated",
278 )
279 parser.add_argument(
280 "--groupby",
281 type=str,
282 required=True,
283 help="The column in adata.obs that defines the groups",
284 )
285 parser.add_argument(
286 "--sample_key",
287 required=True,
288 type=str,
289 help="The column in adata.obs that defines the samples",
290 )
291 # add argument for layer
292 parser.add_argument(
293 "--layer",
294 type=str,
295 default=None,
296 help="The name of the layer of the AnnData object to use",
297 )
298 # add argument for mode
299 parser.add_argument(
300 "--mode",
301 type=str,
302 default="sum",
303 help="The mode for Decoupler pseudobulk analysis",
304 choices=["sum", "mean", "median"],
305 )
306 # add boolean argument for use_raw
307 parser.add_argument(
308 "--use_raw",
309 action="store_true",
310 default=False,
311 help="Whether to use the raw part of the AnnData object",
312 )
313 # add argument for min_cells
314 parser.add_argument(
315 "--min_cells",
316 type=int,
317 default=10,
318 help="Minimum number of cells for pseudobulk analysis",
319 )
320 parser.add_argument(
321 "--save_path", type=str, help="Path to save the plot (optional)"
322 )
323 parser.add_argument(
324 "--min_counts",
325 type=int,
326 help="Minimum count threshold for filtering by expression",
327 )
328 parser.add_argument(
329 "--min_total_counts",
330 type=int,
331 help="Minimum total count threshold for filtering by expression",
332 )
333 parser.add_argument(
334 "--anndata_output_path",
335 type=str,
336 help="Path to save the filtered AnnData object or pseudobulk data",
337 )
338 parser.add_argument(
339 "--filter_expr", action="store_true", help="Enable filtering by expression"
340 )
341 parser.add_argument(
342 "--factor_fields",
343 type=str,
344 help="Comma separated list of fields for the factors",
345 )
346 parser.add_argument(
347 "--deseq2_output_path",
348 type=str,
349 help="Path to save the DESeq2 inputs",
350 required=True,
351 )
352 parser.add_argument(
353 "--plot_samples_figsize",
354 type=int,
355 default=[10, 10],
356 nargs=2,
357 help="Size of the samples plot as a tuple (two arguments)",
358 )
359 parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2)
360
361 # Parse the command line arguments
362 args = parser.parse_args()
363
364 # Call the main function
365 main(args)