Mercurial > repos > ebi-gxa > decoupler_pseudobulk
diff decoupler_pseudobulk.py @ 8:93f61ea19336 draft
planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author | ebi-gxa |
---|---|
date | Mon, 15 Jul 2024 10:56:42 +0000 |
parents | 130e25d3ce92 |
children | bd4b54b75888 |
line wrap: on
line diff
--- a/decoupler_pseudobulk.py Tue Apr 16 11:49:25 2024 +0000 +++ b/decoupler_pseudobulk.py Mon Jul 15 10:56:42 2024 +0000 @@ -40,8 +40,108 @@ return index_value +def genes_to_ignore_per_contrast_field( + count_matrix_df, + samples_metadata, + sample_metadata_col_contrasts, + min_counts_per_sample=5, + use_cpms=False, +): + """ + # This function calculates the genes to ignore per contrast field + # (e.g., bulk_labels, louvain). + # It does this by first getting the count matrix for each group, + # then identifying genes with a count below a specified threshold. + # The genes to ignore are those that are present in more than a specified + # number of groups. + + >>> import pandas as pd + >>> samples_metadata = pd.DataFrame({'sample': + ... ['S1', 'S2', 'S3', + ... 'S4', 'S5', 'S6'], + ... 'contrast_field': + ... ['A', 'A', 'A', 'B', 'B', 'B']}) + >>> count_matrix_df = pd.DataFrame( + ... {'S1': + ... [30, 1, 40, 50, 30], + ... 'S2': + ... [40, 2, 60, 50, 80], + ... 'S3': + ... [80, 1, 60, 50, 50], + ... 'S4': [1, 50, 50, 50, 2], + ... 'S5': [3, 40, 40, 40, 2], + ... 'S6': [0, 50, 50, 50, 1]}) + >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5'] + >>> df = genes_to_ignore_per_contrast_field(count_matrix_df, + ... samples_metadata, min_counts_per_sample=5, + ... sample_metadata_col_contrasts='contrast_field') + >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0] + 'Gene2' + >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0] + 'Gene1' + >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1] + 'Gene5' + """ + + # Initialize a dictionary to store the genes to ignore per contrast field + contrast_fields = [] + genes_to_ignore = [] + + # Iterate over the contrast fields + for contrast_field in samples_metadata[ + sample_metadata_col_contrasts + ].unique(): + # Get the count matrix for the current contrast field + count_matrix_field = count_matrix_df.loc[ + :, + ( + samples_metadata[sample_metadata_col_contrasts] + == contrast_field + ).tolist(), + ] + + # We derive min_counts from the number of samples with that + # contrast_field value + min_counts = count_matrix_field.shape[1] * min_counts_per_sample + + if use_cpms: + # Convert counts to counts per million (CPM) + count_matrix_field = ( + count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0) + * 1e6 + ) + min_counts = 1 # use 1 CPM + + # Calculate the total number of cells in the current contrast field + # (this produces a vector of counts per gene) + total_counts_per_gene = count_matrix_field.sum(axis=1) + + # Identify genes with a count below the specified threshold + genes = total_counts_per_gene[ + total_counts_per_gene < min_counts + ].index.tolist() + if len(genes) > 0: + # genes_to_ignore[contrast_field] = " ".join(genes) + for gene in genes: + genes_to_ignore.append(gene) + contrast_fields.append(contrast_field) + # transform gene_to_ignore to a DataFrame + # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(), + # columns=["contrast_field", "genes_to_ignore"]) + genes_to_ignore_df = pd.DataFrame( + {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore} + ) + return genes_to_ignore_df + + # write results for loading into DESeq2 -def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None): +def write_DESeq2_inputs( + pdata, + layer=None, + output_dir="", + factor_fields=None, + min_counts_per_sample_marking=20, +): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() @@ -62,7 +162,9 @@ # write obs to a col_metadata file if factor_fields: # only output the index plus the columns in factor_fields in that order - obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True) + obs_for_deseq[factor_fields].to_csv( + col_metadata_file, sep="\t", index=True + ) else: obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True) # write var to a gene_metadata file @@ -70,13 +172,28 @@ # write the counts matrix of a specified layer to file if layer is None: # write the X numpy matrix transposed to file - df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index) + df = pd.DataFrame( + pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index + ) else: df = pd.DataFrame( - pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index + pdata.layers[layer].T, + index=pdata.var.index, + columns=obs_for_deseq.index, ) df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="") + if factor_fields: + df_genes_ignore = genes_to_ignore_per_contrast_field( + count_matrix_df=df, + samples_metadata=obs_for_deseq, + sample_metadata_col_contrasts=factor_fields[0], + min_counts_per_sample=min_counts_per_sample_marking, + ) + df_genes_ignore.to_csv( + f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t" + ) + def plot_pseudobulk_samples( pseudobulk_data, @@ -89,7 +206,9 @@ >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10)) + >>> plot_pseudobulk_samples(pseudobulk, + ... groupby=["bulk_labels", "louvain"], + ... figsize=(10, 10)) """ fig = decoupler.plot_psbulk_samples( pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True @@ -101,14 +220,19 @@ def plot_filter_by_expr( - pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None + pseudobulk_data, + group, + min_count=None, + min_total_count=None, + save_path=None, ): """ >>> import scanpy as sc >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200) + >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", + ... min_count=10, min_total_count=200) """ fig = decoupler.plot_filter_by_expr( pseudobulk_data, @@ -129,7 +253,8 @@ >>> adata = sc.datasets.pbmc68k_reduced() >>> adata.X = abs(adata.X).astype(int) >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain") - >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200) + >>> pdata_filt = filter_by_expr(pseudobulk, + ... min_count=10, min_total_count=200) """ genes = decoupler.filter_by_expr( pdata, min_count=min_count, min_total_count=min_total_count @@ -150,12 +275,16 @@ if obs: if not set(fields).issubset(set(adata.obs.columns)): raise ValueError( - f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}" + f"Some of the following fields {legend} are not present \ + in adata.obs: {fields}. \ + Possible fields are: {list(set(adata.obs.columns))}" ) else: if not set(fields).issubset(set(adata.var.columns)): raise ValueError( - f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}" + f"Some of the following fields {legend} are not present \ + in adata.var: {fields}. \ + Possible fields are: {list(set(adata.var.columns))}" ) @@ -219,10 +348,15 @@ # Save the pseudobulk data if args.anndata_output_path: - pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip") + pseudobulk_data.write_h5ad( + args.anndata_output_path, compression="gzip" + ) write_DESeq2_inputs( - pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields + pseudobulk_data, + output_dir=args.deseq2_output_path, + factor_fields=factor_fields, + min_counts_per_sample_marking=args.min_counts_per_sample_marking, ) @@ -254,7 +388,9 @@ field_name = "_".join(obs_fields_to_merge) for field in obs_fields_to_merge: if field not in adata.obs.columns: - raise ValueError(f"The '{field}' column is not present in adata.obs.") + raise ValueError( + f"The '{field}' column is not present in adata.obs." + ) if field_name not in adata.obs.columns: adata.obs[field_name] = adata.obs[field].astype(str) else: @@ -271,12 +407,16 @@ ) # Add arguments - parser.add_argument("adata_file", type=str, help="Path to the AnnData file") + parser.add_argument( + "adata_file", type=str, help="Path to the AnnData file" + ) parser.add_argument( "-m", "--adata_obs_fields_to_merge", type=str, - help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;", + help="Fields in adata.obs to merge, comma separated. \ + You can have more than one set of fields, \ + separated by semi-colon ;", ) parser.add_argument( "--groupby", @@ -328,6 +468,13 @@ help="Minimum count threshold for filtering by expression", ) parser.add_argument( + "--min_counts_per_sample_marking", + type=int, + default=20, + help="Minimum count threshold per sample for \ + marking genes to be ignored after DE", + ) + parser.add_argument( "--min_total_counts", type=int, help="Minimum total count threshold for filtering by expression", @@ -338,7 +485,9 @@ help="Path to save the filtered AnnData object or pseudobulk data", ) parser.add_argument( - "--filter_expr", action="store_true", help="Enable filtering by expression" + "--filter_expr", + action="store_true", + help="Enable filtering by expression", ) parser.add_argument( "--factor_fields", @@ -358,7 +507,9 @@ nargs=2, help="Size of the samples plot as a tuple (two arguments)", ) - parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2) + parser.add_argument( + "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2 + ) # Parse the command line arguments args = parser.parse_args()