Mercurial > repos > ebi-gxa > decoupler_pathway_inference
changeset 3:c6787c2aee46 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author | ebi-gxa |
---|---|
date | Mon, 15 Jul 2024 10:56:37 +0000 |
parents | 82b7cd3e1bbd |
children | 6c30272fb587 |
files | decoupler_aucell_score.py decoupler_pathway_inference.py decoupler_pseudobulk.py |
diffstat | 3 files changed, 354 insertions(+), 82 deletions(-) [+] |
line wrap: on
line diff
--- a/decoupler_aucell_score.py Tue Apr 16 11:49:19 2024 +0000 +++ b/decoupler_aucell_score.py Mon Jul 15 10:56:37 2024 +0000 @@ -1,16 +1,15 @@ import argparse -import os -import tempfile import anndata import decoupler as dc +import numba as nb import pandas as pd -import numba as nb def read_gmt_long(gmt_file): - """ - Reads a GMT file and produce a Pandas DataFrame in long format, ready to be passed to the AUCell method. + r""" + Reads a GMT file and produce a Pandas DataFrame in long format, ready to + be passed to the AUCell method. Parameters ---------- @@ -20,9 +19,29 @@ Returns ------- pd.DataFrame - A DataFrame with the gene sets. Each row represents a gene set to gene assignment, and the columns are "gene_set_name" and "genes". - >>> line = "HALLMARK_NOTCH_SIGNALING\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\\tJAG1\\tNOTCH3\\tNOTCH2\\tAPH1A\\tHES1\\tCCND1\\tFZD1\\tPSEN2\\tFZD7\\tDTX1\\tDLL1\\tFZD5\\tMAML2\\tNOTCH1\\tPSENEN\\tWNT5A\\tCUL1\\tWNT2\\tDTX4\\tSAP30\\tPPARD\\tKAT2A\\tHEYL\\tSKP1\\tRBX1\\tTCF7L2\\tARRB1\\tLFNG\\tPRKCA\\tDTX2\\tST3GAL6\\tFBXW11\\n" - >>> line2 = "HALLMARK_APICAL_SURFACE\\thttp://www.gsea-msigdb.org/gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\\tB4GALT1\\tRHCG\\tMAL\\tLYPD3\\tPKHD1\\tATP6V0A4\\tCRYBG1\\tSHROOM2\\tSRPX\\tMDGA1\\tTMEM8B\\tTHY1\\tPCSK9\\tEPHB4\\tDCBLD2\\tGHRL\\tLYN\\tGAS1\\tFLOT2\\tPLAUR\\tAKAP7\\tATP8B1\\tEFNA5\\tSLC34A3\\tAPP\\tGSTM3\\tHSPB1\\tSLC2A4\\tIL2RB\\tRTN4RL1\\tNCOA6\\tSULF2\\tADAM10\\tBRCA1\\tGATA3\\tAFAP1L2\\tIL2RG\\tCD160\\tADIPOR2\\tSLC22A12\\tNTNG1\\tSCUBE1\\tCX3CL1\\tCROCC\\n" + A DataFrame with the gene sets. Each row represents a gene set to gene + assignment, and the columns are "gene_set_name" and "genes". + >>> import os + >>> import tempfile + >>> line = "HALLMARK_NOTCH_SIGNALING\ + ... \thttp://www.gsea-msigdb.org/\ + ... gsea/msigdb/human/geneset/HALLMARK_NOTCH_SIGNALING\ + ... \tJAG1\tNOTCH3\tNOTCH2\tAPH1A\tHES1\tCCND1\ + ... \tFZD1\tPSEN2\tFZD7\tDTX1\tDLL1\tFZD5\tMAML2\ + ... \tNOTCH1\tPSENEN\tWNT5A\tCUL1\tWNT2\tDTX4\ + ... \tSAP30\tPPARD\tKAT2A\tHEYL\tSKP1\tRBX1\tTCF7L2\ + ... \tARRB1\tLFNG\tPRKCA\tDTX2\tST3GAL6\tFBXW11\n" + >>> line2 = "HALLMARK_APICAL_SURFACE\ + ... \thttp://www.gsea-msigdb.org/\ + ... gsea/msigdb/human/geneset/HALLMARK_APICAL_SURFACE\ + ... \tB4GALT1\tRHCG\tMAL\tLYPD3\tPKHD1\tATP6V0A4\ + ... \tCRYBG1\tSHROOM2\tSRPX\tMDGA1\tTMEM8B\tTHY1\ + ... \tPCSK9\tEPHB4\tDCBLD2\tGHRL\tLYN\tGAS1\tFLOT2\ + ... \tPLAUR\tAKAP7\tATP8B1\tEFNA5\tSLC34A3\tAPP\ + ... \tGSTM3\tHSPB1\tSLC2A4\tIL2RB\tRTN4RL1\tNCOA6\ + ... \tSULF2\tADAM10\tBRCA1\tGATA3\tAFAP1L2\tIL2RG\ + ... \tCD160\tADIPOR2\tSLC22A12\tNTNG1\tSCUBE1\tCX3CL1\ + ... \tCROCC\n" >>> temp_dir = tempfile.gettempdir() >>> temp_gmt = os.path.join(temp_dir, "temp_file.gmt") >>> with open(temp_gmt, "w") as f: @@ -36,7 +55,8 @@ >>> len(df.loc[df["gene_set"] == "HALLMARK_APICAL_SURFACE"].gene.tolist()) 44 """ - # Create a list of dictionaries, where each dictionary represents a gene set + # Create a list of dictionaries, where each dictionary represents a + # gene set gene_sets = {} # Read the GMT file into a list of lines @@ -46,12 +66,20 @@ if not line: break fields = line.strip().split("\t") - gene_sets[fields[0]]= fields[2:] + gene_sets[fields[0]] = fields[2:] - return pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) + return pd.concat( + pd.DataFrame({"gene_set": k, "gene": v}) for k, v in gene_sets.items() + ) -def score_genes_aucell_mt(adata: anndata.AnnData, gene_set_gene: pd.DataFrame, use_raw=False, min_n_genes=5, var_gene_symbols_field=None): +def score_genes_aucell_mt( + adata: anndata.AnnData, + gene_set_gene: pd.DataFrame, + use_raw=False, + min_n_genes=5, + var_gene_symbols_field=None, +): """Score genes using Aucell. Parameters @@ -60,17 +88,23 @@ gene_set_gene: pd.DataFrame with columns gene_set and gene use_raw : bool, optional, False by default. min_n_genes : int, optional, 5 by default. - var_gene_symbols_field : str, optional, None by default. The field in var where gene symbols are stored + var_gene_symbols_field : str, optional, None by default. The field in var + where gene symbols are stored >>> import scanpy as sc >>> import decoupler as dc >>> adata = sc.datasets.pbmc68k_reduced() - >>> r_gene_list = adata.var[adata.var.index.str.startswith("RP")].index.tolist() - >>> m_gene_list = adata.var[adata.var.index.str.startswith("M")].index.tolist() + >>> r_gene_list = adata.var[ + ... adata.var.index.str.startswith("RP")].index.tolist() + >>> m_gene_list = adata.var[ + ... adata.var.index.str.startswith("M")].index.tolist() >>> gene_set = {} >>> gene_set["m"] = m_gene_list >>> gene_set["r"] = r_gene_list - >>> gene_set_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_set.items()) + >>> gene_set_df = pd.concat( + ... pd.DataFrame( + ... {'gene_set':k, 'gene':v} + ... ) for k, v in gene_set.items()) >>> score_genes_aucell_mt(adata, gene_set_df, use_raw=False) >>> "AUCell_m" in adata.obs.columns True @@ -78,47 +112,72 @@ True """ - # if var_gene_symbols_fiels is provided, transform gene_set_gene df so that gene contains gene ids instead of gene symbols + # if var_gene_symbols_fiels is provided, transform gene_set_gene df so + # that gene contains gene ids instead of gene symbols if var_gene_symbols_field: - # merge the index of var to gene_set_gene df based on var_gene_symbols_field + # merge the index of var to gene_set_gene df based on + # var_gene_symbols_field var_id_symbols = adata.var[[var_gene_symbols_field]] - var_id_symbols['gene_id'] = var_id_symbols.index + var_id_symbols["gene_id"] = var_id_symbols.index - gene_set_gene = gene_set_gene.merge(var_id_symbols, left_on='gene', right_on=var_gene_symbols_field, how='left') - # this will still produce some empty gene_ids (genes in the gene_set_gene df that are not in the var df), fill those - # with the original gene symbol from the gene_set to avoid deforming the AUCell calculation - gene_set_gene['gene_id'] = gene_set_gene['gene_id'].fillna(gene_set_gene['gene']) - gene_set_gene['gene'] = gene_set_gene['gene_id'] - + gene_set_gene = gene_set_gene.merge( + var_id_symbols, + left_on="gene", + right_on=var_gene_symbols_field, + how="left", + ) + # this will still produce some empty gene_ids (genes in the + # gene_set_gene df that are not in the var df), fill those + # with the original gene symbol from the gene_set to avoid + # deforming the AUCell calculation + gene_set_gene["gene_id"] = gene_set_gene["gene_id"].fillna( + gene_set_gene["gene"] + ) + gene_set_gene["gene"] = gene_set_gene["gene_id"] + # run decoupler's run_aucell dc.run_aucell( - adata, net=gene_set_gene, source="gene_set", target="gene", use_raw=use_raw, min_n=min_n_genes - ) + adata, + net=gene_set_gene, + source="gene_set", + target="gene", + use_raw=use_raw, + min_n=min_n_genes, + ) for gs in gene_set_gene.gene_set.unique(): - if gs in adata.obsm['aucell_estimate'].keys(): + if gs in adata.obsm["aucell_estimate"].keys(): adata.obs[f"AUCell_{gs}"] = adata.obsm["aucell_estimate"][gs] def run_for_genelists( - adata, gene_lists, score_names, use_raw=False, gene_symbols_field=None, min_n_genes=5 + adata, + gene_lists, + score_names, + use_raw=False, + gene_symbols_field=None, + min_n_genes=5, ): if len(gene_lists) == len(score_names): for gene_list, score_names in zip(gene_lists, score_names): genes = gene_list.split(",") gene_sets = {} gene_sets[score_names] = genes - gene_set_gene_df = pd.concat(pd.DataFrame({'gene_set':k, 'gene':v}) for k, v in gene_sets.items()) - + gene_set_gene_df = pd.concat( + pd.DataFrame({"gene_set": k, "gene": v}) + for k, v in gene_sets.items() + ) + score_genes_aucell_mt( adata, gene_set_gene_df, use_raw, min_n_genes, - var_gene_symbols_field=gene_symbols_field + var_gene_symbols_field=gene_symbols_field, ) else: raise ValueError( - "The number of gene lists (separated by :) and score names (separated by :) must be the same" + "The number of gene lists (separated by :) and score names \ + (separated by :) must be the same" ) @@ -126,32 +185,41 @@ # Create command-line arguments parser parser = argparse.ArgumentParser(description="Score genes using Aucell") parser.add_argument( - "--input_file", type=str, help="Path to input AnnData file", required=True + "--input_file", + type=str, + help="Path to input AnnData file", + required=True, ) parser.add_argument( "--output_file", type=str, help="Path to output file", required=True ) - parser.add_argument("--gmt_file", type=str, help="Path to GMT file", required=False) + parser.add_argument( + "--gmt_file", type=str, help="Path to GMT file", required=False + ) # add argument for gene sets to score parser.add_argument( "--gene_sets_to_score", type=str, required=False, - help="Optional comma separated list of gene sets to score (the need to be in the gmt file)", + help="Optional comma separated list of gene sets to score \ + (the need to be in the gmt file)", ) # add argument for gene list (comma separated) to score parser.add_argument( "--gene_lists_to_score", type=str, required=False, - help="Comma separated list of genes to score. You can have more than one set of genes, separated by colon :", + help="Comma separated list of genes to score. You can have more \ + than one set of genes, separated by colon :", ) # argument for the score name when using the gene list parser.add_argument( "--score_names", type=str, required=False, - help="Name of the score column when using the gene list. You can have more than one set of score names, separated by colon :. It should be the same length as the number of gene lists.", + help="Name of the score column when using the gene list. You can \ + have more than one set of score names, separated by colon :. \ + It should be the same length as the number of gene lists.", ) parser.add_argument( "--gene_symbols_field", @@ -159,7 +227,8 @@ help="Name of the gene symbols field in the AnnData object", required=True, ) - # argument for min_n Minimum of targets per source. If less, sources are removed. + # argument for min_n Minimum of targets per source. If less, sources + # are removed. parser.add_argument( "--min_n", type=int, @@ -169,11 +238,18 @@ ) parser.add_argument("--use_raw", action="store_true", help="Use raw data") parser.add_argument( - "--write_anndata", action="store_true", help="Write the modified AnnData object" + "--write_anndata", + action="store_true", + help="Write the modified AnnData object", ) # argument for number of max concurrent processes - parser.add_argument("--max_threads", type=int, required=False, default=1, help="Number of max concurrent threads") - + parser.add_argument( + "--max_threads", + type=int, + required=False, + default=1, + help="Number of max concurrent threads", + ) # Parse command-line arguments args = parser.parse_args() @@ -189,23 +265,40 @@ msigdb = read_gmt_long(args.gmt_file) gene_sets_to_score = ( - args.gene_sets_to_score.split(",") if args.gene_sets_to_score else [] + args.gene_sets_to_score.split(",") + if args.gene_sets_to_score + else [] ) if gene_sets_to_score: - # we limit the GMT file read to the genesets specified in the gene_sets_to_score argument + # we limit the GMT file read to the genesets specified in the + # gene_sets_to_score argument msigdb = msigdb[msigdb["gene_set"].isin(gene_sets_to_score)] - - score_genes_aucell_mt(adata, msigdb, args.use_raw, args.min_n, var_gene_symbols_field=args.gene_symbols_field) + + score_genes_aucell_mt( + adata, + msigdb, + args.use_raw, + args.min_n, + var_gene_symbols_field=args.gene_symbols_field, + ) elif args.gene_lists_to_score is not None and args.score_names is not None: gene_lists = args.gene_lists_to_score.split(":") score_names = args.score_names.split(",") run_for_genelists( - adata, gene_lists, score_names, args.use_raw, args.gene_symbols_field, args.min_n + adata, + gene_lists, + score_names, + args.use_raw, + args.gene_symbols_field, + args.min_n, ) - # Save the modified AnnData object or generate a file with cells as rows and the new score_names columns + # Save the modified AnnData object or generate a file with cells as rows + # and the new score_names columns if args.write_anndata: adata.write_h5ad(args.output_file) else: - new_columns = [col for col in adata.obs.columns if col.startswith("AUCell_")] + new_columns = [ + col for col in adata.obs.columns if col.startswith("AUCell_") + ] adata.obs[new_columns].to_csv(args.output_file, sep="\t", index=True)
--- a/decoupler_pathway_inference.py Tue Apr 16 11:49:19 2024 +0000 +++ b/decoupler_pathway_inference.py Mon Jul 15 10:56:37 2024 +0000 @@ -20,24 +20,34 @@ # output file prefix parser.add_argument( - "-o", "--output", + "-o", + "--output", help="output files prefix", default=None, ) # path to save Activities AnnData file parser.add_argument( - "-a", "--activities_path", help="Path to save Activities AnnData file", default=None + "-a", + "--activities_path", + help="Path to save Activities AnnData file", + default=None, ) # Column name in net with source nodes parser.add_argument( - "-s", "--source", help="Column name in net with source nodes.", default="source" + "-s", + "--source", + help="Column name in net with source nodes.", + default="source", ) # Column name in net with target nodes parser.add_argument( - "-t", "--target", help="Column name in net with target nodes.", default="target" + "-t", + "--target", + help="Column name in net with target nodes.", + default="target", ) # Column name in net with weights. @@ -47,17 +57,27 @@ # add boolean argument for use_raw parser.add_argument( - "--use_raw", action="store_true", default=False, help="Whether to use the raw part of the AnnData object" + "--use_raw", + action="store_true", + default=False, + help="Whether to use the raw part of the AnnData object", ) # add argument for min_cells parser.add_argument( - "--min_n", help="Minimum of targets per source. If less, sources are removed.", default=5, type=int + "--min_n", + help="Minimum of targets per source. If less, sources are removed.", + default=5, + type=int, ) # add activity inference method option parser.add_argument( - "-m", "--method", help="Activity inference method", default="mlm", required=True + "-m", + "--method", + help="Activity inference method", + default="mlm", + required=True, ) args = parser.parse_args() @@ -69,7 +89,7 @@ adata = ad.read_h5ad(args.input_anndata) # read in the input file network input file -network = pd.read_csv(args.input_network, sep='\t') +network = pd.read_csv(args.input_network, sep="\t") if ( args.source not in network.columns @@ -92,17 +112,21 @@ weight=args.weight, verbose=True, min_n=args.min_n, - use_raw=args.use_raw + use_raw=args.use_raw, ) if args.output is not None: - # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files - combined_df = pd.concat([adata.obsm["mlm_estimate"], adata.obsm["mlm_pvals"]], axis=1) + # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the + # output network files + combined_df = pd.concat( + [adata.obsm["mlm_estimate"], adata.obsm["mlm_pvals"]], axis=1 + ) # Save the combined dataframe to a file combined_df.to_csv(args.output + ".tsv", sep="\t") - # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path + # if args.activities_path is specified, generate the activities AnnData + # and save the AnnData object to the specified path if args.activities_path is not None: acts = dc.get_acts(adata, obsm_key="mlm_estimate") acts.write_h5ad(args.activities_path) @@ -116,17 +140,21 @@ weight=args.weight, verbose=True, min_n=args.min_n, - use_raw=args.use_raw + use_raw=args.use_raw, ) if args.output is not None: - # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the output network files - combined_df = pd.concat([adata.obsm["ulm_estimate"], adata.obsm["ulm_pvals"]], axis=1) + # write adata.obsm[mlm_key] and adata.obsm[mlm_pvals_key] to the + # output network files + combined_df = pd.concat( + [adata.obsm["ulm_estimate"], adata.obsm["ulm_pvals"]], axis=1 + ) # Save the combined dataframe to a file combined_df.to_csv(args.output + ".tsv", sep="\t") - # if args.activities_path is specified, generate the activities AnnData and save the AnnData object to the specified path + # if args.activities_path is specified, generate the activities AnnData + # and save the AnnData object to the specified path if args.activities_path is not None: acts = dc.get_acts(adata, obsm_key="ulm_estimate") acts.write_h5ad(args.activities_path)
--- a/decoupler_pseudobulk.py Tue Apr 16 11:49:19 2024 +0000 +++ b/decoupler_pseudobulk.py Mon Jul 15 10:56:37 2024 +0000 @@ -40,8 +40,108 @@ return index_value +def genes_to_ignore_per_contrast_field( + count_matrix_df, + samples_metadata, + sample_metadata_col_contrasts, + min_counts_per_sample=5, + use_cpms=False, +): + """ + # This function calculates the genes to ignore per contrast field + # (e.g., bulk_labels, louvain). + # It does this by first getting the count matrix for each group, + # then identifying genes with a count below a specified threshold. + # The genes to ignore are those that are present in more than a specified + # number of groups. + + >>> import pandas as pd + >>> samples_metadata = pd.DataFrame({'sample': + ... ['S1', 'S2', 'S3', + ... 'S4', 'S5', 'S6'], + ... 'contrast_field': + ... ['A', 'A', 'A', 'B', 'B', 'B']}) + >>> count_matrix_df = pd.DataFrame( + ... {'S1': + ... [30, 1, 40, 50, 30], + ... 'S2': + ... [40, 2, 60, 50, 80], + ... 'S3': + ... [80, 1, 60, 50, 50], + ... 'S4': [1, 50, 50, 50, 2], + ... 'S5': [3, 40, 40, 40, 2], + ... 'S6': [0, 50, 50, 50, 1]}) + >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5'] + >>> df = genes_to_ignore_per_contrast_field(count_matrix_df, + ... samples_metadata, min_counts_per_sample=5, + ... sample_metadata_col_contrasts='contrast_field') + >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0] + 'Gene2' + >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0] + 'Gene1' + >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1] + 'Gene5' + """ + + # Initialize a dictionary to store the genes to ignore per contrast field + contrast_fields = [] + genes_to_ignore = [] + + # Iterate over the contrast fields + for contrast_field in samples_metadata[ + sample_metadata_col_contrasts + ].unique(): + # Get the count matrix for the current contrast field + count_matrix_field = count_matrix_df.loc[ + :, + ( + samples_metadata[sample_metadata_col_contrasts] + == contrast_field + ).tolist(), + ] + + # We derive min_counts from the number of samples with that + # contrast_field value + min_counts = count_matrix_field.shape[1] * min_counts_per_sample + + if use_cpms: + # Convert counts to counts per million (CPM) + count_matrix_field = ( + count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0) + * 1e6 + ) + min_counts = 1 # use 1 CPM + + # Calculate the total number of cells in the current contrast field + # (this produces a vector of counts per gene) + total_counts_per_gene = count_matrix_field.sum(axis=1) + + # Identify genes with a count below the specified threshold + genes = total_counts_per_gene[ + total_counts_per_gene < min_counts + ].index.tolist() + if len(genes) > 0: + # genes_to_ignore[contrast_field] = " ".join(genes) + for gene in genes: + genes_to_ignore.append(gene) + contrast_fields.append(contrast_field) + # transform gene_to_ignore to a DataFrame + # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(), + # columns=["contrast_field", "genes_to_ignore"]) + genes_to_ignore_df = pd.DataFrame( + {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore} + ) + return genes_to_ignore_df + + # write results for loading into DESeq2 -def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): +def write_DESeq2_inputs( + pdata, + layer=None, + output_dir="", + factor_fields=None, + min_counts_per_sample_marking=20, +): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() @@ -62,7 +162,9 @@ # write obs to a col_metadata file if factor_fields: # only output the index plus the columns in factor_fields in that order - obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) + obs_for_deseq[factor_fields].to_csv( + col_metadata_file, sep="\t", index=True + ) else: obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) # write var to a gene_metadata file @@ -70,13 +172,28 @@ # write the counts matrix of a specified layer to file if layer is None: # write the X numpy matrix transposed to file - df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) + df = pd.DataFrame( + pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index + ) else: df = pd.DataFrame( - pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index + pdata.layers[layer].T, + index=pdata.var.index, + columns=obs_for_deseq.index, ) df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") + if factor_fields: + df_genes_ignore = genes_to_ignore_per_contrast_field( + count_matrix_df=df, + samples_metadata=obs_for_deseq, + sample_metadata_col_contrasts=factor_fields[0], + min_counts_per_sample=min_counts_per_sample_marking, + ) + df_genes_ignore.to_csv( + f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t" + ) + def plot_pseudobulk_samples( pseudobulk_data, @@ -89,7 +206,9 @@ >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) + >>> plot_pseudobulk_samples(pseudobulk, + ... groupby=["bulk_labels", "louvain"], + ... figsize=(10, 10)) """ fig = decoupler.plot_psbulk_samples( pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True @@ -101,14 +220,19 @@ def plot_filter_by_expr( - pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None + pseudobulk_data, + group, + min_count=None, + min_total_count=None, + save_path=None, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) + >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", + ... min_count=10, min_total_count=200) """ fig = decoupler.plot_filter_by_expr( pseudobulk_data, @@ -129,7 +253,8 @@ >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) + >>> pdata_filt = filter_by_expr(pseudobulk, + ... min_count=10, min_total_count=200) """ genes = decoupler.filter_by_expr( pdata, min_count=min_count, min_total_count=min_total_count @@ -150,12 +275,16 @@ if obs: if not set(fields).issubset(set(adata.obs.columns)): raise ValueError( - f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" + f"Some of the following fields {legend} are not present \ + in adata.obs: {fields}. \ + Possible fields are: {list(set(adata.obs.columns))}" ) else: if not set(fields).issubset(set(adata.var.columns)): raise ValueError( - f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" + f"Some of the following fields {legend} are not present \ + in adata.var: {fields}. \ + Possible fields are: {list(set(adata.var.columns))}" ) @@ -219,10 +348,15 @@ # Save the pseudobulk data if args.anndata_output_path: - pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") + pseudobulk_data.write_h5ad( + args.anndata_output_path, compression="gzip" + ) write_DESeq2_inputs( - pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields + pseudobulk_data, + output_dir=args.deseq2_output_path, + factor_fields=factor_fields, + min_counts_per_sample_marking=args.min_counts_per_sample_marking, ) @@ -254,7 +388,9 @@ field_name = "_".join(obs_fields_to_merge) for field in obs_fields_to_merge: if field not in adata.obs.columns: - raise ValueError(f"The '{field}' column is not present in adata.obs.") + raise ValueError( + f"The '{field}' column is not present in adata.obs." + ) if field_name not in adata.obs.columns: adata.obs[field_name] = adata.obs[field].astype(str) else: @@ -271,12 +407,16 @@ ) # Add arguments - parser.add_argument("adata_file", type=str, help="Path to the AnnData file") + parser.add_argument( + "adata_file", type=str, help="Path to the AnnData file" + ) parser.add_argument( "-m", "--adata_obs_fields_to_merge", type=str, - help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", + help="Fields in adata.obs to merge, comma separated. \ + You can have more than one set of fields, \ + separated by semi-colon ;", ) parser.add_argument( "--groupby", @@ -328,6 +468,13 @@ help="Minimum count threshold for filtering by expression", ) parser.add_argument( + "--min_counts_per_sample_marking", + type=int, + default=20, + help="Minimum count threshold per sample for \ + marking genes to be ignored after DE", + ) + parser.add_argument( "--min_total_counts", type=int, help="Minimum total count threshold for filtering by expression", @@ -338,7 +485,9 @@ help="Path to save the filtered AnnData object or pseudobulk data", ) parser.add_argument( - "--filter_expr", action="store_true", help="Enable filtering by expression" + "--filter_expr", + action="store_true", + help="Enable filtering by expression", ) parser.add_argument( "--factor_fields", @@ -358,7 +507,9 @@ nargs=2, help="Size of the samples plot as a tuple (two arguments)", ) - parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) + parser.add_argument( + "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2 + ) # Parse the command line arguments args = parser.parse_args()