view decoupler_pseudobulk.py @ 3:c6787c2aee46 draft

planemo upload for repository https://github.com/ebi-gene-expression-group/container-galaxy-sc-tertiary/ commit eea5c13f9e6e070a2359c59400773b01f9cd7567
author ebi-gxa
date Mon, 15 Jul 2024 10:56:37 +0000
parents 77d680b36e23
children 6c30272fb587
line wrap: on
line source

import argparse

import anndata
import decoupler
import pandas as pd


def get_pseudobulk(
    adata,
    sample_col,
    groups_col,
    layer=None,
    mode="sum",
    min_cells=10,
    min_counts=1000,
    use_raw=False,
):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> adata.X = abs(adata.X).astype(int)
    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
    """

    return decoupler.get_pseudobulk(
        adata,
        sample_col=sample_col,
        groups_col=groups_col,
        layer=layer,
        mode=mode,
        use_raw=use_raw,
        min_cells=min_cells,
        min_counts=min_counts,
    )


def prepend_c_to_index(index_value):
    if index_value and index_value[0].isdigit():
        return "C" + index_value
    return index_value


def genes_to_ignore_per_contrast_field(
    count_matrix_df,
    samples_metadata,
    sample_metadata_col_contrasts,
    min_counts_per_sample=5,
    use_cpms=False,
):
    """
    # This function calculates the genes to ignore per contrast field
    # (e.g., bulk_labels, louvain).
    # It does this by first getting the count matrix for each group,
    # then identifying genes with a count below a specified threshold.
    # The genes to ignore are those that are present in more than a specified
    # number of groups.

    >>> import pandas as pd
    >>> samples_metadata = pd.DataFrame({'sample':
    ...                                    ['S1', 'S2', 'S3',
    ...                                     'S4', 'S5', 'S6'],
    ...                                  'contrast_field':
    ...                                    ['A', 'A', 'A', 'B', 'B', 'B']})
    >>> count_matrix_df = pd.DataFrame(
    ...                       {'S1':
    ...                          [30, 1, 40, 50, 30],
    ...                        'S2':
    ...                          [40, 2, 60, 50, 80],
    ...                        'S3':
    ...                          [80, 1, 60, 50, 50],
    ...                        'S4': [1, 50, 50, 50, 2],
    ...                        'S5': [3, 40, 40, 40, 2],
    ...                        'S6': [0, 50, 50, 50, 1]})
    >>> count_matrix_df.index = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5']
    >>> df = genes_to_ignore_per_contrast_field(count_matrix_df,
    ...             samples_metadata, min_counts_per_sample=5,
    ...             sample_metadata_col_contrasts='contrast_field')
    >>> df[df['contrast_field'] == 'A'].genes_to_ignore.tolist()[0]
    'Gene2'
    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[0]
    'Gene1'
    >>> df[df['contrast_field'] == 'B'].genes_to_ignore.tolist()[1]
    'Gene5'
    """

    # Initialize a dictionary to store the genes to ignore per contrast field
    contrast_fields = []
    genes_to_ignore = []

    # Iterate over the contrast fields
    for contrast_field in samples_metadata[
        sample_metadata_col_contrasts
    ].unique():
        # Get the count matrix for the current contrast field
        count_matrix_field = count_matrix_df.loc[
            :,
            (
                samples_metadata[sample_metadata_col_contrasts]
                == contrast_field
            ).tolist(),
        ]

        # We derive min_counts from the number of samples with that
        # contrast_field value
        min_counts = count_matrix_field.shape[1] * min_counts_per_sample

        if use_cpms:
            # Convert counts to counts per million (CPM)
            count_matrix_field = (
                count_matrix_field.div(count_matrix_field.sum(axis=1), axis=0)
                * 1e6
            )
            min_counts = 1  # use 1 CPM

        # Calculate the total number of cells in the current contrast field
        # (this produces a vector of counts per gene)
        total_counts_per_gene = count_matrix_field.sum(axis=1)

        # Identify genes with a count below the specified threshold
        genes = total_counts_per_gene[
            total_counts_per_gene < min_counts
        ].index.tolist()
        if len(genes) > 0:
            # genes_to_ignore[contrast_field] = " ".join(genes)
            for gene in genes:
                genes_to_ignore.append(gene)
                contrast_fields.append(contrast_field)
    # transform gene_to_ignore to a DataFrame
    # genes_to_ignore_df = pd.DataFrame(genes_to_ignore.items(),
    #                           columns=["contrast_field", "genes_to_ignore"])
    genes_to_ignore_df = pd.DataFrame(
        {"contrast_field": contrast_fields, "genes_to_ignore": genes_to_ignore}
    )
    return genes_to_ignore_df


# write results for loading into DESeq2
def write_DESeq2_inputs(
    pdata,
    layer=None,
    output_dir="",
    factor_fields=None,
    min_counts_per_sample_marking=20,
):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> adata.X = abs(adata.X).astype(int)
    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
    >>> write_DESeq2_inputs(pseudobulk)
    """
    # add / to output_dir if is not empty or if it doesn't end with /
    if output_dir != "" and not output_dir.endswith("/"):
        output_dir = output_dir + "/"
    obs_for_deseq = pdata.obs.copy()
    # replace any index starting with digits to start with C instead.
    obs_for_deseq.rename(index=prepend_c_to_index, inplace=True)
    # avoid dash that is read as point on R colnames.
    obs_for_deseq.index = obs_for_deseq.index.str.replace("-", "_")
    obs_for_deseq.index = obs_for_deseq.index.str.replace(" ", "_")
    col_metadata_file = f"{output_dir}col_metadata.tsv"
    # write obs to a col_metadata file
    if factor_fields:
        # only output the index plus the columns in factor_fields in that order
        obs_for_deseq[factor_fields].to_csv(
            col_metadata_file, sep="\t", index=True
        )
    else:
        obs_for_deseq.to_csv(col_metadata_file, sep="\t", index=True)
    # write var to a gene_metadata file
    pdata.var.to_csv(f"{output_dir}gene_metadata.tsv", sep="\t", index=True)
    # write the counts matrix of a specified layer to file
    if layer is None:
        # write the X numpy matrix transposed to file
        df = pd.DataFrame(
            pdata.X.T, index=pdata.var.index, columns=obs_for_deseq.index
        )
    else:
        df = pd.DataFrame(
            pdata.layers[layer].T,
            index=pdata.var.index,
            columns=obs_for_deseq.index,
        )
    df.to_csv(f"{output_dir}counts_matrix.tsv", sep="\t", index_label="")

    if factor_fields:
        df_genes_ignore = genes_to_ignore_per_contrast_field(
            count_matrix_df=df,
            samples_metadata=obs_for_deseq,
            sample_metadata_col_contrasts=factor_fields[0],
            min_counts_per_sample=min_counts_per_sample_marking,
        )
        df_genes_ignore.to_csv(
            f"{output_dir}genes_to_ignore_per_contrast_field.tsv", sep="\t"
        )


def plot_pseudobulk_samples(
    pseudobulk_data,
    groupby,
    figsize=(10, 10),
    save_path=None,
):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> adata.X = abs(adata.X).astype(int)
    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
    >>> plot_pseudobulk_samples(pseudobulk,
    ...                         groupby=["bulk_labels", "louvain"],
    ...                         figsize=(10, 10))
    """
    fig = decoupler.plot_psbulk_samples(
        pseudobulk_data, groupby=groupby, figsize=figsize, return_fig=True
    )
    if save_path:
        fig.savefig(f"{save_path}/pseudobulk_samples.png")
    else:
        fig.show()


def plot_filter_by_expr(
    pseudobulk_data,
    group,
    min_count=None,
    min_total_count=None,
    save_path=None,
):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> adata.X = abs(adata.X).astype(int)
    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
    >>> plot_filter_by_expr(pseudobulk, group="bulk_labels",
    ...                     min_count=10, min_total_count=200)
    """
    fig = decoupler.plot_filter_by_expr(
        pseudobulk_data,
        group=group,
        min_count=min_count,
        min_total_count=min_total_count,
        return_fig=True,
    )
    if save_path:
        fig.savefig(f"{save_path}/filter_by_expr.png")
    else:
        fig.show()


def filter_by_expr(pdata, min_count=None, min_total_count=None):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> adata.X = abs(adata.X).astype(int)
    >>> pseudobulk = get_pseudobulk(adata, "bulk_labels", "louvain")
    >>> pdata_filt = filter_by_expr(pseudobulk,
    ...                             min_count=10, min_total_count=200)
    """
    genes = decoupler.filter_by_expr(
        pdata, min_count=min_count, min_total_count=min_total_count
    )
    return pdata[:, genes].copy()


def check_fields(fields, adata, obs=True, context=None):
    """
    >>> import scanpy as sc
    >>> adata = sc.datasets.pbmc68k_reduced()
    >>> check_fields(["bulk_labels", "louvain"], adata, obs=True)
    """

    legend = ""
    if context:
        legend = f", passed in {context},"
    if obs:
        if not set(fields).issubset(set(adata.obs.columns)):
            raise ValueError(
                f"Some of the following fields {legend} are not present \
                    in adata.obs: {fields}. \
                        Possible fields are: {list(set(adata.obs.columns))}"
            )
    else:
        if not set(fields).issubset(set(adata.var.columns)):
            raise ValueError(
                f"Some of the following fields {legend} are not present \
                    in adata.var: {fields}. \
                        Possible fields are: {list(set(adata.var.columns))}"
            )


def main(args):
    # Load AnnData object from file
    adata = anndata.read_h5ad(args.adata_file)

    # Merge adata.obs fields specified in args.adata_obs_fields_to_merge
    if args.adata_obs_fields_to_merge:
        # first split potential groups by ":" and iterate over them
        for group in args.adata_obs_fields_to_merge.split(":"):
            fields = group.split(",")
            check_fields(fields, adata)
            adata = merge_adata_obs_fields(fields, adata)

    check_fields([args.groupby, args.sample_key], adata)

    factor_fields = None
    if args.factor_fields:
        factor_fields = args.factor_fields.split(",")
        check_fields(factor_fields, adata)

    print(f"Using mode: {args.mode}")
    # Perform pseudobulk analysis
    pseudobulk_data = get_pseudobulk(
        adata,
        sample_col=args.sample_key,
        groups_col=args.groupby,
        layer=args.layer,
        mode=args.mode,
        use_raw=args.use_raw,
        min_cells=args.min_cells,
        min_counts=args.min_counts,
    )

    # Plot pseudobulk samples
    plot_pseudobulk_samples(
        pseudobulk_data,
        args.groupby,
        save_path=args.save_path,
        figsize=args.plot_samples_figsize,
    )

    plot_filter_by_expr(
        pseudobulk_data,
        group=args.groupby,
        min_count=args.min_counts,
        min_total_count=args.min_total_counts,
        save_path=args.save_path,
    )

    # Filter by expression if enabled
    if args.filter_expr:
        filtered_adata = filter_by_expr(
            pseudobulk_data,
            min_count=args.min_counts,
            min_total_count=args.min_total_counts,
        )

        pseudobulk_data = filtered_adata

    # Save the pseudobulk data
    if args.anndata_output_path:
        pseudobulk_data.write_h5ad(
            args.anndata_output_path, compression="gzip"
        )

    write_DESeq2_inputs(
        pseudobulk_data,
        output_dir=args.deseq2_output_path,
        factor_fields=factor_fields,
        min_counts_per_sample_marking=args.min_counts_per_sample_marking,
    )


def merge_adata_obs_fields(obs_fields_to_merge, adata):
    """
    Merge adata.obs fields specified in args.adata_obs_fields_to_merge

    Parameters
    ----------
    obs_fields_to_merge : str
        Fields in adata.obs to merge, comma separated
    adata : anndata.AnnData
        The AnnData object

    Returns
    -------
    anndata.AnnData
        The merged AnnData object

    docstring tests:
    >>> import scanpy as sc
    >>> ad = sc.datasets.pbmc68k_reduced()
    >>> ad = merge_adata_obs_fields(["bulk_labels","louvain"], ad)
    >>> ad.obs.columns
    Index(['bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score',
           'G2M_score', 'phase', 'louvain', 'bulk_labels_louvain'],
          dtype='object')
    """
    field_name = "_".join(obs_fields_to_merge)
    for field in obs_fields_to_merge:
        if field not in adata.obs.columns:
            raise ValueError(
                f"The '{field}' column is not present in adata.obs."
            )
        if field_name not in adata.obs.columns:
            adata.obs[field_name] = adata.obs[field].astype(str)
        else:
            adata.obs[field_name] = (
                adata.obs[field_name] + "_" + adata.obs[field].astype(str)
            )
    return adata


if __name__ == "__main__":
    # Create argument parser
    parser = argparse.ArgumentParser(
        description="Perform pseudobulk analysis on an AnnData object"
    )

    # Add arguments
    parser.add_argument(
        "adata_file", type=str, help="Path to the AnnData file"
    )
    parser.add_argument(
        "-m",
        "--adata_obs_fields_to_merge",
        type=str,
        help="Fields in adata.obs to merge, comma separated. \
            You can have more than one set of fields, \
                separated by semi-colon ;",
    )
    parser.add_argument(
        "--groupby",
        type=str,
        required=True,
        help="The column in adata.obs that defines the groups",
    )
    parser.add_argument(
        "--sample_key",
        required=True,
        type=str,
        help="The column in adata.obs that defines the samples",
    )
    # add argument for layer
    parser.add_argument(
        "--layer",
        type=str,
        default=None,
        help="The name of the layer of the AnnData object to use",
    )
    # add argument for mode
    parser.add_argument(
        "--mode",
        type=str,
        default="sum",
        help="The mode for Decoupler pseudobulk analysis",
        choices=["sum", "mean", "median"],
    )
    # add boolean argument for use_raw
    parser.add_argument(
        "--use_raw",
        action="store_true",
        default=False,
        help="Whether to use the raw part of the AnnData object",
    )
    # add argument for min_cells
    parser.add_argument(
        "--min_cells",
        type=int,
        default=10,
        help="Minimum number of cells for pseudobulk analysis",
    )
    parser.add_argument(
        "--save_path", type=str, help="Path to save the plot (optional)"
    )
    parser.add_argument(
        "--min_counts",
        type=int,
        help="Minimum count threshold for filtering by expression",
    )
    parser.add_argument(
        "--min_counts_per_sample_marking",
        type=int,
        default=20,
        help="Minimum count threshold per sample for \
            marking genes to be ignored after DE",
    )
    parser.add_argument(
        "--min_total_counts",
        type=int,
        help="Minimum total count threshold for filtering by expression",
    )
    parser.add_argument(
        "--anndata_output_path",
        type=str,
        help="Path to save the filtered AnnData object or pseudobulk data",
    )
    parser.add_argument(
        "--filter_expr",
        action="store_true",
        help="Enable filtering by expression",
    )
    parser.add_argument(
        "--factor_fields",
        type=str,
        help="Comma separated list of fields for the factors",
    )
    parser.add_argument(
        "--deseq2_output_path",
        type=str,
        help="Path to save the DESeq2 inputs",
        required=True,
    )
    parser.add_argument(
        "--plot_samples_figsize",
        type=int,
        default=[10, 10],
        nargs=2,
        help="Size of the samples plot as a tuple (two arguments)",
    )
    parser.add_argument(
        "--plot_filtering_figsize", type=int, default=[10, 10], nargs=2
    )

    # Parse the command line arguments
    args = parser.parse_args()

    # Call the main function
    main(args)