Mercurial > repos > ebi-gxa > decoupler_pseudobulk
comparison decoupler_aucell_score.py @ 7:68a2b5445558 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit b01245159f9cb67101497bb974b2c13bcee019b7
author | ebi-gxa |
---|---|
date | Tue, 16 Apr 2024 11:49:25 +0000 |
parents | ed2a77422e00 |
children | 93f61ea19336 |
comparison
equal
deleted
inserted
replaced
6:ed2a77422e00 | 7:68a2b5445558 |
---|---|
3 import tempfile | 3 import tempfile |
4 | 4 |
5 import anndata | 5 import anndata |
6 import decoupler as dc | 6 import decoupler as dc |
7 import pandas as pd | 7 import pandas as pd |
8 | 8 import numba as nb |
9 | 9 |
10 def read_gmt(gmt_file): | 10 |
11 def read_gmt_long(gmt_file): | |
11 """ | 12 """ |
12 Reads a GMT file into a Pandas DataFrame. | 13 Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method. |
13 | 14 |
14 Parameters | 15 Parameters |
15 ---------- | 16 ---------- |
16 gmt_file : str | 17 gmt_file : str |
17 Path to the GMT file. | 18 Path to the GMT file. |
18 | 19 |
19 Returns | 20 Returns |
20 ------- | 21 ------- |
21 pd.DataFrame | 22 pd.DataFrame |
22 A DataFrame with the gene sets. Each row represents a gene set, and the columns are "gene_set_name", "gene_set_url", and "genes". | 23 A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes". |
23 >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" | 24 >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" |
24 >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" | 25 >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" |
25 >>> temp_dir = tempfile.gettempdir() | 26 >>> temp_dir = tempfile.gettempdir() |
26 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") | 27 >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") |
27 >>> with open(temp_gmt, "w") as f: | 28 >>> with open(temp_gmt, "w") as f: |
28 ... f.write(line) | 29 ... f.write(line) |
29 ... f.write(line2) | 30 ... f.write(line2) |
30 288 | 31 288 |
31 380 | 32 380 |
32 >>> df = read_gmt(temp_gmt) | 33 >>> df = read_gmt_long(temp_gmt) |
33 >>> df.shape[0] | 34 >>> df.shape[0] |
34 2 | 35 76 |
35 >>> df.columns == ["gene_set_name", "genes"] | 36 >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist()) |
36 array([ True, True]) | 37 44 |
37 >>> df.loc[df["gene_set_name"] == "HALLMARK_APICAL_SURFACE"].genes.tolist()[0].startswith("B4GALT1") | |
38 True | |
39 """ | 38 """ |
39 # Create a list of dictionaries, where each dictionary represents a gene set | |
40 gene_sets = {} | |
41 | |
40 # Read the GMT file into a list of lines | 42 # Read the GMT file into a list of lines |
41 with open(gmt_file, "r") as f: | 43 with open(gmt_file, "r") as f: |
42 lines = f.readlines() | 44 while True: |
43 | 45 line = f.readline() |
44 # Create a list of dictionaries, where each dictionary represents a gene set | 46 if not line: |
45 gene_sets = [] | 47 break |
46 for line in lines: | 48 fields = line.strip().split("\t") |
47 fields = line.strip().split("\t") | 49 gene_sets[fields[0]]= fields[2:] |
48 gene_set = {"gene_set_name": fields[0], "genes": ",".join(fields[2:])} | 50 |
49 gene_sets.append(gene_set) | 51 return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) |
50 | 52 |
51 # Convert the list of dictionaries to a DataFrame | 53 |
52 return pd.DataFrame(gene_sets) | 54 def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None): |
53 | |
54 | |
55 def score_genes_aucell( | |
56 adata: anndata.AnnData, gene_list: list, score_name: str, use_raw=False, min_n_genes=5 | |
57 ): | |
58 """Score genes using Aucell. | 55 """Score genes using Aucell. |
59 | 56 |
60 Parameters | 57 Parameters |
61 ---------- | 58 ---------- |
62 adata : anndata.AnnData | 59 adata : anndata.AnnData |
63 gene_list : list | 60 gene_set_gene: pd.DataFrame with columns gene_set and gene |
64 score_names : str | 61 use_raw : bool, optional, False by default. |
65 use_raw : bool, optional | 62 min_n_genes : int, optional, 5 by default. |
63 var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored | |
66 | 64 |
67 >>> import scanpy as sc | 65 >>> import scanpy as sc |
68 >>> import decoupler as dc | 66 >>> import decoupler as dc |
69 >>> adata = sc.datasets.pbmc68k_reduced() | 67 >>> adata = sc.datasets.pbmc68k_reduced() |
70 >>> gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() | 68 >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() |
71 >>> score_genes_aucell(adata, gene_list, "ribosomal_aucell", use_raw=False) | 69 >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist() |
72 >>> "ribosomal_aucell" in adata.obs.columns | 70 >>> gene_set = {} |
71 >>> gene_set["m"] = m_gene_list | |
72 >>> gene_set["r"] = r_gene_list | |
73 >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items()) | |
74 >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False) | |
75 >>> "AUCell_m" in adata.obs.columns | |
76 True | |
77 >>> "AUCell_r" in adata.obs.columns | |
73 True | 78 True |
74 """ | 79 """ |
75 # make a data.frame with two columns, geneset and gene_id, geneset filled with score_names and gene_id with gene_list, one row per element | 80 |
76 geneset_df = pd.DataFrame( | 81 # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols |
77 { | 82 if var_gene_symbols_field: |
78 "gene_id": gene_list, | 83 # merge the index of var to gene_set_gene df based on var_gene_symbols_field |
79 "geneset": score_name, | 84 var_id_symbols = adata.var[[var_gene_symbols_field]] |
80 } | 85 var_id_symbols['gene_id'] = var_id_symbols.index |
81 ) | 86 |
87 gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left') | |
88 # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those | |
89 # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation | |
90 gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene']) | |
91 gene_set_gene['gene'] = gene_set_gene['gene_id'] | |
92 | |
82 # run decoupler's run_aucell | 93 # run decoupler's run_aucell |
83 # catch the value error | 94 dc.run_aucell( |
84 try: | 95 adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes |
85 dc.run_aucell( | 96 ) |
86 adata, net=geneset_df, source="geneset", target="gene_id", use_raw=use_raw | 97 for gs in gene_set_gene.gene_set.unique(): |
87 ) | 98 if gs in adata.obsm['aucell_estimate'].keys(): |
88 # copy .obsm['aucell_estimate'] matrix columns to adata.obs using the column names | 99 adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs] |
89 adata.obs[score_name] = adata.obsm["aucell_estimate"][score_name] | |
90 except ValueError as ve: | |
91 print(f"Gene list {score_name} failed, skipping: {str(ve)}") | |
92 | 100 |
93 | 101 |
94 def run_for_genelists( | 102 def run_for_genelists( |
95 adata, gene_lists, score_names, use_raw=False, gene_symbols_field="gene_symbols", min_n_genes=5 | 103 adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5 |
96 ): | 104 ): |
97 if len(gene_lists) == len(score_names): | 105 if len(gene_lists) == len(score_names): |
98 for gene_list, score_names in zip(gene_lists, score_names): | 106 for gene_list, score_names in zip(gene_lists, score_names): |
99 genes = gene_list.split(",") | 107 genes = gene_list.split(",") |
100 ens_gene_ids = adata.var[adata.var[gene_symbols_field].isin(genes)].index | 108 gene_sets = {} |
101 score_genes_aucell( | 109 gene_sets[score_names] = genes |
110 gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) | |
111 | |
112 score_genes_aucell_mt( | |
102 adata, | 113 adata, |
103 ens_gene_ids, | 114 gene_set_gene_df, |
104 f"AUCell_{score_names}", | |
105 use_raw, | 115 use_raw, |
106 min_n_genes | 116 min_n_genes, |
117 var_gene_symbols_field=gene_symbols_field | |
107 ) | 118 ) |
108 else: | 119 else: |
109 raise ValueError( | 120 raise ValueError( |
110 "The number of gene lists (separated by :) and score names (separated by :) must be the same" | 121 "The number of gene lists (separated by :) and score names (separated by :) must be the same" |
111 ) | 122 ) |
158 ) | 169 ) |
159 parser.add_argument("--use_raw", action="store_true", help="Use raw data") | 170 parser.add_argument("--use_raw", action="store_true", help="Use raw data") |
160 parser.add_argument( | 171 parser.add_argument( |
161 "--write_anndata", action="store_true", help="Write the modified AnnData object" | 172 "--write_anndata", action="store_true", help="Write the modified AnnData object" |
162 ) | 173 ) |
174 # argument for number of max concurrent processes | |
175 parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads") | |
176 | |
163 | 177 |
164 # Parse command-line arguments | 178 # Parse command-line arguments |
165 args = parser.parse_args() | 179 args = parser.parse_args() |
166 | 180 |
181 nb.set_num_threads(n=args.max_threads) | |
182 | |
167 # Load input AnnData object | 183 # Load input AnnData object |
168 adata = anndata.read_h5ad(args.input_file) | 184 adata = anndata.read_h5ad(args.input_file) |
169 | 185 |
170 if args.gmt_file is not None: | 186 if args.gmt_file is not None: |
171 # Load MSigDB file in GMT format | 187 # Load MSigDB file in GMT format |
172 msigdb = read_gmt(args.gmt_file) | 188 # msigdb = read_gmt(args.gmt_file) |
189 msigdb = read_gmt_long(args.gmt_file) | |
173 | 190 |
174 gene_sets_to_score = ( | 191 gene_sets_to_score = ( |
175 args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] | 192 args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] |
176 ) | 193 ) |
177 # Score genes by their ensembl ids using the score_genes_aucell function | 194 if gene_sets_to_score: |
178 for _, row in msigdb.iterrows(): | 195 # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument |
179 gene_set_name = row["gene_set_name"] | 196 msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)] |
180 if not gene_sets_to_score or gene_set_name in gene_sets_to_score: | 197 |
181 genes = row["genes"].split(",") | 198 score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field) |
182 # Convert gene symbols to ensembl ids by using the columns gene_symbols and index in adata.var specific to the gene set | |
183 ens_gene_ids = adata.var[ | |
184 adata.var[args.gene_symbols_field].isin(genes) | |
185 ].index | |
186 score_genes_aucell( | |
187 adata, ens_gene_ids, f"AUCell_{gene_set_name}", args.use_raw, args.min_n | |
188 ) | |
189 elif args.gene_lists_to_score is not None and args.score_names is not None: | 199 elif args.gene_lists_to_score is not None and args.score_names is not None: |
190 gene_lists = args.gene_lists_to_score.split(":") | 200 gene_lists = args.gene_lists_to_score.split(":") |
191 score_names = args.score_names.split(",") | 201 score_names = args.score_names.split(",") |
192 run_for_genelists( | 202 run_for_genelists( |
193 adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n | 203 adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n |