diff decoupler_pseudobulk.py @ 8:93f61ea19336 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author ebi-gxa
date Mon, 15 Jul 2024 10:56:42 +0000
parents 130e25d3ce92
children bd4b54b75888
line wrap: on
line diff
--- a/decoupler_pseudobulk.py	Tue Apr 16 11:49:25 2024 +0000
+++ b/decoupler_pseudobulk.py	Mon Jul 15 10:56:42 2024 +0000
@@ -40,8 +40,108 @@
     return index_value
 
 
+def genes_to_ignore_per_contrast_field(
+    count_matrix_df,
+    samples_metadata,
+    sample_metadata_col_contrasts,
+    min_counts_per_sample=5,
+    use_cpms=False,
+):
+    """
+    # This function calculates the genes to ignore per contrast field
+    # (e.g., bulk_labels, louvain).
+    # It does this by first getting the count matrix for each group,
+    # then identifying genes with a count below a specified threshold.
+    # The genes to ignore are those that are present in more than a specified
+    # number of groups.
+
+    >>> import pandas as pd
+    >>> samples_metadata = pd.DataFrame({'sample':
+    ...                                    ['S1', 'S2', 'S3',
+    ...                                     'S4', 'S5', 'S6'],
+    ...                                  'contrast_field':
+    ...                                    ['A', 'A', 'A', 'B', 'B', 'B']})
+    >>> count_matrix_df = pd.DataFrame(
+    ...                       {'S1':
+    ...                          [30, 1, 40, 50, 30],
+    ...                        'S2':
+    ...                          [40, 2, 60, 50, 80],
+    ...                        'S3':
+    ...                          [80, 1, 60, 50, 50],
+    ...                        'S4': [1, 50, 50, 50, 2],
+    ...                        'S5': [3, 40, 40, 40, 2],
+    ...                        'S6': [0, 50, 50, 50, 1]})
+    >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5']
+    >>> df = genes_to_ignore_per_contrast_field(count_matrix_df,
+    ...             samples_metadata, min_counts_per_sample=5,
+    ...             sample_metadata_col_contrasts='contrast_field')
+    >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0]
+    'Gene2'
+    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0]
+    'Gene1'
+    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1]
+    'Gene5'
+    """
+
+    # Initialize a dictionary to store the genes to ignore per contrast field
+    contrast_fields = []
+    genes_to_ignore = []
+
+    # Iterate over the contrast fields
+    for contrast_field in samples_metadata[
+        sample_metadata_col_contrasts
+    ].unique():
+        # Get the count matrix for the current contrast field
+        count_matrix_field = count_matrix_df.loc[
+            :,
+            (
+                samples_metadata[sample_metadata_col_contrasts]
+                == contrast_field
+            ).tolist(),
+        ]
+
+        # We derive min_counts from the number of samples with that
+        # contrast_field value
+        min_counts = count_matrix_field.shape[1] * min_counts_per_sample
+
+        if use_cpms:
+            # Convert counts to counts per million (CPM)
+            count_matrix_field = (
+                count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0)
+                * 1e6
+            )
+            min_counts = 1  # use 1 CPM
+
+        # Calculate the total number of cells in the current contrast field
+        # (this produces a vector of counts per gene)
+        total_counts_per_gene = count_matrix_field.sum(axis=1)
+
+        # Identify genes with a count below the specified threshold
+        genes = total_counts_per_gene[
+            total_counts_per_gene < min_counts
+        ].index.tolist()
+        if len(genes) > 0:
+            # genes_to_ignore[contrast_field] = " ".join(genes)
+            for gene in genes:
+                genes_to_ignore.append(gene)
+                contrast_fields.append(contrast_field)
+    # transform gene_to_ignore to a DataFrame
+    # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(),
+    #                           columns=["contrast_field", "genes_to_ignore"])
+    genes_to_ignore_df = pd.DataFrame(
+        {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore}
+    )
+    return genes_to_ignore_df
+
+
 # write results for loading into DESeq2
-def write_DESeq2_inputs(pdata, layer=None, output_dir="", factor_fields=None):
+def write_DESeq2_inputs(
+    pdata,
+    layer=None,
+    output_dir="",
+    factor_fields=None,
+    min_counts_per_sample_marking=20,
+):
     """
     >>> import scanpy as sc
     >>> adata = sc.datasets.pbmc68k_reduced()
@@ -62,7 +162,9 @@
     # write obs to a col_metadata file
     if factor_fields:
         # only output the index plus the columns in factor_fields in that order
-        obs_for_deseq[factor_fields].to_csv(col_metadata_file, sep="\t", index=True)
+        obs_for_deseq[factor_fields].to_csv(
+            col_metadata_file, sep="\t", index=True
+        )
     else:
         obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
     # write var to a gene_metadata file
@@ -70,13 +172,28 @@
     # write the counts matrix of a specified layer to file
     if layer is None:
         # write the X numpy matrix transposed to file
-        df = pd.DataFrame(pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index)
+        df = pd.DataFrame(
+            pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index
+        )
     else:
         df = pd.DataFrame(
-            pdata.layers[layer].T, index=pdata.var.index, columns=obs_for_deseq.index
+            pdata.layers[layer].T,
+            index=pdata.var.index,
+            columns=obs_for_deseq.index,
         )
     df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")
 
+    if factor_fields:
+        df_genes_ignore = genes_to_ignore_per_contrast_field(
+            count_matrix_df=df,
+            samples_metadata=obs_for_deseq,
+            sample_metadata_col_contrasts=factor_fields[0],
+            min_counts_per_sample=min_counts_per_sample_marking,
+        )
+        df_genes_ignore.to_csv(
+            f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t"
+        )
+
 
 def plot_pseudobulk_samples(
     pseudobulk_data,
@@ -89,7 +206,9 @@
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> plot_pseudobulk_samples(pseudobulk, groupby=["bulk_labels", "louvain"], figsize=(10, 10))
+    >>> plot_pseudobulk_samples(pseudobulk,
+    ...                         groupby=["bulk_labels", "louvain"],
+    ...                         figsize=(10, 10))
     """
     fig = decoupler.plot_psbulk_samples(
         pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
@@ -101,14 +220,19 @@
 
 
 def plot_filter_by_expr(
-    pseudobulk_data, group, min_count=None, min_total_count=None, save_path=None
+    pseudobulk_data,
+    group,
+    min_count=None,
+    min_total_count=None,
+    save_path=None,
 ):
     """
     >>> import scanpy as sc
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels", min_count=10, min_total_count=200)
+    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels",
+    ...                     min_count=10, min_total_count=200)
     """
     fig = decoupler.plot_filter_by_expr(
         pseudobulk_data,
@@ -129,7 +253,8 @@
     >>> adata = sc.datasets.pbmc68k_reduced()
     >>> adata.X = abs(adata.X).astype(int)
     >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
-    >>> pdata_filt = filter_by_expr(pseudobulk, min_count=10, min_total_count=200)
+    >>> pdata_filt = filter_by_expr(pseudobulk,
+    ...                             min_count=10, min_total_count=200)
     """
     genes = decoupler.filter_by_expr(
         pdata, min_count=min_count, min_total_count=min_total_count
@@ -150,12 +275,16 @@
     if obs:
         if not set(fields).issubset(set(adata.obs.columns)):
             raise ValueError(
-                f"Some of the following fields {legend} are not present in adata.obs: {fields}. Possible fields are: {list(set(adata.obs.columns))}"
+                f"Some of the following fields {legend} are not present \
+                    in adata.obs: {fields}. \
+                        Possible fields are: {list(set(adata.obs.columns))}"
             )
     else:
         if not set(fields).issubset(set(adata.var.columns)):
             raise ValueError(
-                f"Some of the following fields {legend} are not present in adata.var: {fields}. Possible fields are: {list(set(adata.var.columns))}"
+                f"Some of the following fields {legend} are not present \
+                    in adata.var: {fields}. \
+                        Possible fields are: {list(set(adata.var.columns))}"
             )
 
 
@@ -219,10 +348,15 @@
 
     # Save the pseudobulk data
     if args.anndata_output_path:
-        pseudobulk_data.write_h5ad(args.anndata_output_path, compression="gzip")
+        pseudobulk_data.write_h5ad(
+            args.anndata_output_path, compression="gzip"
+        )
 
     write_DESeq2_inputs(
-        pseudobulk_data, output_dir=args.deseq2_output_path, factor_fields=factor_fields
+        pseudobulk_data,
+        output_dir=args.deseq2_output_path,
+        factor_fields=factor_fields,
+        min_counts_per_sample_marking=args.min_counts_per_sample_marking,
     )
 
 
@@ -254,7 +388,9 @@
     field_name = "_".join(obs_fields_to_merge)
     for field in obs_fields_to_merge:
         if field not in adata.obs.columns:
-            raise ValueError(f"The '{field}' column is not present in adata.obs.")
+            raise ValueError(
+                f"The '{field}' column is not present in adata.obs."
+            )
         if field_name not in adata.obs.columns:
             adata.obs[field_name] = adata.obs[field].astype(str)
         else:
@@ -271,12 +407,16 @@
     )
 
     # Add arguments
-    parser.add_argument("adata_file", type=str, help="Path to the AnnData file")
+    parser.add_argument(
+        "adata_file", type=str, help="Path to the AnnData file"
+    )
     parser.add_argument(
         "-m",
         "--adata_obs_fields_to_merge",
         type=str,
-        help="Fields in adata.obs to merge, comma separated. You can have more than one set of fields, separated by semi-colon ;",
+        help="Fields in adata.obs to merge, comma separated. \
+            You can have more than one set of fields, \
+                separated by semi-colon ;",
     )
     parser.add_argument(
         "--groupby",
@@ -328,6 +468,13 @@
         help="Minimum count threshold for filtering by expression",
     )
     parser.add_argument(
+        "--min_counts_per_sample_marking",
+        type=int,
+        default=20,
+        help="Minimum count threshold per sample for \
+            marking genes to be ignored after DE",
+    )
+    parser.add_argument(
         "--min_total_counts",
         type=int,
         help="Minimum total count threshold for filtering by expression",
@@ -338,7 +485,9 @@
         help="Path to save the filtered AnnData object or pseudobulk data",
     )
     parser.add_argument(
-        "--filter_expr", action="store_true", help="Enable filtering by expression"
+        "--filter_expr",
+        action="store_true",
+        help="Enable filtering by expression",
     )
     parser.add_argument(
         "--factor_fields",
@@ -358,7 +507,9 @@
         nargs=2,
         help="Size of the samples plot as a tuple (two arguments)",
     )
-    parser.add_argument("--plot_filtering_figsize", type=int, default=[10, 10], nargs=2)
+    parser.add_argument(
+        "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2
+    )
 
     # Parse the command line arguments
     args = parser.parse_args()