comparison decoupler_aucell_score.py @ 3:c6787c2aee46 draft default tip

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author ebi-gxa
date Mon, 15 Jul 2024 10:56:37 +0000
parents 82b7cd3e1bbd
children
comparison
equal deleted inserted replaced
2:82b7cd3e1bbd 3:c6787c2aee46
1 import argparse 1 import argparse
2 import os
3 import tempfile
4 2
5 import anndata 3 import anndata
6 import decoupler as dc 4 import decoupler as dc
5 import numba as nb
7 import pandas as pd 6 import pandas as pd
8 import numba as nb
9 7
10 8
11 def read_gmt_long(gmt_file): 9 def read_gmt_long(gmt_file):
12 """ 10 r"""
13 Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method. 11 Reads a GMT file and produce a Pandas DataFrame in long format, ready to
12 be passed to the AUCell method.
14 13
15 Parameters 14 Parameters
16 ---------- 15 ----------
17 gmt_file : str 16 gmt_file : str
18 Path to the GMT file. 17 Path to the GMT file.
19 18
20 Returns 19 Returns
21 ------- 20 -------
22 pd.DataFrame 21 pd.DataFrame
23 A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes". 22 A DataFrame with the gene sets. Each row represents a gene set to gene
24 >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" 23 assignment, and the columns are "gene_set_name" and "genes".
25 >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" 24 >>> import os
25 >>> import tempfile
26 >>> line = "HALLMARK_NOTCH_SIGNALING\
27 ... \thttp://www.gsea-msigdb.org/\
28 ... gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\
29 ... \tJAG1\tNOTCH3\tNOTCH2\tAPH1A\tHES1\tCCND1\
30 ... \tFZD1\tPSEN2\tFZD7\tDTX1\tDLL1\tFZD5\tMAML2\
31 ... \tNOTCH1\tPSENEN\tWNT5A\tCUL1\tWNT2\tDTX4\
32 ... \tSAP30\tPPARD\tKAT2A\tHEYL\tSKP1\tRBX1\tTCF7L2\
33 ... \tARRB1\tLFNG\tPRKCA\tDTX2\tST3GAL6\tFBXW11\n"
34 >>> line2 = "HALLMARK_APICAL_SURFACE\
35 ... \thttp://www.gsea-msigdb.org/\
36 ... gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\
37 ... \tB4GALT1\tRHCG\tMAL\tLYPD3\tPKHD1\tATP6V0A4\
38 ... \tCRYBG1\tSHROOM2\tSRPX\tMDGA1\tTMEM8B\tTHY1\
39 ... \tPCSK9\tEPHB4\tDCBLD2\tGHRL\tLYN\tGAS1\tFLOT2\
40 ... \tPLAUR\tAKAP7\tATP8B1\tEFNA5\tSLC34A3\tAPP\
41 ... \tGSTM3\tHSPB1\tSLC2A4\tIL2RB\tRTN4RL1\tNCOA6\
42 ... \tSULF2\tADAM10\tBRCA1\tGATA3\tAFAP1L2\tIL2RG\
43 ... \tCD160\tADIPOR2\tSLC22A12\tNTNG1\tSCUBE1\tCX3CL1\
44 ... \tCROCC\n"
26 >>> temp_dir = tempfile.gettempdir() 45 >>> temp_dir = tempfile.gettempdir()
27 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") 46 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt")
28 >>> with open(temp_gmt, "w") as f: 47 >>> with open(temp_gmt, "w") as f:
29 ... f.write(line) 48 ... f.write(line)
30 ... f.write(line2) 49 ... f.write(line2)
34 >>> df.shape[0] 53 >>> df.shape[0]
35 76 54 76
36 >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist()) 55 >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist())
37 44 56 44
38 """ 57 """
39 # Create a list of dictionaries, where each dictionary represents a gene set 58 # Create a list of dictionaries, where each dictionary represents a
59 # gene set
40 gene_sets = {} 60 gene_sets = {}
41 61
42 # Read the GMT file into a list of lines 62 # Read the GMT file into a list of lines
43 with open(gmt_file, "r") as f: 63 with open(gmt_file, "r") as f:
44 while True: 64 while True:
45 line = f.readline() 65 line = f.readline()
46 if not line: 66 if not line:
47 break 67 break
48 fields = line.strip().split("\t") 68 fields = line.strip().split("\t")
49 gene_sets[fields[0]]= fields[2:] 69 gene_sets[fields[0]] = fields[2:]
50 70
51 return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) 71 return pd.concat(
52 72 pd.DataFrame({"gene_set": k, "gene": v}) for k, v in gene_sets.items()
53 73 )
54 def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None): 74
75
76 def score_genes_aucell_mt(
77 adata: anndata.AnnData,
78 gene_set_gene: pd.DataFrame,
79 use_raw=False,
80 min_n_genes=5,
81 var_gene_symbols_field=None,
82 ):
55 """Score genes using Aucell. 83 """Score genes using Aucell.
56 84
57 Parameters 85 Parameters
58 ---------- 86 ----------
59 adata : anndata.AnnData 87 adata : anndata.AnnData
60 gene_set_gene: pd.DataFrame with columns gene_set and gene 88 gene_set_gene: pd.DataFrame with columns gene_set and gene
61 use_raw : bool, optional, False by default. 89 use_raw : bool, optional, False by default.
62 min_n_genes : int, optional, 5 by default. 90 min_n_genes : int, optional, 5 by default.
63 var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored 91 var_gene_symbols_field : str, optional, None by default. The field in var
92 where gene symbols are stored
64 93
65 >>> import scanpy as sc 94 >>> import scanpy as sc
66 >>> import decoupler as dc 95 >>> import decoupler as dc
67 >>> adata = sc.datasets.pbmc68k_reduced() 96 >>> adata = sc.datasets.pbmc68k_reduced()
68 >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() 97 >>> r_gene_list = adata.var[
69 >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist() 98 ... adata.var.index.str.startswith("RP")].index.tolist()
99 >>> m_gene_list = adata.var[
100 ... adata.var.index.str.startswith("M")].index.tolist()
70 >>> gene_set = {} 101 >>> gene_set = {}
71 >>> gene_set["m"] = m_gene_list 102 >>> gene_set["m"] = m_gene_list
72 >>> gene_set["r"] = r_gene_list 103 >>> gene_set["r"] = r_gene_list
73 >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items()) 104 >>> gene_set_df = pd.concat(
105 ... pd.DataFrame(
106 ... {'gene_set':k, 'gene':v}
107 ... ) for k, v in gene_set.items())
74 >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False) 108 >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False)
75 >>> "AUCell_m" in adata.obs.columns 109 >>> "AUCell_m" in adata.obs.columns
76 True 110 True
77 >>> "AUCell_r" in adata.obs.columns 111 >>> "AUCell_r" in adata.obs.columns
78 True 112 True
79 """ 113 """
80 114
81 # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols 115 # if var_gene_symbols_fiels is provided, transform gene_set_gene df so
116 # that gene contains gene ids instead of gene symbols
82 if var_gene_symbols_field: 117 if var_gene_symbols_field:
83 # merge the index of var to gene_set_gene df based on var_gene_symbols_field 118 # merge the index of var to gene_set_gene df based on
119 # var_gene_symbols_field
84 var_id_symbols = adata.var[[var_gene_symbols_field]] 120 var_id_symbols = adata.var[[var_gene_symbols_field]]
85 var_id_symbols['gene_id'] = var_id_symbols.index 121 var_id_symbols["gene_id"] = var_id_symbols.index
86 122
87 gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left') 123 gene_set_gene = gene_set_gene.merge(
88 # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those 124 var_id_symbols,
89 # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation 125 left_on="gene",
90 gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene']) 126 right_on=var_gene_symbols_field,
91 gene_set_gene['gene'] = gene_set_gene['gene_id'] 127 how="left",
92 128 )
129 # this will still produce some empty gene_ids (genes in the
130 # gene_set_gene df that are not in the var df), fill those
131 # with the original gene symbol from the gene_set to avoid
132 # deforming the AUCell calculation
133 gene_set_gene["gene_id"] = gene_set_gene["gene_id"].fillna(
134 gene_set_gene["gene"]
135 )
136 gene_set_gene["gene"] = gene_set_gene["gene_id"]
137
93 # run decoupler's run_aucell 138 # run decoupler's run_aucell
94 dc.run_aucell( 139 dc.run_aucell(
95 adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes 140 adata,
96 ) 141 net=gene_set_gene,
142 source="gene_set",
143 target="gene",
144 use_raw=use_raw,
145 min_n=min_n_genes,
146 )
97 for gs in gene_set_gene.gene_set.unique(): 147 for gs in gene_set_gene.gene_set.unique():
98 if gs in adata.obsm['aucell_estimate'].keys(): 148 if gs in adata.obsm["aucell_estimate"].keys():
99 adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs] 149 adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs]
100 150
101 151
102 def run_for_genelists( 152 def run_for_genelists(
103 adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5 153 adata,
154 gene_lists,
155 score_names,
156 use_raw=False,
157 gene_symbols_field=None,
158 min_n_genes=5,
104 ): 159 ):
105 if len(gene_lists) == len(score_names): 160 if len(gene_lists) == len(score_names):
106 for gene_list, score_names in zip(gene_lists, score_names): 161 for gene_list, score_names in zip(gene_lists, score_names):
107 genes = gene_list.split(",") 162 genes = gene_list.split(",")
108 gene_sets = {} 163 gene_sets = {}
109 gene_sets[score_names] = genes 164 gene_sets[score_names] = genes
110 gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) 165 gene_set_gene_df = pd.concat(
111 166 pd.DataFrame({"gene_set": k, "gene": v})
167 for k, v in gene_sets.items()
168 )
169
112 score_genes_aucell_mt( 170 score_genes_aucell_mt(
113 adata, 171 adata,
114 gene_set_gene_df, 172 gene_set_gene_df,
115 use_raw, 173 use_raw,
116 min_n_genes, 174 min_n_genes,
117 var_gene_symbols_field=gene_symbols_field 175 var_gene_symbols_field=gene_symbols_field,
118 ) 176 )
119 else: 177 else:
120 raise ValueError( 178 raise ValueError(
121 "The number of gene lists (separated by :) and score names (separated by :) must be the same" 179 "The number of gene lists (separated by :) and score names \
180 (separated by :) must be the same"
122 ) 181 )
123 182
124 183
125 if __name__ == "__main__": 184 if __name__ == "__main__":
126 # Create command-line arguments parser 185 # Create command-line arguments parser
127 parser = argparse.ArgumentParser(description="Score genes using Aucell") 186 parser = argparse.ArgumentParser(description="Score genes using Aucell")
128 parser.add_argument( 187 parser.add_argument(
129 "--input_file", type=str, help="Path to input AnnData file", required=True 188 "--input_file",
189 type=str,
190 help="Path to input AnnData file",
191 required=True,
130 ) 192 )
131 parser.add_argument( 193 parser.add_argument(
132 "--output_file", type=str, help="Path to output file", required=True 194 "--output_file", type=str, help="Path to output file", required=True
133 ) 195 )
134 parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False) 196 parser.add_argument(
197 "--gmt_file", type=str, help="Path to GMT file", required=False
198 )
135 # add argument for gene sets to score 199 # add argument for gene sets to score
136 parser.add_argument( 200 parser.add_argument(
137 "--gene_sets_to_score", 201 "--gene_sets_to_score",
138 type=str, 202 type=str,
139 required=False, 203 required=False,
140 help="Optional comma separated list of gene sets to score (the need to be in the gmt file)", 204 help="Optional comma separated list of gene sets to score \
205 (the need to be in the gmt file)",
141 ) 206 )
142 # add argument for gene list (comma separated) to score 207 # add argument for gene list (comma separated) to score
143 parser.add_argument( 208 parser.add_argument(
144 "--gene_lists_to_score", 209 "--gene_lists_to_score",
145 type=str, 210 type=str,
146 required=False, 211 required=False,
147 help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :", 212 help="Comma separated list of genes to score. You can have more \
213 than one set of genes, separated by colon :",
148 ) 214 )
149 # argument for the score name when using the gene list 215 # argument for the score name when using the gene list
150 parser.add_argument( 216 parser.add_argument(
151 "--score_names", 217 "--score_names",
152 type=str, 218 type=str,
153 required=False, 219 required=False,
154 help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.", 220 help="Name of the score column when using the gene list. You can \
221 have more than one set of score names, separated by colon :. \
222 It should be the same length as the number of gene lists.",
155 ) 223 )
156 parser.add_argument( 224 parser.add_argument(
157 "--gene_symbols_field", 225 "--gene_symbols_field",
158 type=str, 226 type=str,
159 help="Name of the gene symbols field in the AnnData object", 227 help="Name of the gene symbols field in the AnnData object",
160 required=True, 228 required=True,
161 ) 229 )
162 # argument for min_n Minimum of targets per source. If less, sources are removed. 230 # argument for min_n Minimum of targets per source. If less, sources
231 # are removed.
163 parser.add_argument( 232 parser.add_argument(
164 "--min_n", 233 "--min_n",
165 type=int, 234 type=int,
166 required=False, 235 required=False,
167 default=5, 236 default=5,
168 help="Minimum of targets per source. If less, sources are removed.", 237 help="Minimum of targets per source. If less, sources are removed.",
169 ) 238 )
170 parser.add_argument("--use_raw", action="store_true", help="Use raw data") 239 parser.add_argument("--use_raw", action="store_true", help="Use raw data")
171 parser.add_argument( 240 parser.add_argument(
172 "--write_anndata", action="store_true", help="Write the modified AnnData object" 241 "--write_anndata",
242 action="store_true",
243 help="Write the modified AnnData object",
173 ) 244 )
174 # argument for number of max concurrent processes 245 # argument for number of max concurrent processes
175 parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads") 246 parser.add_argument(
176 247 "--max_threads",
248 type=int,
249 required=False,
250 default=1,
251 help="Number of max concurrent threads",
252 )
177 253
178 # Parse command-line arguments 254 # Parse command-line arguments
179 args = parser.parse_args() 255 args = parser.parse_args()
180 256
181 nb.set_num_threads(n=args.max_threads) 257 nb.set_num_threads(n=args.max_threads)
187 # Load MSigDB file in GMT format 263 # Load MSigDB file in GMT format
188 # msigdb = read_gmt(args.gmt_file) 264 # msigdb = read_gmt(args.gmt_file)
189 msigdb = read_gmt_long(args.gmt_file) 265 msigdb = read_gmt_long(args.gmt_file)
190 266
191 gene_sets_to_score = ( 267 gene_sets_to_score = (
192 args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] 268 args.gene_sets_to_score.split(",")
269 if args.gene_sets_to_score
270 else []
193 ) 271 )
194 if gene_sets_to_score: 272 if gene_sets_to_score:
195 # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument 273 # we limit the GMT file read to the genesets specified in the
274 # gene_sets_to_score argument
196 msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)] 275 msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)]
197 276
198 score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field) 277 score_genes_aucell_mt(
278 adata,
279 msigdb,
280 args.use_raw,
281 args.min_n,
282 var_gene_symbols_field=args.gene_symbols_field,
283 )
199 elif args.gene_lists_to_score is not None and args.score_names is not None: 284 elif args.gene_lists_to_score is not None and args.score_names is not None:
200 gene_lists = args.gene_lists_to_score.split(":") 285 gene_lists = args.gene_lists_to_score.split(":")
201 score_names = args.score_names.split(",") 286 score_names = args.score_names.split(",")
202 run_for_genelists( 287 run_for_genelists(
203 adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n 288 adata,
204 ) 289 gene_lists,
205 290 score_names,
206 # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns 291 args.use_raw,
292 args.gene_symbols_field,
293 args.min_n,
294 )
295
296 # Save the modified AnnData object or generate a file with cells as rows
297 # and the new score_names columns
207 if args.write_anndata: 298 if args.write_anndata:
208 adata.write_h5ad(args.output_file) 299 adata.write_h5ad(args.output_file)
209 else: 300 else:
210 new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")] 301 new_columns = [
302 col for col in adata.obs.columns if col.startswith("AUCell_")
303 ]
211 adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True) 304 adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True)