Mercurial > repos > ebi-gxa > score_genes_aucell
comparison decoupler_aucell_score.py @ 5:c9aaac87c583 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author | ebi-gxa |
---|---|
date | Mon, 15 Jul 2024 10:56:46 +0000 |
parents | 515ac51db6e5 |
children |
comparison
equal
deleted
inserted
replaced
4:515ac51db6e5 | 5:c9aaac87c583 |
---|---|
1 import argparse | 1 import argparse |
2 import os | |
3 import tempfile | |
4 | 2 |
5 import anndata | 3 import anndata |
6 import decoupler as dc | 4 import decoupler as dc |
5 import numba as nb | |
7 import pandas as pd | 6 import pandas as pd |
8 import numba as nb | |
9 | 7 |
10 | 8 |
11 def read_gmt_long(gmt_file): | 9 def read_gmt_long(gmt_file): |
12 """ | 10 r""" |
13 Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method. | 11 Reads a GMT file and produce a Pandas DataFrame in long format, ready to |
12 be passed to the AUCell method. | |
14 | 13 |
15 Parameters | 14 Parameters |
16 ---------- | 15 ---------- |
17 gmt_file : str | 16 gmt_file : str |
18 Path to the GMT file. | 17 Path to the GMT file. |
19 | 18 |
20 Returns | 19 Returns |
21 ------- | 20 ------- |
22 pd.DataFrame | 21 pd.DataFrame |
23 A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes". | 22 A DataFrame with the gene sets. Each row represents a gene set to gene |
24 >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" | 23 assignment, and the columns are "gene_set_name" and "genes". |
25 >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" | 24 >>> import os |
25 >>> import tempfile | |
26 >>> line = "HALLMARK_NOTCH_SIGNALING\ | |
27 ... \thttp://www.gsea-msigdb.org/\ | |
28 ... gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\ | |
29 ... \tJAG1\tNOTCH3\tNOTCH2\tAPH1A\tHES1\tCCND1\ | |
30 ... \tFZD1\tPSEN2\tFZD7\tDTX1\tDLL1\tFZD5\tMAML2\ | |
31 ... \tNOTCH1\tPSENEN\tWNT5A\tCUL1\tWNT2\tDTX4\ | |
32 ... \tSAP30\tPPARD\tKAT2A\tHEYL\tSKP1\tRBX1\tTCF7L2\ | |
33 ... \tARRB1\tLFNG\tPRKCA\tDTX2\tST3GAL6\tFBXW11\n" | |
34 >>> line2 = "HALLMARK_APICAL_SURFACE\ | |
35 ... \thttp://www.gsea-msigdb.org/\ | |
36 ... gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\ | |
37 ... \tB4GALT1\tRHCG\tMAL\tLYPD3\tPKHD1\tATP6V0A4\ | |
38 ... \tCRYBG1\tSHROOM2\tSRPX\tMDGA1\tTMEM8B\tTHY1\ | |
39 ... \tPCSK9\tEPHB4\tDCBLD2\tGHRL\tLYN\tGAS1\tFLOT2\ | |
40 ... \tPLAUR\tAKAP7\tATP8B1\tEFNA5\tSLC34A3\tAPP\ | |
41 ... \tGSTM3\tHSPB1\tSLC2A4\tIL2RB\tRTN4RL1\tNCOA6\ | |
42 ... \tSULF2\tADAM10\tBRCA1\tGATA3\tAFAP1L2\tIL2RG\ | |
43 ... \tCD160\tADIPOR2\tSLC22A12\tNTNG1\tSCUBE1\tCX3CL1\ | |
44 ... \tCROCC\n" | |
26 >>> temp_dir = tempfile.gettempdir() | 45 >>> temp_dir = tempfile.gettempdir() |
27 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") | 46 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") |
28 >>> with open(temp_gmt, "w") as f: | 47 >>> with open(temp_gmt, "w") as f: |
29 ... f.write(line) | 48 ... f.write(line) |
30 ... f.write(line2) | 49 ... f.write(line2) |
34 >>> df.shape[0] | 53 >>> df.shape[0] |
35 76 | 54 76 |
36 >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist()) | 55 >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist()) |
37 44 | 56 44 |
38 """ | 57 """ |
39 # Create a list of dictionaries, where each dictionary represents a gene set | 58 # Create a list of dictionaries, where each dictionary represents a |
59 # gene set | |
40 gene_sets = {} | 60 gene_sets = {} |
41 | 61 |
42 # Read the GMT file into a list of lines | 62 # Read the GMT file into a list of lines |
43 with open(gmt_file, "r") as f: | 63 with open(gmt_file, "r") as f: |
44 while True: | 64 while True: |
45 line = f.readline() | 65 line = f.readline() |
46 if not line: | 66 if not line: |
47 break | 67 break |
48 fields = line.strip().split("\t") | 68 fields = line.strip().split("\t") |
49 gene_sets[fields[0]]= fields[2:] | 69 gene_sets[fields[0]] = fields[2:] |
50 | 70 |
51 return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) | 71 return pd.concat( |
52 | 72 pd.DataFrame({"gene_set": k, "gene": v}) for k, v in gene_sets.items() |
53 | 73 ) |
54 def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None): | 74 |
75 | |
76 def score_genes_aucell_mt( | |
77 adata: anndata.AnnData, | |
78 gene_set_gene: pd.DataFrame, | |
79 use_raw=False, | |
80 min_n_genes=5, | |
81 var_gene_symbols_field=None, | |
82 ): | |
55 """Score genes using Aucell. | 83 """Score genes using Aucell. |
56 | 84 |
57 Parameters | 85 Parameters |
58 ---------- | 86 ---------- |
59 adata : anndata.AnnData | 87 adata : anndata.AnnData |
60 gene_set_gene: pd.DataFrame with columns gene_set and gene | 88 gene_set_gene: pd.DataFrame with columns gene_set and gene |
61 use_raw : bool, optional, False by default. | 89 use_raw : bool, optional, False by default. |
62 min_n_genes : int, optional, 5 by default. | 90 min_n_genes : int, optional, 5 by default. |
63 var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored | 91 var_gene_symbols_field : str, optional, None by default. The field in var |
92 where gene symbols are stored | |
64 | 93 |
65 >>> import scanpy as sc | 94 >>> import scanpy as sc |
66 >>> import decoupler as dc | 95 >>> import decoupler as dc |
67 >>> adata = sc.datasets.pbmc68k_reduced() | 96 >>> adata = sc.datasets.pbmc68k_reduced() |
68 >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() | 97 >>> r_gene_list = adata.var[ |
69 >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist() | 98 ... adata.var.index.str.startswith("RP")].index.tolist() |
99 >>> m_gene_list = adata.var[ | |
100 ... adata.var.index.str.startswith("M")].index.tolist() | |
70 >>> gene_set = {} | 101 >>> gene_set = {} |
71 >>> gene_set["m"] = m_gene_list | 102 >>> gene_set["m"] = m_gene_list |
72 >>> gene_set["r"] = r_gene_list | 103 >>> gene_set["r"] = r_gene_list |
73 >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items()) | 104 >>> gene_set_df = pd.concat( |
105 ... pd.DataFrame( | |
106 ... {'gene_set':k, 'gene':v} | |
107 ... ) for k, v in gene_set.items()) | |
74 >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False) | 108 >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False) |
75 >>> "AUCell_m" in adata.obs.columns | 109 >>> "AUCell_m" in adata.obs.columns |
76 True | 110 True |
77 >>> "AUCell_r" in adata.obs.columns | 111 >>> "AUCell_r" in adata.obs.columns |
78 True | 112 True |
79 """ | 113 """ |
80 | 114 |
81 # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols | 115 # if var_gene_symbols_fiels is provided, transform gene_set_gene df so |
116 # that gene contains gene ids instead of gene symbols | |
82 if var_gene_symbols_field: | 117 if var_gene_symbols_field: |
83 # merge the index of var to gene_set_gene df based on var_gene_symbols_field | 118 # merge the index of var to gene_set_gene df based on |
119 # var_gene_symbols_field | |
84 var_id_symbols = adata.var[[var_gene_symbols_field]] | 120 var_id_symbols = adata.var[[var_gene_symbols_field]] |
85 var_id_symbols['gene_id'] = var_id_symbols.index | 121 var_id_symbols["gene_id"] = var_id_symbols.index |
86 | 122 |
87 gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left') | 123 gene_set_gene = gene_set_gene.merge( |
88 # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those | 124 var_id_symbols, |
89 # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation | 125 left_on="gene", |
90 gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene']) | 126 right_on=var_gene_symbols_field, |
91 gene_set_gene['gene'] = gene_set_gene['gene_id'] | 127 how="left", |
92 | 128 ) |
129 # this will still produce some empty gene_ids (genes in the | |
130 # gene_set_gene df that are not in the var df), fill those | |
131 # with the original gene symbol from the gene_set to avoid | |
132 # deforming the AUCell calculation | |
133 gene_set_gene["gene_id"] = gene_set_gene["gene_id"].fillna( | |
134 gene_set_gene["gene"] | |
135 ) | |
136 gene_set_gene["gene"] = gene_set_gene["gene_id"] | |
137 | |
93 # run decoupler's run_aucell | 138 # run decoupler's run_aucell |
94 dc.run_aucell( | 139 dc.run_aucell( |
95 adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes | 140 adata, |
96 ) | 141 net=gene_set_gene, |
142 source="gene_set", | |
143 target="gene", | |
144 use_raw=use_raw, | |
145 min_n=min_n_genes, | |
146 ) | |
97 for gs in gene_set_gene.gene_set.unique(): | 147 for gs in gene_set_gene.gene_set.unique(): |
98 if gs in adata.obsm['aucell_estimate'].keys(): | 148 if gs in adata.obsm["aucell_estimate"].keys(): |
99 adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs] | 149 adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs] |
100 | 150 |
101 | 151 |
102 def run_for_genelists( | 152 def run_for_genelists( |
103 adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5 | 153 adata, |
154 gene_lists, | |
155 score_names, | |
156 use_raw=False, | |
157 gene_symbols_field=None, | |
158 min_n_genes=5, | |
104 ): | 159 ): |
105 if len(gene_lists) == len(score_names): | 160 if len(gene_lists) == len(score_names): |
106 for gene_list, score_names in zip(gene_lists, score_names): | 161 for gene_list, score_names in zip(gene_lists, score_names): |
107 genes = gene_list.split(",") | 162 genes = gene_list.split(",") |
108 gene_sets = {} | 163 gene_sets = {} |
109 gene_sets[score_names] = genes | 164 gene_sets[score_names] = genes |
110 gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) | 165 gene_set_gene_df = pd.concat( |
111 | 166 pd.DataFrame({"gene_set": k, "gene": v}) |
167 for k, v in gene_sets.items() | |
168 ) | |
169 | |
112 score_genes_aucell_mt( | 170 score_genes_aucell_mt( |
113 adata, | 171 adata, |
114 gene_set_gene_df, | 172 gene_set_gene_df, |
115 use_raw, | 173 use_raw, |
116 min_n_genes, | 174 min_n_genes, |
117 var_gene_symbols_field=gene_symbols_field | 175 var_gene_symbols_field=gene_symbols_field, |
118 ) | 176 ) |
119 else: | 177 else: |
120 raise ValueError( | 178 raise ValueError( |
121 "The number of gene lists (separated by :) and score names (separated by :) must be the same" | 179 "The number of gene lists (separated by :) and score names \ |
180 (separated by :) must be the same" | |
122 ) | 181 ) |
123 | 182 |
124 | 183 |
125 if __name__ == "__main__": | 184 if __name__ == "__main__": |
126 # Create command-line arguments parser | 185 # Create command-line arguments parser |
127 parser = argparse.ArgumentParser(description="Score genes using Aucell") | 186 parser = argparse.ArgumentParser(description="Score genes using Aucell") |
128 parser.add_argument( | 187 parser.add_argument( |
129 "--input_file", type=str, help="Path to input AnnData file", required=True | 188 "--input_file", |
189 type=str, | |
190 help="Path to input AnnData file", | |
191 required=True, | |
130 ) | 192 ) |
131 parser.add_argument( | 193 parser.add_argument( |
132 "--output_file", type=str, help="Path to output file", required=True | 194 "--output_file", type=str, help="Path to output file", required=True |
133 ) | 195 ) |
134 parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False) | 196 parser.add_argument( |
197 "--gmt_file", type=str, help="Path to GMT file", required=False | |
198 ) | |
135 # add argument for gene sets to score | 199 # add argument for gene sets to score |
136 parser.add_argument( | 200 parser.add_argument( |
137 "--gene_sets_to_score", | 201 "--gene_sets_to_score", |
138 type=str, | 202 type=str, |
139 required=False, | 203 required=False, |
140 help="Optional comma separated list of gene sets to score (the need to be in the gmt file)", | 204 help="Optional comma separated list of gene sets to score \ |
205 (the need to be in the gmt file)", | |
141 ) | 206 ) |
142 # add argument for gene list (comma separated) to score | 207 # add argument for gene list (comma separated) to score |
143 parser.add_argument( | 208 parser.add_argument( |
144 "--gene_lists_to_score", | 209 "--gene_lists_to_score", |
145 type=str, | 210 type=str, |
146 required=False, | 211 required=False, |
147 help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :", | 212 help="Comma separated list of genes to score. You can have more \ |
213 than one set of genes, separated by colon :", | |
148 ) | 214 ) |
149 # argument for the score name when using the gene list | 215 # argument for the score name when using the gene list |
150 parser.add_argument( | 216 parser.add_argument( |
151 "--score_names", | 217 "--score_names", |
152 type=str, | 218 type=str, |
153 required=False, | 219 required=False, |
154 help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.", | 220 help="Name of the score column when using the gene list. You can \ |
221 have more than one set of score names, separated by colon :. \ | |
222 It should be the same length as the number of gene lists.", | |
155 ) | 223 ) |
156 parser.add_argument( | 224 parser.add_argument( |
157 "--gene_symbols_field", | 225 "--gene_symbols_field", |
158 type=str, | 226 type=str, |
159 help="Name of the gene symbols field in the AnnData object", | 227 help="Name of the gene symbols field in the AnnData object", |
160 required=True, | 228 required=True, |
161 ) | 229 ) |
162 # argument for min_n Minimum of targets per source. If less, sources are removed. | 230 # argument for min_n Minimum of targets per source. If less, sources |
231 # are removed. | |
163 parser.add_argument( | 232 parser.add_argument( |
164 "--min_n", | 233 "--min_n", |
165 type=int, | 234 type=int, |
166 required=False, | 235 required=False, |
167 default=5, | 236 default=5, |
168 help="Minimum of targets per source. If less, sources are removed.", | 237 help="Minimum of targets per source. If less, sources are removed.", |
169 ) | 238 ) |
170 parser.add_argument("--use_raw", action="store_true", help="Use raw data") | 239 parser.add_argument("--use_raw", action="store_true", help="Use raw data") |
171 parser.add_argument( | 240 parser.add_argument( |
172 "--write_anndata", action="store_true", help="Write the modified AnnData object" | 241 "--write_anndata", |
242 action="store_true", | |
243 help="Write the modified AnnData object", | |
173 ) | 244 ) |
174 # argument for number of max concurrent processes | 245 # argument for number of max concurrent processes |
175 parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads") | 246 parser.add_argument( |
176 | 247 "--max_threads", |
248 type=int, | |
249 required=False, | |
250 default=1, | |
251 help="Number of max concurrent threads", | |
252 ) | |
177 | 253 |
178 # Parse command-line arguments | 254 # Parse command-line arguments |
179 args = parser.parse_args() | 255 args = parser.parse_args() |
180 | 256 |
181 nb.set_num_threads(n=args.max_threads) | 257 nb.set_num_threads(n=args.max_threads) |
187 # Load MSigDB file in GMT format | 263 # Load MSigDB file in GMT format |
188 # msigdb = read_gmt(args.gmt_file) | 264 # msigdb = read_gmt(args.gmt_file) |
189 msigdb = read_gmt_long(args.gmt_file) | 265 msigdb = read_gmt_long(args.gmt_file) |
190 | 266 |
191 gene_sets_to_score = ( | 267 gene_sets_to_score = ( |
192 args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] | 268 args.gene_sets_to_score.split(",") |
269 if args.gene_sets_to_score | |
270 else [] | |
193 ) | 271 ) |
194 if gene_sets_to_score: | 272 if gene_sets_to_score: |
195 # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument | 273 # we limit the GMT file read to the genesets specified in the |
274 # gene_sets_to_score argument | |
196 msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)] | 275 msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)] |
197 | 276 |
198 score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field) | 277 score_genes_aucell_mt( |
278 adata, | |
279 msigdb, | |
280 args.use_raw, | |
281 args.min_n, | |
282 var_gene_symbols_field=args.gene_symbols_field, | |
283 ) | |
199 elif args.gene_lists_to_score is not None and args.score_names is not None: | 284 elif args.gene_lists_to_score is not None and args.score_names is not None: |
200 gene_lists = args.gene_lists_to_score.split(":") | 285 gene_lists = args.gene_lists_to_score.split(":") |
201 score_names = args.score_names.split(",") | 286 score_names = args.score_names.split(",") |
202 run_for_genelists( | 287 run_for_genelists( |
203 adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n | 288 adata, |
204 ) | 289 gene_lists, |
205 | 290 score_names, |
206 # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns | 291 args.use_raw, |
292 args.gene_symbols_field, | |
293 args.min_n, | |
294 ) | |
295 | |
296 # Save the modified AnnData object or generate a file with cells as rows | |
297 # and the new score_names columns | |
207 if args.write_anndata: | 298 if args.write_anndata: |
208 adata.write_h5ad(args.output_file) | 299 adata.write_h5ad(args.output_file) |
209 else: | 300 else: |
210 new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")] | 301 new_columns = [ |
302 col for col in adata.obs.columns if col.startswith("AUCell_") | |
303 ] | |
211 adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True) | 304 adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True) |