comparison COBRAxy/ras_generator.py @ 529:6acd64232dad draft

Uploaded
author francesco_lapi
date Wed, 22 Oct 2025 12:05:30 +0000
parents b02cfa3b36dd
children 352c51a39e23
comparison
equal deleted inserted replaced
528:698802db290d 529:6acd64232dad
8 import sys 8 import sys
9 import argparse 9 import argparse
10 import pandas as pd 10 import pandas as pd
11 import numpy as np 11 import numpy as np
12 import utils.general_utils as utils 12 import utils.general_utils as utils
13 from typing import List, Dict 13 from typing import List, Dict
14 import ast 14 import ast
15
16 # Optional imports for AnnData mode (not used in ras_generator.py)
17 try:
18 from progressbar import ProgressBar, Bar, Percentage
19 from scanpy import AnnData
20 from cobra.flux_analysis.variability import find_essential_reactions, find_essential_genes
21 except ImportError:
22 # These are only needed for AnnData mode, not for ras_generator.py
23 pass
15 24
16 ERRORS = [] 25 ERRORS = []
17 ########################## argparse ########################################## 26 ########################## argparse ##########################################
18 ARGS :argparse.Namespace 27 ARGS :argparse.Namespace
19 def process_args(args:List[str] = None) -> argparse.Namespace: 28 def process_args(args:List[str] = None) -> argparse.Namespace:
148 raise ValueError("No valid rules found in the uploaded file. Please check the file format.") 157 raise ValueError("No valid rules found in the uploaded file. Please check the file format.")
149 # csv rules need to be parsed, those in a pickle format are taken to be pre-parsed. 158 # csv rules need to be parsed, those in a pickle format are taken to be pre-parsed.
150 return dict_rule 159 return dict_rule
151 160
152 161
162 """
163 Class to compute the RAS values
164
165 """
166
167 class RAS_computation:
168
169 def __init__(self, adata=None, model=None, dataset=None, gene_rules=None, rules_total_string=None):
170 """
171 Initialize RAS computation with two possible input modes:
172
173 Mode 1 (Original - for sampling_main.py):
174 adata: AnnData object with gene expression (cells × genes)
175 model: COBRApy model object with reactions and GPRs
176
177 Mode 2 (New - for ras_generator.py):
178 dataset: pandas DataFrame with gene expression (genes × samples)
179 gene_rules: dict mapping reaction IDs to GPR strings
180 rules_total_string: list of all gene names in GPRs (for validation)
181 """
182 self._logic_operators = ['and', 'or', '(', ')']
183 self.val_nan = np.nan
184
185 # Determine which mode we're in
186 if adata is not None and model is not None:
187 # Mode 1: AnnData + COBRApy model (original)
188 self._init_from_anndata(adata, model)
189 elif dataset is not None and gene_rules is not None:
190 # Mode 2: DataFrame + rules dict (ras_generator style)
191 self._init_from_dataframe(dataset, gene_rules, rules_total_string)
192 else:
193 raise ValueError(
194 "Invalid initialization. Provide either:\n"
195 " - adata + model (for AnnData input), or\n"
196 " - dataset + gene_rules (for DataFrame input)"
197 )
198
199 def _normalize_gene_name(self, gene_name):
200 """Normalize gene names by replacing special characters."""
201 return gene_name.replace("-", "_").replace(":", "_")
202
203 def _normalize_rule(self, rule):
204 """Normalize GPR rule: lowercase operators, add spaces around parentheses, normalize gene names."""
205 rule = rule.replace("OR", "or").replace("AND", "and")
206 rule = rule.replace("(", "( ").replace(")", " )")
207 # Normalize gene names in the rule
208 tokens = rule.split()
209 normalized_tokens = [token if token in self._logic_operators else self._normalize_gene_name(token) for token in tokens]
210 return " ".join(normalized_tokens)
211
212 def _init_from_anndata(self, adata, model):
213 """Initialize from AnnData and COBRApy model (original mode)."""
214 # Build the dictionary for the GPRs
215 df_reactions = pd.DataFrame(index=[reaction.id for reaction in model.reactions])
216 gene_rules = [self._normalize_rule(reaction.gene_reaction_rule) for reaction in model.reactions]
217 df_reactions['rule'] = gene_rules
218 df_reactions = df_reactions.reset_index()
219 df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x)))
220
221 self.dict_rule_reactions = df_reactions.to_dict()['index']
222
223 # build useful structures for RAS computation
224 self.model = model
225 self.count_adata = adata.copy()
226
227 # Normalize gene names in both model and dataset
228 model_genes = [self._normalize_gene_name(gene.id) for gene in model.genes]
229 dataset_genes = [self._normalize_gene_name(gene) for gene in self.count_adata.var.index]
230 self.genes = pd.Index(dataset_genes).intersection(model_genes)
231
232 if len(self.genes) == 0:
233 raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.")
234
235 self.cell_ids = list(self.count_adata.obs.index.values)
236 # Get expression data with normalized gene names
237 self.count_df_filtered = self.count_adata.to_df().T
238 self.count_df_filtered.index = [self._normalize_gene_name(g) for g in self.count_df_filtered.index]
239 self.count_df_filtered = self.count_df_filtered.loc[self.genes]
240
241 def _init_from_dataframe(self, dataset, gene_rules, rules_total_string):
242 """Initialize from DataFrame and rules dict (ras_generator mode)."""
243 reactions = list(gene_rules.keys())
244
245 # Build the dictionary for the GPRs
246 df_reactions = pd.DataFrame(index=reactions)
247 gene_rules_list = [self._normalize_rule(gene_rules[reaction_id]) for reaction_id in reactions]
248 df_reactions['rule'] = gene_rules_list
249 df_reactions = df_reactions.reset_index()
250 df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x)))
251
252 self.dict_rule_reactions = df_reactions.to_dict()['index']
253
254 # build useful structures for RAS computation
255 self.model = None
256 self.count_adata = None
257
258 # Normalize gene names in dataset
259 dataset_normalized = dataset.copy()
260 dataset_normalized.index = [self._normalize_gene_name(g) for g in dataset_normalized.index]
261
262 # Determine which genes are in both dataset and GPRs
263 if rules_total_string is not None:
264 rules_genes = [self._normalize_gene_name(g) for g in rules_total_string]
265 self.genes = dataset_normalized.index.intersection(rules_genes)
266 else:
267 # Extract all genes from rules
268 all_genes_in_rules = set()
269 for rule in gene_rules_list:
270 tokens = rule.split()
271 for token in tokens:
272 if token not in self._logic_operators:
273 all_genes_in_rules.add(token)
274 self.genes = dataset_normalized.index.intersection(all_genes_in_rules)
275
276 if len(self.genes) == 0:
277 raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.")
278
279 self.cell_ids = list(dataset_normalized.columns)
280 self.count_df_filtered = dataset_normalized.loc[self.genes]
281
282 def compute(self,
283 or_expression=np.sum, # type of operation to do in case of an or expression (sum, max, mean)
284 and_expression=np.min, # type of operation to do in case of an and expression(min, sum)
285 drop_na_rows=True, # if True remove the nan rows of the ras matrix
286 drop_duplicates=False, # if true, remove duplicates rows
287 ignore_nan=True, # if True, ignore NaN values in GPR evaluation (e.g., A or NaN -> A)
288 print_progressbar=True, # if True, print the progress bar
289 add_count_metadata=True, # if True add metadata of cells in the ras adata
290 add_met_metadata=True, # if True add metadata from the metabolic model (gpr and compartments of reactions)
291 add_essential_reactions=False,
292 add_essential_genes=False
293 ):
294
295 self.or_function = or_expression
296 self.and_function = and_expression
297
298 ras_df = np.full((len(self.dict_rule_reactions), len(self.cell_ids)), np.nan)
299 genes_not_mapped = set() # Track genes not in dataset
300
301 if print_progressbar:
302 pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=len(self.dict_rule_reactions)).start()
303
304 # Process each unique GPR rule
305 for ind, (rule, reaction_ids) in enumerate(self.dict_rule_reactions.items()):
306 if len(rule) == 0:
307 # Empty rule - keep as NaN
308 pass
309 else:
310 # Extract genes from rule
311 rule_genes = [token for token in rule.split() if token not in self._logic_operators]
312 rule_genes_unique = list(set(rule_genes))
313
314 # Which genes are in the dataset?
315 genes_present = [g for g in rule_genes_unique if g in self.genes]
316 genes_missing = [g for g in rule_genes_unique if g not in self.genes]
317
318 if genes_missing:
319 genes_not_mapped.update(genes_missing)
320
321 if len(genes_present) == 0:
322 # No genes in dataset - keep as NaN
323 pass
324 elif len(genes_missing) > 0 and not ignore_nan:
325 # Some genes missing and we don't ignore NaN - set to NaN
326 pass
327 else:
328 # Evaluate the GPR expression using AST
329 # For single gene, AST handles it fine: ast.parse("GENE_A") works
330 try:
331 tree = ast.parse(rule, mode="eval").body
332 data = self.count_df_filtered.loc[genes_present]
333
334 # Evaluate for each cell/sample
335 for j, col in enumerate(data.columns):
336 gene_values = dict(zip(data.index, data[col].values))
337 ras_df[ind, j] = self._evaluate_ast(tree, gene_values, self.or_function, self.and_function, ignore_nan)
338 except:
339 # If parsing fails, keep as NaN
340 pass
341
342 if print_progressbar:
343 pbar.update(ind + 1)
344
345 if print_progressbar:
346 pbar.finish()
347
348 # Store genes not mapped for later use
349 self.genes_not_mapped = sorted(genes_not_mapped)
350
351 # create the dataframe of ras (rules x samples)
352 ras_df = pd.DataFrame(data=ras_df, index=range(len(self.dict_rule_reactions)), columns=self.cell_ids)
353 ras_df['REACTIONS'] = [reaction_ids for rule, reaction_ids in self.dict_rule_reactions.items()]
354
355 reactions_common = pd.DataFrame()
356 reactions_common["REACTIONS"] = ras_df['REACTIONS']
357 reactions_common["proof2"] = ras_df['REACTIONS']
358 reactions_common = reactions_common.explode('REACTIONS')
359 reactions_common = reactions_common.set_index("REACTIONS")
360
361 ras_df = ras_df.explode("REACTIONS")
362 ras_df = ras_df.set_index("REACTIONS")
363
364 if drop_na_rows:
365 ras_df = ras_df.dropna(how="all")
366
367 if drop_duplicates:
368 ras_df = ras_df.drop_duplicates()
369
370 # If initialized from DataFrame (ras_generator mode), return DataFrame instead of AnnData
371 if self.count_adata is None:
372 return ras_df, self.genes_not_mapped
373
374 # Original AnnData mode: create AnnData structure for RAS
375 ras_adata = AnnData(ras_df.T)
376
377 #add metadata
378 if add_count_metadata:
379 ras_adata.var["common_gprs"] = reactions_common.loc[ras_df.index]
380 ras_adata.var["common_gprs"] = ras_adata.var["common_gprs"].apply(lambda x: ",".join(x))
381 for el in self.count_adata.obs.columns:
382 ras_adata.obs["countmatrix_"+el]=self.count_adata.obs[el]
383
384 if add_met_metadata:
385 if self.model is not None and len(self.model.compartments)>0:
386 ras_adata.var['compartments']=[list(self.model.reactions.get_by_id(reaction).compartments) for reaction in ras_adata.var.index]
387 ras_adata.var['compartments']=ras_adata.var["compartments"].apply(lambda x: ",".join(x))
388
389 if self.model is not None:
390 ras_adata.var['GPR rule'] = [self.model.reactions.get_by_id(reaction).gene_reaction_rule for reaction in ras_adata.var.index]
391
392 if add_essential_reactions:
393 if self.model is not None:
394 essential_reactions=find_essential_reactions(self.model)
395 essential_reactions=[el.id for el in essential_reactions]
396 ras_adata.var['essential reactions']=["yes" if el in essential_reactions else "no" for el in ras_adata.var.index]
397
398 if add_essential_genes:
399 if self.model is not None:
400 essential_genes=find_essential_genes(self.model)
401 essential_genes=[el.id for el in essential_genes]
402 ras_adata.var['essential genes']=[" ".join([gene for gene in genes.split() if gene in essential_genes]) for genes in ras_adata.var["GPR rule"]]
403
404 return ras_adata
405
406 def _evaluate_ast(self, node, values, or_function, and_function, ignore_nan):
407 """
408 Evaluate a boolean expression using AST (Abstract Syntax Tree).
409 Handles all GPR types: single gene, simple (A and B), nested (A or (B and C)).
410
411 Args:
412 node: AST node to evaluate
413 values: Dictionary mapping gene names to their expression values
414 or_function: Function to apply for OR operations
415 and_function: Function to apply for AND operations
416 ignore_nan: If True, ignore None/NaN values (e.g., A or None -> A)
417
418 Returns:
419 Evaluated expression result (float or np.nan)
420 """
421 if isinstance(node, ast.BoolOp):
422 # Boolean operation (and/or)
423 vals = [self._evaluate_ast(v, values, or_function, and_function, ignore_nan) for v in node.values]
424
425 if ignore_nan:
426 # Filter out None/NaN values
427 vals = [v for v in vals if v is not None and not (isinstance(v, float) and np.isnan(v))]
428
429 if not vals:
430 return np.nan
431
432 if isinstance(node.op, ast.Or):
433 return or_function(vals)
434 elif isinstance(node.op, ast.And):
435 return and_function(vals)
436
437 elif isinstance(node, ast.Name):
438 # Variable (gene name)
439 return values.get(node.id, None)
440 elif isinstance(node, ast.Constant):
441 # Constant (shouldn't happen in GPRs, but handle it)
442 return values.get(str(node.value), None)
443 else:
444 raise ValueError(f"Unexpected node type in GPR: {ast.dump(node)}")
445
446
447 # ============================================================================
448 # STANDALONE FUNCTION FOR RAS_GENERATOR COMPATIBILITY
449 # ============================================================================
153 450
154 def computeRAS( 451 def computeRAS(
155 dataset,gene_rules,rules_total_string, 452 dataset,
156 or_function=np.sum, # type of operation to do in case of an or expression (max, sum, mean) 453 gene_rules,
157 and_function=np.min, # type of operation to do in case of an and expression(min, sum) 454 rules_total_string,
158 ignore_nan = True 455 or_function=np.sum,
159 ): 456 and_function=np.min,
160 457 ignore_nan=True
161 458 ):
162 logic_operators = ['and', 'or', '(', ')'] 459 """
163 reactions=list(gene_rules.keys()) 460 Compute RAS from tabular data and GPR rules (ras_generator.py compatible).
164 461
165 # Build the dictionary for the GPRs 462 This is a standalone function that wraps the RAS_computation class
166 df_reactions = pd.DataFrame(index=reactions) 463 to provide the same interface as ras_generator.py.
167 gene_rules=[gene_rules[reaction_id].replace("OR","or").replace("AND","and").replace("(","( ").replace(")"," )") for reaction_id in reactions] 464
168 df_reactions['rule'] = gene_rules 465 Args:
169 df_reactions = df_reactions.reset_index() 466 dataset: pandas DataFrame with gene expression (genes × samples)
170 df_reactions = df_reactions.groupby('rule').agg(lambda x: sorted(list(x))) 467 gene_rules: dict mapping reaction IDs to GPR strings
171 468 rules_total_string: list of all gene names in GPRs
172 dict_rule_reactions = df_reactions.to_dict()['index'] 469 or_function: function for OR operations (default: np.sum)
173 470 and_function: function for AND operations (default: np.min)
174 # build useful structures for RAS computation 471 ignore_nan: if True, ignore NaN in GPR evaluation (default: True)
175 genes =dataset.index.intersection(rules_total_string) 472
176 473 Returns:
177 #check if there is one gene at least 474 tuple: (ras_df, genes_not_mapped)
178 if len(genes)==0: 475 - ras_df: DataFrame with RAS values (reactions × samples)
179 raise ValueError("ERROR: No genes from the count matrix match the metabolic model. Check that gene annotations are consistent between model and dataset.") 476 - genes_not_mapped: list of genes in GPRs not found in dataset
180 477 """
181 cell_ids = list(dataset.columns) 478 # Create RAS computation object in DataFrame mode
182 count_df_filtered = dataset.loc[genes] 479 ras_obj = RAS_computation(
183 count_df_filtered = count_df_filtered.rename(index=lambda x: x.replace("-", "_").replace(":", "_")) 480 dataset=dataset,
184 481 gene_rules=gene_rules,
185 ras_df=np.full((len(dict_rule_reactions), len(cell_ids)), np.nan) 482 rules_total_string=rules_total_string
186 483 )
187 # for loop on rules 484
188 genes_not_mapped=[] 485 # Compute RAS
189 ind = 0 486 result = ras_obj.compute(
190 for rule, reaction_ids in dict_rule_reactions.items(): 487 or_expression=or_function,
191 if len(rule) != 0: 488 and_expression=and_function,
192 # there is one gene at least in the formula 489 ignore_nan=ignore_nan,
193 warning_rule="_" 490 print_progressbar=False, # No progress bar for ras_generator
194 if "-" in rule: 491 add_count_metadata=False, # No metadata in DataFrame mode
195 warning_rule="-" 492 add_met_metadata=False,
196 if ":" in rule: 493 add_essential_reactions=False,
197 warning_rule=":" 494 add_essential_genes=False
198 rule_orig=rule.split().copy() #original rule in list 495 )
199 rule = rule.replace(warning_rule,"_") 496
200 #modified rule 497 # Result is a tuple (ras_df, genes_not_mapped) in DataFrame mode
201 rule_split = rule.split() 498 return result
202 rule_split_elements = list(filter(lambda x: x not in logic_operators, rule_split)) # remove of all logical operators
203 rule_split_elements = list(set(rule_split_elements)) # genes in formula
204
205 # which genes are in the count matrix?
206 genes_in_count_matrix = [el for el in rule_split_elements if el in genes]
207 genes_notin_count_matrix = []
208 for el in rule_split_elements:
209 if el not in genes: #not present in original dataset
210 if el.replace("_",warning_rule) in rule_orig:
211 genes_notin_count_matrix.append(el.replace("_",warning_rule))
212 else:
213 genes_notin_count_matrix.append(el)
214 genes_not_mapped.extend(genes_notin_count_matrix)
215
216 # add genes not present in the data
217 if len(genes_in_count_matrix) > 0: #there is at least one gene in the count matrix
218 if len(rule_split) == 1:
219 #one gene --> one reaction
220 ras_df[ind] = count_df_filtered.loc[genes_in_count_matrix]
221 else:
222 if len(genes_notin_count_matrix) > 0 and ignore_nan == False:
223 ras_df[ind] = np.nan
224 else:
225 # more genes in the formula
226 check_only_and=("and" in rule and "or" not in rule) #only and
227 check_only_or=("or" in rule and "and" not in rule) #only or
228 if check_only_and or check_only_or:
229 #or/and sequence
230 matrix = count_df_filtered.loc[genes_in_count_matrix].values
231 #compute for all cells
232 if check_only_and:
233 ras_df[ind] = and_function(matrix, axis=0)
234 else:
235 ras_df[ind] = or_function(matrix, axis=0)
236 else:
237 # complex expression (e.g. A or (B and C))
238 data = count_df_filtered.loc[genes_in_count_matrix] # dataframe of genes in the GPRs
239 tree = ast.parse(rule, mode="eval").body
240 values_by_cell = [dict(zip(data.index, data[col].values)) for col in data.columns]
241 for j, values in enumerate(values_by_cell):
242 ras_df[ind, j] = _evaluate_ast(tree, values, or_function, and_function)
243
244 ind +=1
245
246 #create the dataframe of ras (rules x samples)
247 ras_df= pd.DataFrame(data=ras_df,index=range(len(dict_rule_reactions)), columns=cell_ids)
248 ras_df['Reactions'] = [reaction_ids for rule,reaction_ids in dict_rule_reactions.items()]
249
250 #create the reaction dataframe for ras (reactions x samples)
251 ras_df = ras_df.explode("Reactions").set_index("Reactions")
252
253 #total genes not mapped from the gpr
254 genes_not_mapped = sorted(set(genes_not_mapped))
255
256 return ras_df,genes_not_mapped
257
258 # function to evalute a complex boolean expression e.g. A or (B and C)
259 # function to evalute a complex boolean expression e.g. A or (B and C)
260 def _evaluate_ast( node, values,or_function,and_function):
261 if isinstance(node, ast.BoolOp):
262
263 vals = [_evaluate_ast(v, values,or_function,and_function) for v in node.values]
264
265 vals = [v for v in vals if v is not None]
266 if not vals:
267 return np.nan
268
269 vals = [np.array(v) if isinstance(v, (list, np.ndarray)) else v for v in vals]
270
271 if isinstance(node.op, ast.Or):
272 return or_function(vals)
273 elif isinstance(node.op, ast.And):
274 return and_function(vals)
275
276 elif isinstance(node, ast.Name):
277 return values.get(node.id, None)
278 elif isinstance(node, ast.Constant):
279 key = str(node.value) #convert in str
280 return values.get(key, None)
281 else:
282 raise ValueError(f"Error in boolean expression: {ast.dump(node)}")
283 499
284 def main(args:List[str] = None) -> None: 500 def main(args:List[str] = None) -> None:
285 """ 501 """
286 Initializes everything and sets the program in motion based on the fronted input arguments. 502 Initializes everything and sets the program in motion based on the fronted input arguments.
287 503