diff COBRAxy/ras_generator.py @ 529:6acd64232dad draft

Uploaded
author francesco_lapi
date Wed, 22 Oct 2025 12:05:30 +0000
parents b02cfa3b36dd
children 352c51a39e23
line wrap: on
line diff
--- a/COBRAxy/ras_generator.py	Tue Oct 21 09:49:31 2025 +0000
+++ b/COBRAxy/ras_generator.py	Wed Oct 22 12:05:30 2025 +0000
@@ -10,9 +10,18 @@
 import pandas as pd
 import numpy as np
 import utils.general_utils as utils
-from typing import  List, Dict
+from typing import List, Dict
 import ast
 
+# Optional imports for AnnData mode (not used in ras_generator.py)
+try:
+    from progressbar import ProgressBar, Bar, Percentage
+    from scanpy import AnnData
+    from cobra.flux_analysis.variability import find_essential_reactions, find_essential_genes
+except ImportError:
+    # These are only needed for AnnData mode, not for ras_generator.py
+    pass
+
 ERRORS = []
 ########################## argparse ##########################################
 ARGS :argparse.Namespace
@@ -150,136 +159,343 @@
     return dict_rule
 
 
+"""
+Class to compute the RAS values
 
-def computeRAS(
-            dataset,gene_rules,rules_total_string,
-            or_function=np.sum,    # type of operation to do in case of an or expression (max, sum, mean)
-            and_function=np.min,   # type of operation to do in case of an and expression(min, sum)
-            ignore_nan = True
-            ):
+"""
+
+class RAS_computation:
+
+    def __init__(self, adata=None, model=None, dataset=None, gene_rules=None, rules_total_string=None):
+        """
+        Initialize RAS computation with two possible input modes:
+        
+        Mode 1 (Original - for sampling_main.py):
+            adata: AnnData object with gene expression (cells × genes)
+            model: COBRApy model object with reactions and GPRs
+            
+        Mode 2 (New - for ras_generator.py):
+            dataset: pandas DataFrame with gene expression (genes × samples)
+            gene_rules: dict mapping reaction IDs to GPR strings
+            rules_total_string: list of all gene names in GPRs (for validation)
+        """
+        self._logic_operators = ['and', 'or', '(', ')']
+        self.val_nan = np.nan
+        
+        # Determine which mode we're in
+        if adata is not None and model is not None:
+            # Mode 1: AnnData + COBRApy model (original)
+            self._init_from_anndata(adata, model)
+        elif dataset is not None and gene_rules is not None:
+            # Mode 2: DataFrame + rules dict (ras_generator style)
+            self._init_from_dataframe(dataset, gene_rules, rules_total_string)
+        else:
+            raise ValueError(
+                "Invalid initialization. Provide either:\n"
+                "  - adata + model (for AnnData input), or\n"
+                "  - dataset + gene_rules (for DataFrame input)"
+            )
+    
+    def _normalize_gene_name(self, gene_name):
+        """Normalize gene names by replacing special characters."""
+        return gene_name.replace("-", "_").replace(":", "_")
+    
+    def _normalize_rule(self, rule):
+        """Normalize GPR rule: lowercase operators, add spaces around parentheses, normalize gene names."""
+        rule = rule.replace("OR", "or").replace("AND", "and")
+        rule = rule.replace("(", "( ").replace(")", " )")
+        # Normalize gene names in the rule
+        tokens = rule.split()
+        normalized_tokens = [token if token in self._logic_operators else self._normalize_gene_name(token) for token in tokens]
+        return " ".join(normalized_tokens)
+    
+    def _init_from_anndata(self, adata, model):
+        """Initialize from AnnData and COBRApy model (original mode)."""
+        # Build the dictionary for the GPRs
+        df_reactions = pd.DataFrame(index=[reaction.id for reaction in model.reactions])
+        gene_rules = [self._normalize_rule(reaction.gene_reaction_rule) for reaction in model.reactions]
+        df_reactions['rule'] = gene_rules
+        df_reactions = df_reactions.reset_index()
+        df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x)))
+        
+        self.dict_rule_reactions = df_reactions.to_dict()['index']
+
+        # build useful structures for RAS computation
+        self.model = model
+        self.count_adata = adata.copy()
+        
+        # Normalize gene names in both model and dataset
+        model_genes = [self._normalize_gene_name(gene.id) for gene in model.genes]
+        dataset_genes = [self._normalize_gene_name(gene) for gene in self.count_adata.var.index]
+        self.genes = pd.Index(dataset_genes).intersection(model_genes)
+        
+        if len(self.genes) == 0:
+            raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.")
+        
+        self.cell_ids = list(self.count_adata.obs.index.values)
+        # Get expression data with normalized gene names
+        self.count_df_filtered = self.count_adata.to_df().T
+        self.count_df_filtered.index = [self._normalize_gene_name(g) for g in self.count_df_filtered.index]
+        self.count_df_filtered = self.count_df_filtered.loc[self.genes]
+    
+    def _init_from_dataframe(self, dataset, gene_rules, rules_total_string):
+        """Initialize from DataFrame and rules dict (ras_generator mode)."""
+        reactions = list(gene_rules.keys())
+        
+        # Build the dictionary for the GPRs
+        df_reactions = pd.DataFrame(index=reactions)
+        gene_rules_list = [self._normalize_rule(gene_rules[reaction_id]) for reaction_id in reactions]
+        df_reactions['rule'] = gene_rules_list
+        df_reactions = df_reactions.reset_index()
+        df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x)))
+
+        self.dict_rule_reactions = df_reactions.to_dict()['index']
+
+        # build useful structures for RAS computation
+        self.model = None
+        self.count_adata = None
+        
+        # Normalize gene names in dataset
+        dataset_normalized = dataset.copy()
+        dataset_normalized.index = [self._normalize_gene_name(g) for g in dataset_normalized.index]
+        
+        # Determine which genes are in both dataset and GPRs
+        if rules_total_string is not None:
+            rules_genes = [self._normalize_gene_name(g) for g in rules_total_string]
+            self.genes = dataset_normalized.index.intersection(rules_genes)
+        else:
+            # Extract all genes from rules
+            all_genes_in_rules = set()
+            for rule in gene_rules_list:
+                tokens = rule.split()
+                for token in tokens:
+                    if token not in self._logic_operators:
+                        all_genes_in_rules.add(token)
+            self.genes = dataset_normalized.index.intersection(all_genes_in_rules)
+        
+        if len(self.genes) == 0:
+            raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.")
+        
+        self.cell_ids = list(dataset_normalized.columns)
+        self.count_df_filtered = dataset_normalized.loc[self.genes]
+ 
+    def compute(self,
+                or_expression=np.sum,       # type of operation to do in case of an or expression (sum, max, mean)
+                and_expression=np.min,      # type of operation to do in case of an and expression(min, sum)
+                drop_na_rows=True,          # if True remove the nan rows of the ras  matrix
+                drop_duplicates=False,      # if true, remove duplicates rows
+                ignore_nan=True,            # if True, ignore NaN values in GPR evaluation (e.g., A or NaN -> A)
+                print_progressbar=True,     # if True, print the progress bar
+                add_count_metadata=True,    # if True add metadata of cells in the ras adata
+                add_met_metadata=True,      # if True add metadata from the metabolic model (gpr and compartments of reactions)
+                add_essential_reactions=False,
+                add_essential_genes=False
+                ):
+
+        self.or_function = or_expression
+        self.and_function = and_expression
+        
+        ras_df = np.full((len(self.dict_rule_reactions), len(self.cell_ids)), np.nan)
+        genes_not_mapped = set()  # Track genes not in dataset
+        
+        if print_progressbar:
+            pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=len(self.dict_rule_reactions)).start()
+        
+        # Process each unique GPR rule
+        for ind, (rule, reaction_ids) in enumerate(self.dict_rule_reactions.items()):
+            if len(rule) == 0:
+                # Empty rule - keep as NaN
+                pass
+            else:
+                # Extract genes from rule
+                rule_genes = [token for token in rule.split() if token not in self._logic_operators]
+                rule_genes_unique = list(set(rule_genes))
+                
+                # Which genes are in the dataset?
+                genes_present = [g for g in rule_genes_unique if g in self.genes]
+                genes_missing = [g for g in rule_genes_unique if g not in self.genes]
+                
+                if genes_missing:
+                    genes_not_mapped.update(genes_missing)
+                
+                if len(genes_present) == 0:
+                    # No genes in dataset - keep as NaN
+                    pass
+                elif len(genes_missing) > 0 and not ignore_nan:
+                    # Some genes missing and we don't ignore NaN - set to NaN
+                    pass
+                else:
+                    # Evaluate the GPR expression using AST
+                    # For single gene, AST handles it fine: ast.parse("GENE_A") works
+                    try:
+                        tree = ast.parse(rule, mode="eval").body
+                        data = self.count_df_filtered.loc[genes_present]
+                        
+                        # Evaluate for each cell/sample
+                        for j, col in enumerate(data.columns):
+                            gene_values = dict(zip(data.index, data[col].values))
+                            ras_df[ind, j] = self._evaluate_ast(tree, gene_values, self.or_function, self.and_function, ignore_nan)
+                    except:
+                        # If parsing fails, keep as NaN
+                        pass
+            
+            if print_progressbar:
+                pbar.update(ind + 1)
+        
+        if print_progressbar:
+            pbar.finish()
+        
+        # Store genes not mapped for later use
+        self.genes_not_mapped = sorted(genes_not_mapped)
+        
+        # create the dataframe of ras (rules x samples)
+        ras_df = pd.DataFrame(data=ras_df, index=range(len(self.dict_rule_reactions)), columns=self.cell_ids)
+        ras_df['REACTIONS'] = [reaction_ids for rule, reaction_ids in self.dict_rule_reactions.items()]
+        
+        reactions_common = pd.DataFrame()
+        reactions_common["REACTIONS"] = ras_df['REACTIONS']
+        reactions_common["proof2"] = ras_df['REACTIONS']
+        reactions_common = reactions_common.explode('REACTIONS')
+        reactions_common = reactions_common.set_index("REACTIONS")
+
+        ras_df = ras_df.explode("REACTIONS")
+        ras_df = ras_df.set_index("REACTIONS")
+
+        if drop_na_rows:
+            ras_df = ras_df.dropna(how="all")
+            
+        if drop_duplicates:
+            ras_df = ras_df.drop_duplicates()
+        
+        # If initialized from DataFrame (ras_generator mode), return DataFrame instead of AnnData
+        if self.count_adata is None:
+            return ras_df, self.genes_not_mapped
+        
+        # Original AnnData mode: create AnnData structure for RAS
+        ras_adata = AnnData(ras_df.T)
+
+        #add metadata
+        if add_count_metadata:
+            ras_adata.var["common_gprs"] = reactions_common.loc[ras_df.index]
+            ras_adata.var["common_gprs"] = ras_adata.var["common_gprs"].apply(lambda x: ",".join(x))
+            for el in self.count_adata.obs.columns:
+                ras_adata.obs["countmatrix_"+el]=self.count_adata.obs[el]
+
+        if add_met_metadata:
+            if self.model is not None and len(self.model.compartments)>0:
+                  ras_adata.var['compartments']=[list(self.model.reactions.get_by_id(reaction).compartments) for reaction in ras_adata.var.index]  
+                  ras_adata.var['compartments']=ras_adata.var["compartments"].apply(lambda x: ",".join(x))
+            
+            if self.model is not None:
+                ras_adata.var['GPR rule'] = [self.model.reactions.get_by_id(reaction).gene_reaction_rule for reaction in ras_adata.var.index]
+
+        if add_essential_reactions:
+            if self.model is not None:
+                essential_reactions=find_essential_reactions(self.model)
+                essential_reactions=[el.id for el in essential_reactions]            
+                ras_adata.var['essential reactions']=["yes" if el in essential_reactions else "no" for el in ras_adata.var.index]
+        
+        if add_essential_genes:
+            if self.model is not None:
+                essential_genes=find_essential_genes(self.model)
+                essential_genes=[el.id for el in essential_genes]
+                ras_adata.var['essential genes']=[" ".join([gene for gene in genes.split()  if gene in essential_genes]) for genes in ras_adata.var["GPR rule"]]
+        
+        return ras_adata
+
+    def _evaluate_ast(self, node, values, or_function, and_function, ignore_nan):
+        """
+        Evaluate a boolean expression using AST (Abstract Syntax Tree).
+        Handles all GPR types: single gene, simple (A and B), nested (A or (B and C)).
+        
+        Args:
+            node: AST node to evaluate
+            values: Dictionary mapping gene names to their expression values
+            or_function: Function to apply for OR operations
+            and_function: Function to apply for AND operations
+            ignore_nan: If True, ignore None/NaN values (e.g., A or None -> A)
+            
+        Returns:
+            Evaluated expression result (float or np.nan)
+        """
+        if isinstance(node, ast.BoolOp):
+            # Boolean operation (and/or)
+            vals = [self._evaluate_ast(v, values, or_function, and_function, ignore_nan) for v in node.values]
+            
+            if ignore_nan:
+                # Filter out None/NaN values
+                vals = [v for v in vals if v is not None and not (isinstance(v, float) and np.isnan(v))]
+            
+            if not vals:
+                return np.nan
+            
+            if isinstance(node.op, ast.Or):
+                return or_function(vals)
+            elif isinstance(node.op, ast.And):
+                return and_function(vals)
+
+        elif isinstance(node, ast.Name):
+            # Variable (gene name)
+            return values.get(node.id, None)
+        elif isinstance(node, ast.Constant):
+            # Constant (shouldn't happen in GPRs, but handle it)
+            return values.get(str(node.value), None)
+        else:
+            raise ValueError(f"Unexpected node type in GPR: {ast.dump(node)}")
 
 
-    logic_operators = ['and', 'or', '(', ')']
-    reactions=list(gene_rules.keys())
-
-    # Build the dictionary for the GPRs
-    df_reactions = pd.DataFrame(index=reactions)
-    gene_rules=[gene_rules[reaction_id].replace("OR","or").replace("AND","and").replace("(","( ").replace(")"," )") for reaction_id in reactions]        
-    df_reactions['rule'] = gene_rules
-    df_reactions = df_reactions.reset_index()
-    df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x)))
-
-    dict_rule_reactions = df_reactions.to_dict()['index']
-
-    # build useful structures for RAS computation
-    genes =dataset.index.intersection(rules_total_string)
-    
-    #check if there is one gene at least 
-    if len(genes)==0:
-        raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.") 
-    
-    cell_ids = list(dataset.columns)
-    count_df_filtered = dataset.loc[genes]
-    count_df_filtered = count_df_filtered.rename(index=lambda x: x.replace("-", "_").replace(":", "_"))
-
-    ras_df=np.full((len(dict_rule_reactions), len(cell_ids)), np.nan)
+# ============================================================================
+# STANDALONE FUNCTION FOR RAS_GENERATOR COMPATIBILITY
+# ============================================================================
 
-    # for loop on rules
-    genes_not_mapped=[]
-    ind = 0       
-    for rule, reaction_ids in dict_rule_reactions.items():
-        if len(rule) != 0:
-            # there is one gene at least in the formula
-            warning_rule="_"
-            if "-" in rule:
-                warning_rule="-"
-            if ":" in rule:
-                warning_rule=":"
-            rule_orig=rule.split().copy()  #original rule in list
-            rule = rule.replace(warning_rule,"_")
-             #modified rule
-            rule_split = rule.split()
-            rule_split_elements = list(filter(lambda x: x not in logic_operators, rule_split))  # remove of all logical operators
-            rule_split_elements = list(set(rule_split_elements))                                # genes in formula
-            
-            # which genes are in the count matrix?                
-            genes_in_count_matrix = [el for el in rule_split_elements if el in genes]
-            genes_notin_count_matrix = []
-            for el in rule_split_elements:
-                if el not in genes: #not present in original dataset
-                    if el.replace("_",warning_rule) in rule_orig: 
-                        genes_notin_count_matrix.append(el.replace("_",warning_rule))
-                    else:
-                        genes_notin_count_matrix.append(el)
-            genes_not_mapped.extend(genes_notin_count_matrix)
-            
-            # add genes not present in the data
-            if len(genes_in_count_matrix) > 0: #there is at least one gene in the count matrix                 
-                    if len(rule_split) == 1:
-                        #one gene --> one reaction
-                        ras_df[ind] = count_df_filtered.loc[genes_in_count_matrix]
-                    else:    
-                        if len(genes_notin_count_matrix) > 0 and ignore_nan == False:
-                                ras_df[ind] = np.nan
-                        else:                   
-                            # more genes in the formula
-                            check_only_and=("and" in rule and "or" not in rule) #only and
-                            check_only_or=("or" in rule and "and" not in rule)  #only or
-                            if check_only_and or check_only_or:
-                                #or/and sequence
-                                matrix = count_df_filtered.loc[genes_in_count_matrix].values
-                                #compute for all cells
-                                if check_only_and: 
-                                    ras_df[ind] = and_function(matrix, axis=0)
-                                else:
-                                    ras_df[ind] = or_function(matrix, axis=0)
-                            else:
-                                # complex expression (e.g. A or (B and C))
-                                data = count_df_filtered.loc[genes_in_count_matrix]  # dataframe of genes in the GPRs
-                                tree = ast.parse(rule, mode="eval").body
-                                values_by_cell = [dict(zip(data.index, data[col].values)) for col in data.columns]
-                                for j, values in enumerate(values_by_cell):
-                                    ras_df[ind, j] = _evaluate_ast(tree, values, or_function, and_function)
+def computeRAS(
+    dataset, 
+    gene_rules, 
+    rules_total_string,
+    or_function=np.sum,
+    and_function=np.min,
+    ignore_nan=True
+):
+    """
+    Compute RAS from tabular data and GPR rules (ras_generator.py compatible).
+    
+    This is a standalone function that wraps the RAS_computation class
+    to provide the same interface as ras_generator.py.
+    
+    Args:
+        dataset: pandas DataFrame with gene expression (genes × samples)
+        gene_rules: dict mapping reaction IDs to GPR strings
+        rules_total_string: list of all gene names in GPRs
+        or_function: function for OR operations (default: np.sum)
+        and_function: function for AND operations (default: np.min)
+        ignore_nan: if True, ignore NaN in GPR evaluation (default: True)
     
-        ind +=1
-    
-    #create the dataframe of ras (rules x samples)
-    ras_df= pd.DataFrame(data=ras_df,index=range(len(dict_rule_reactions)), columns=cell_ids)
-    ras_df['Reactions'] = [reaction_ids for rule,reaction_ids in dict_rule_reactions.items()]
+    Returns:
+        tuple: (ras_df, genes_not_mapped)
+            - ras_df: DataFrame with RAS values (reactions × samples)
+            - genes_not_mapped: list of genes in GPRs not found in dataset
+    """
+    # Create RAS computation object in DataFrame mode
+    ras_obj = RAS_computation(
+        dataset=dataset,
+        gene_rules=gene_rules,
+        rules_total_string=rules_total_string
+    )
     
-    #create the reaction dataframe for ras (reactions x samples)
-    ras_df = ras_df.explode("Reactions").set_index("Reactions")
-
-    #total genes not mapped from the gpr
-    genes_not_mapped = sorted(set(genes_not_mapped))
-
-    return ras_df,genes_not_mapped
-
-# function to evalute a complex boolean expression e.g. A or (B and C)
-# function to evalute a complex boolean expression e.g. A or (B and C)
-def _evaluate_ast( node, values,or_function,and_function):
-    if isinstance(node, ast.BoolOp):
-        
-        vals = [_evaluate_ast(v, values,or_function,and_function) for v in node.values]
-       
-        vals = [v for v in vals if v is not None]
-        if not vals:
-            return np.nan
-      
-        vals = [np.array(v) if isinstance(v, (list, np.ndarray)) else v for v in vals]
-
-        if isinstance(node.op, ast.Or):
-            return or_function(vals)
-        elif isinstance(node.op, ast.And):
-            return and_function(vals)
-
-    elif isinstance(node, ast.Name):
-        return values.get(node.id, None)
-    elif isinstance(node, ast.Constant):
-        key = str(node.value)     #convert in str       
-        return values.get(key, None)   
-    else:
-        raise ValueError(f"Error in boolean expression: {ast.dump(node)}")
+    # Compute RAS
+    result = ras_obj.compute(
+        or_expression=or_function,
+        and_expression=and_function,
+        ignore_nan=ignore_nan,
+        print_progressbar=False,  # No progress bar for ras_generator
+        add_count_metadata=False,  # No metadata in DataFrame mode
+        add_met_metadata=False,
+        add_essential_reactions=False,
+        add_essential_genes=False
+    )
+    
+    # Result is a tuple (ras_df, genes_not_mapped) in DataFrame mode
+    return result
 
 def main(args:List[str] = None) -> None:
     """