changeset 7:68a2b5445558 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit b01245159f9cb67101497bb974b2c13bcee019b7
author ebi-gxa
date Tue, 16 Apr 2024 11:49:25 +0000
parents ed2a77422e00
children 93f61ea19336
files decoupler_aucell_score.py
diffstat 1 files changed, 72 insertions(+), 62 deletions(-) [+]
line wrap: on
line diff
--- a/decoupler_aucell_score.py	Mon Apr 15 13:20:32 2024 +0000
+++ b/decoupler_aucell_score.py	Tue Apr 16 11:49:25 2024 +0000
@@ -5,11 +5,12 @@
 import anndata
 import decoupler as dc
 import pandas as pd
+import numba as nb
 
 
-def read_gmt(gmt_file):
+def read_gmt_long(gmt_file):
     """
-    Reads a GMT file into a Pandas DataFrame.
+    Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method.
 
     Parameters
     ----------
@@ -19,7 +20,7 @@
     Returns
     -------
     pd.DataFrame
-        A DataFrame with the gene sets. Each row represents a gene set, and the columns are "gene_set_name", "gene_set_url", and "genes".
+        A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes".
     >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n"
     >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n"
     >>> temp_dir = tempfile.gettempdir()
@@ -29,81 +30,91 @@
     ...   f.write(line2)
     288
     380
-    >>> df = read_gmt(temp_gmt)
+    >>> df = read_gmt_long(temp_gmt)
     >>> df.shape[0]
-    2
-    >>> df.columns == ["gene_set_name", "genes"]
-    array([ True,  True])
-    >>> df.loc[df["gene_set_name"] == "HALLMARK_APICAL_SURFACE"].genes.tolist()[0].startswith("B4GALT1")
-    True
+    76
+    >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist())
+    44
     """
+    # Create a list of dictionaries, where each dictionary represents a gene set
+    gene_sets = {}
+
     # Read the GMT file into a list of lines
     with open(gmt_file, "r") as f:
-        lines = f.readlines()
+        while True:
+            line = f.readline()
+            if not line:
+                break
+            fields = line.strip().split("\t")
+            gene_sets[fields[0]]= fields[2:]
 
-    # Create a list of dictionaries, where each dictionary represents a gene set
-    gene_sets = []
-    for line in lines:
-        fields = line.strip().split("\t")
-        gene_set = {"gene_set_name": fields[0], "genes": ",".join(fields[2:])}
-        gene_sets.append(gene_set)
-
-    # Convert the list of dictionaries to a DataFrame
-    return pd.DataFrame(gene_sets)
+    return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items())
 
 
-def score_genes_aucell(
-    adata: anndata.AnnData, gene_list: list, score_name: str, use_raw=False, min_n_genes=5
-):
+def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None):
     """Score genes using Aucell.
 
     Parameters
     ----------
     adata : anndata.AnnData
-    gene_list : list
-    score_names : str
-    use_raw : bool, optional
+    gene_set_gene: pd.DataFrame with columns gene_set and gene
+    use_raw : bool, optional, False by default.
+    min_n_genes : int, optional, 5 by default.
+    var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored
 
     >>> import scanpy as sc
     >>> import decoupler as dc
     >>> adata = sc.datasets.pbmc68k_reduced()
-    >>> gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist()
-    >>> score_genes_aucell(adata, gene_list, "ribosomal_aucell", use_raw=False)
-    >>> "ribosomal_aucell" in adata.obs.columns
+    >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist()
+    >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist()
+    >>> gene_set = {}
+    >>> gene_set["m"] = m_gene_list
+    >>> gene_set["r"] = r_gene_list
+    >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items())
+    >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False)
+    >>> "AUCell_m" in adata.obs.columns
+    True
+    >>> "AUCell_r" in adata.obs.columns
     True
     """
-    # make a data.frame with two columns, geneset and gene_id, geneset filled with score_names and gene_id with gene_list, one row per element
-    geneset_df = pd.DataFrame(
-        {
-            "gene_id": gene_list,
-            "geneset": score_name,
-        }
-    )
+
+    # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols
+    if var_gene_symbols_field:
+        # merge the index of var to gene_set_gene df based on var_gene_symbols_field
+        var_id_symbols = adata.var[[var_gene_symbols_field]]
+        var_id_symbols['gene_id'] = var_id_symbols.index
+
+        gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left')
+        # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those
+        # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation
+        gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene'])
+        gene_set_gene['gene'] = gene_set_gene['gene_id']
+    
     # run decoupler's run_aucell
-    # catch the value error
-    try:
-        dc.run_aucell(
-            adata, net=geneset_df, source="geneset", target="gene_id", use_raw=use_raw
+    dc.run_aucell(
+            adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes
         )
-        # copy .obsm['aucell_estimate'] matrix columns to adata.obs using the column names
-        adata.obs[score_name] = adata.obsm["aucell_estimate"][score_name]
-    except ValueError as ve:
-        print(f"Gene list {score_name} failed, skipping: {str(ve)}")
+    for gs in gene_set_gene.gene_set.unique():
+        if gs in adata.obsm['aucell_estimate'].keys():
+            adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs]
 
 
 def run_for_genelists(
-    adata, gene_lists, score_names, use_raw=False, gene_symbols_field="gene_symbols", min_n_genes=5
+    adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5
 ):
     if len(gene_lists) == len(score_names):
         for gene_list, score_names in zip(gene_lists, score_names):
             genes = gene_list.split(",")
-            ens_gene_ids = adata.var[adata.var[gene_symbols_field].isin(genes)].index
-            score_genes_aucell(
+            gene_sets = {}
+            gene_sets[score_names] = genes
+            gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items())
+            
+            score_genes_aucell_mt(
                 adata,
-                ens_gene_ids,
-                f"AUCell_{score_names}",
+                gene_set_gene_df,
                 use_raw,
-                min_n_genes
+                min_n_genes,
+                var_gene_symbols_field=gene_symbols_field
             )
     else:
         raise ValueError(
@@ -160,32 +171,31 @@
     parser.add_argument(
         "--write_anndata", action="store_true", help="Write the modified AnnData object"
     )
+    # argument for number of max concurrent processes
+    parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads")
+
 
     # Parse command-line arguments
     args = parser.parse_args()
 
+    nb.set_num_threads(n=args.max_threads)
+
     # Load input AnnData object
     adata = anndata.read_h5ad(args.input_file)
 
     if args.gmt_file is not None:
         # Load MSigDB file in GMT format
-        msigdb = read_gmt(args.gmt_file)
+        # msigdb = read_gmt(args.gmt_file)
+        msigdb = read_gmt_long(args.gmt_file)
 
         gene_sets_to_score = (
             args.gene_sets_to_score.split(",") if args.gene_sets_to_score else []
         )
-        # Score genes by their ensembl ids using the score_genes_aucell function
-        for _, row in msigdb.iterrows():
-            gene_set_name = row["gene_set_name"]
-            if not gene_sets_to_score or gene_set_name in gene_sets_to_score:
-                genes = row["genes"].split(",")
-                # Convert gene symbols to ensembl ids by using the columns gene_symbols and index in adata.var specific to the gene set
-                ens_gene_ids = adata.var[
-                    adata.var[args.gene_symbols_field].isin(genes)
-                ].index
-                score_genes_aucell(
-                    adata, ens_gene_ids, f"AUCell_{gene_set_name}", args.use_raw, args.min_n
-                )
+        if gene_sets_to_score:
+            # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument
+            msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)]
+        
+        score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field)
     elif args.gene_lists_to_score is not None and args.score_names is not None:
         gene_lists = args.gene_lists_to_score.split(":")
         score_names = args.score_names.split(",")