Mercurial > repos > ebi-gxa > decoupler_pseudobulk
view decoupler_pseudobulk.py @ 15:09c833d9b03b draft default tip
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit b581a5b4ba88c5bf06f6223ba9aec51a8564796c
author | ebi-gxa |
---|---|
date | Fri, 29 Nov 2024 11:34:16 +0000 |
parents | a559be56720c |
children |
line wrap: on
line source
import argparse import anndata import decoupler import pandas as pd def get_pseudobulk( adata, sample_col, groups_col, layer=None, mode="sum", min_cells=10, min_counts=1000, use_raw=False, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") """ return decoupler.get_pseudobulk( adata, sample_col=sample_col, groups_col=groups_col, layer=layer, mode=mode, use_raw=use_raw, min_cells=min_cells, min_counts=min_counts, ) def prepend_c_to_index(index_value): if index_value and index_value[0].isdigit(): return "C" + index_value return index_value def genes_to_ignore_per_contrast_field( count_matrix_df, samples_metadata, sample_metadata_col_contrasts, min_counts_per_sample=5, use_cpms=False, ): """ # This function calculates the genes to ignore per contrast field # (e.g., bulk_labels, louvain). # It does this by first getting the count matrix for each group, # then identifying genes with a count below a specified threshold. # The genes to ignore are those that are present in more than a specified # number of groups. >>> import pandas as pd >>> samples_metadata = pd.DataFrame({'sample': ... ['S1', 'S2', 'S3', ... 'S4', 'S5', 'S6'], ... 'contrast_field': ... ['A', 'A', 'A', 'B', 'B', 'B']}) >>> count_matrix_df = pd.DataFrame( ... {'S1': ... [30, 1, 40, 50, 30], ... 'S2': ... [40, 2, 60, 50, 80], ... 'S3': ... [80, 1, 60, 50, 50], ... 'S4': [1, 50, 50, 50, 2], ... 'S5': [3, 40, 40, 40, 2], ... 'S6': [0, 50, 50, 50, 1]}) >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5'] >>> df = genes_to_ignore_per_contrast_field(count_matrix_df, ... samples_metadata, min_counts_per_sample=5, ... sample_metadata_col_contrasts='contrast_field') >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0] 'Gene2' >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0] 'Gene1' >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1] 'Gene5' """ # Initialize a dictionary to store the genes to ignore per contrast field contrast_fields = [] genes_to_ignore = [] # Iterate over the contrast fields for contrast_field in samples_metadata[ sample_metadata_col_contrasts ].unique(): # Get the count matrix for the current contrast field count_matrix_field = count_matrix_df.loc[ :, ( samples_metadata[sample_metadata_col_contrasts] == contrast_field ).tolist(), ] # We derive min_counts from the number of samples with that # contrast_field value min_counts = count_matrix_field.shape[1] * min_counts_per_sample if use_cpms: # Convert counts to counts per million (CPM) count_matrix_field = ( count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0) * 1e6 ) min_counts = 1 # use 1 CPM # Calculate the total number of cells in the current contrast field # (this produces a vector of counts per gene) total_counts_per_gene = count_matrix_field.sum(axis=1) # Identify genes with a count below the specified threshold genes = total_counts_per_gene[ total_counts_per_gene < min_counts ].index.tolist() if len(genes) > 0: # genes_to_ignore[contrast_field] = " ".join(genes) for gene in genes: genes_to_ignore.append(gene) contrast_fields.append(contrast_field) # transform gene_to_ignore to a DataFrame # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(), # columns=["contrast_field", "genes_to_ignore"]) genes_to_ignore_df = pd.DataFrame( {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore} ) return genes_to_ignore_df # write results for loading into DESeq2 def write_DESeq2_inputs( pdata, layer=None, output_dir="", factor_fields=None, min_counts_per_sample_marking=20, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") >>> write_DESeq2_inputs(pseudobulk) """ # add / to output_dir if is not empty or if it doesn't end with / if output_dir != "" and not output_dir.endswith("/"): output_dir = output_dir + "/" obs_for_deseq = pdata.obs.copy() # replace any index starting with digits to start with C instead. obs_for_deseq.rename(index=prepend_c_to_index, inplace=True) # avoid dash that is read as point on R colnames. obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_") obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_") col_metadata_file = f"{output_dir}col_metadata.tsv" # write obs to a col_metadata file if factor_fields: # only output the index plus the columns in factor_fields in that order obs_for_deseq[factor_fields].to_csv( col_metadata_file, sep="\t", index=True ) else: obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) # write var to a gene_metadata file pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True) # write the counts matrix of a specified layer to file if layer is None: # write the X numpy matrix transposed to file df = pd.DataFrame( pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index ) else: df = pd.DataFrame( pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index, ) df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") if factor_fields: df_genes_ignore = genes_to_ignore_per_contrast_field( count_matrix_df=df, samples_metadata=obs_for_deseq, sample_metadata_col_contrasts=factor_fields[0], min_counts_per_sample=min_counts_per_sample_marking, ) df_genes_ignore.to_csv( f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t" ) def plot_pseudobulk_samples( pseudobulk_data, groupby, figsize=(10, 10), save_path=None, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") >>> plot_pseudobulk_samples(pseudobulk, ... groupby=["bulk_labels", "louvain"], ... figsize=(10, 10)) """ fig = decoupler.plot_psbulk_samples( pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True ) if save_path: fig.savefig(f"{save_path}/pseudobulk_samples.png") else: fig.show() def plot_filter_by_expr( pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", ... min_count=10, min_total_count=200) """ fig = decoupler.plot_filter_by_expr( pseudobulk_data, group=group, min_count=min_count, min_total_count=min_total_count, return_fig=True, ) if save_path: fig.savefig(f"{save_path}/filter_by_expr.png") else: fig.show() def filter_by_expr(pdata, min_count=None, min_total_count=None): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") >>> pdata_filt = filter_by_expr(pseudobulk, ... min_count=10, min_total_count=200) """ genes = decoupler.filter_by_expr( pdata, min_count=min_count, min_total_count=min_total_count ) return pdata[:, genes].copy() def check_fields(fields, adata, obs=True, context=None): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> check_fields(["bulk_labels", "louvain"], adata, obs=True) """ legend = "" if context: legend = f", passed in {context}," if obs: if not set(fields).issubset(set(adata.obs.columns)): raise ValueError( f"Some of the following fields {legend} are not present \ in adata.obs: {fields}. \ Possible fields are: {list(set(adata.obs.columns))}" ) else: if not set(fields).issubset(set(adata.var.columns)): raise ValueError( f"Some of the following fields {legend} are not present \ in adata.var: {fields}. \ Possible fields are: {list(set(adata.var.columns))}" ) def main(args): # Load AnnData object from file adata = anndata.read_h5ad(args.adata_file) # Merge adata.obs fields specified in args.adata_obs_fields_to_merge if args.adata_obs_fields_to_merge: # first split potential groups by ":" and iterate over them for group in args.adata_obs_fields_to_merge.split(":"): fields = group.split(",") check_fields(fields, adata) merge_adata_obs_fields(fields, adata) check_fields([args.groupby, args.sample_key], adata) factor_fields = None if args.factor_fields: factor_fields = args.factor_fields.split(",") check_fields(factor_fields, adata) print(f"Using mode: {args.mode}") # Perform pseudobulk analysis pseudobulk_data = get_pseudobulk( adata, sample_col=args.sample_key, groups_col=args.groupby, layer=args.layer, mode=args.mode, use_raw=args.use_raw, min_cells=args.min_cells, min_counts=args.min_counts, ) print("Created pseudo-bulk AnnData, checking if fields still make sense.") print( "If this fails this check, it might mean that you asked for factors \ that are not compatible with you sample identifiers (ie. asked for \ phase in the factors, but each sample contains more than one phase,\ try joining fields)." ) if factor_fields: check_fields( factor_fields, pseudobulk_data, context=" after creation of pseudo-bulk AnnData", ) print("Factors requested are adequate for the pseudo-bulked AnnData!") # Plot pseudobulk samples plot_pseudobulk_samples( pseudobulk_data, args.groupby, save_path=args.save_path, figsize=args.plot_samples_figsize, ) plot_filter_by_expr( pseudobulk_data, group=args.groupby, min_count=args.min_counts, min_total_count=args.min_total_counts, save_path=args.save_path, ) # Filter by expression if enabled if args.filter_expr: filtered_adata = filter_by_expr( pseudobulk_data, min_count=args.min_counts, min_total_count=args.min_total_counts, ) pseudobulk_data = filtered_adata # Save the pseudobulk data if args.anndata_output_path: pseudobulk_data.write_h5ad( args.anndata_output_path, compression="gzip" ) write_DESeq2_inputs( pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields, min_counts_per_sample_marking=args.min_counts_per_sample_marking, ) # if contrasts file is provided, produce a file with genes that should be # filtered for each contrasts if args.contrasts_file: contrast_genes_df = identify_genes_to_filter_per_contrast( contrast_file=args.contrasts_file, min_perc_cells_expression=args.min_gene_exp_perc_per_cell, adata=adata, obs_field=args.groupby ) contrast_genes_df.to_csv( f"{args.save_path}/genes_to_filter_by_contrast.tsv", sep="\t", index=False, ) def merge_adata_obs_fields(obs_fields_to_merge, adata): """ Merge adata.obs fields specified in args.adata_obs_fields_to_merge Parameters ---------- obs_fields_to_merge : str Fields in adata.obs to merge, comma separated adata : anndata.AnnData The AnnData object Returns ------- anndata.AnnData The merged AnnData object docstring tests: >>> import scanpy as sc >>> ad = sc.datasets.pbmc68k_reduced() >>> merge_adata_obs_fields(["bulk_labels","louvain"], ad) >>> ad.obs.columns Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', 'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'], dtype='object') """ field_name = "_".join(obs_fields_to_merge) for field in obs_fields_to_merge: if field not in adata.obs.columns: raise ValueError( f"The '{field}' column is not present in adata.obs." ) if field_name not in adata.obs.columns: adata.obs[field_name] = adata.obs[field].astype(str) else: adata.obs[field_name] = ( adata.obs[field_name] + "_" + adata.obs[field].astype(str) ) def identify_genes_to_filter_per_contrast( contrast_file, min_perc_cells_expression, adata, obs_field ): """ Identify genes to filter per contrast based on expression percentage. We need those genes to be under the threshold for all conditions in a contrast to be identified for further filtering. If one condition has the gene expressed above the threshold, the gene becomes of interest (it can be highly up or down regulated). Parameters ---------- contrast_file : str Path to the contrasts file. min_perc_cells_expression : float Minimum percentage of cells that should express a gene. adata: adata Original AnnData file obs_field: str Field in the AnnData observations where the contrasts are defined. Returns ------- None Examples -------- >>> import anndata >>> import pandas as pd >>> import numpy as np >>> import os >>> from io import StringIO >>> contrast_file = StringIO(f"contrast{os.linesep}condition1-\ condition2{os.linesep}\ 2*(condition1)-condition2{os.linesep}") >>> min_perc_cells_expression = 30.0 >>> data = { ... 'obs': pd.DataFrame({'condition': ['condition1', 'condition1', ... 'condition2', 'condition2']}), ... 'X': np.array([[1, 0, 0, 0, 0], [0, 0, 2, 2, 0], ... [0, 0, 1, 1, 0], [0, 0, 0, 2, 0]]), ... } >>> adata = anndata.AnnData(X=data['X'], obs=data['obs']) >>> df = identify_genes_to_filter_per_contrast( ... contrast_file, min_perc_cells_expression, adata, 'condition' ... ) # doctest:+ELLIPSIS Identifying genes to filter using ... >>> df.head() # doctest:+ELLIPSIS contrast gene 0 condition1-condition2... 1 condition1-condition2... 2 2*(condition1)-condition2... 3 2*(condition1)-condition2... """ import re # Implement the logic to identify genes to filter per contrast # This is a placeholder implementation print( f"Identifying genes to filter using {contrast_file} " f"with min expression {min_perc_cells_expression}%" ) sides_regex = re.compile(r"[\+\-\*\/\(\)\^]+") contrasts = pd.read_csv(contrast_file, sep="\t") # Iterate over each line in the contrast file genes_filter_for_contrast = dict() for contrast in contrasts.iloc[:, 0]: conditions = set(sides_regex.split(contrast)) selected_conditions = [] failed_conditions = [] for condition in conditions: # remove any starting or trailing whitespaces from condition condition = condition.strip() if len(condition) == 0: continue # check if the condition is simply a number, then skip it if condition.isnumeric(): continue if condition not in adata.obs[obs_field].unique(): # add condition to failed_conditions failed_conditions.append(condition) continue # append to selected_conditions selected_conditions.append(condition) if len(failed_conditions) > 0: raise ValueError( f"Condition(s) '{failed_conditions}' " f"from contrast {contrast} " f"is/are not present in the " f"obs_field '{obs_field}' from the AnnData object." f"Possible values are: " f"{', '.join(adata.obs[obs_field].unique())}.") # we want to find the genes that are below the threshold # of % of cells expressed for ALL the conditions in the # contrast. It is enough for one of the conditions # of the contrast to have the genes expressed above # the threshold of % of cells to be of interest. for condition in selected_conditions: # check the percentage of cells that express each gene # Filter the AnnData object based on the obs_field value adata_filtered = adata[adata.obs[obs_field] == condition] # Calculate the percentage of cells expressing each gene gene_expression = (adata_filtered.X > 0).mean(axis=0) * 100 genes_to_filter = set(adata_filtered.var[ gene_expression.transpose() < min_perc_cells_expression ].index.tolist()) # Update the genes_filter_for_contrast dictionary if contrast in genes_filter_for_contrast.keys(): genes_filter_for_contrast[contrast].intersection_update( genes_to_filter ) else: genes_filter_for_contrast[contrast] = genes_to_filter # write the genes_filter_for_contrast to pandas dataframe of two columns: # contrast and gene # Initialize an empty list to store the expanded pairs expanded_pairs = [] # Iterate over the dictionary for contrast, genes in genes_filter_for_contrast.items(): for gene in genes: expanded_pairs.append((contrast, gene)) # Create the DataFrame contrast_genes_df = pd.DataFrame( expanded_pairs, columns=["contrast", "gene"] ) return contrast_genes_df if __name__ == "__main__": # Create argument parser parser = argparse.ArgumentParser( description="Perform pseudobulk analysis on an AnnData object" ) # Add arguments parser.add_argument( "adata_file", type=str, help="Path to the AnnData file" ) parser.add_argument( "-m", "--adata_obs_fields_to_merge", type=str, help="Fields in adata.obs to merge, comma separated. \ You can have more than one set of fields, \ separated by semi-colon ;", ) parser.add_argument( "--groupby", type=str, required=True, help="The column in adata.obs that defines the groups", ) parser.add_argument( "--sample_key", required=True, type=str, help="The column in adata.obs that defines the samples", ) # add argument for layer parser.add_argument( "--layer", type=str, default=None, help="The name of the layer of the AnnData object to use", ) # add argument for mode parser.add_argument( "--mode", type=str, default="sum", help="The mode for Decoupler pseudobulk analysis", choices=["sum", "mean", "median"], ) # add boolean argument for use_raw parser.add_argument( "--use_raw", action="store_true", default=False, help="Whether to use the raw part of the AnnData object", ) # add argument for min_cells parser.add_argument( "--min_cells", type=int, default=10, help="Minimum number of cells for pseudobulk analysis", ) # add argument for min percentage of cells that should express a gene parser.add_argument( "--min_gene_exp_perc_per_cell", type=float, default=50, help="If all the conditions of one side of a contrast express a \ gene in less than this percentage of cells, then the genes \ will be added to a list of genes to ignore for that contrast.\ Requires the contrast file to be provided.", ) parser.add_argument( "--contrasts_file", type=str, required=False, help="Contrasts file, a one column tsv with a header, each line \ represents a contrast as a combination of conditions at each \ side of a substraction.", ) parser.add_argument( "--save_path", type=str, help="Path to save the plot (optional)" ) parser.add_argument( "--min_counts", type=int, help="Minimum count threshold for filtering by expression", ) parser.add_argument( "--min_counts_per_sample_marking", type=int, default=20, help="Minimum count threshold per sample for \ marking genes to be ignored after DE", ) parser.add_argument( "--min_total_counts", type=int, help="Minimum total count threshold for filtering by expression", ) parser.add_argument( "--anndata_output_path", type=str, help="Path to save the filtered AnnData object or pseudobulk data", ) parser.add_argument( "--filter_expr", action="store_true", help="Enable filtering by expression", ) parser.add_argument( "--factor_fields", type=str, help="Comma separated list of fields for the factors", ) parser.add_argument( "--deseq2_output_path", type=str, help="Path to save the DESeq2 inputs", required=True, ) parser.add_argument( "--plot_samples_figsize", type=int, default=[10, 10], nargs=2, help="Size of the samples plot as a tuple (two arguments)", ) parser.add_argument( "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2 ) # Parse the command line arguments args = parser.parse_args() # Call the main function main(args)