Mercurial > repos > bimib > cobraxy
comparison COBRAxy/utils/model_utils.py @ 426:00a78da611ba draft
Uploaded
author | francesco_lapi |
---|---|
date | Wed, 10 Sep 2025 09:25:32 +0000 |
parents | ed2c1f9e20ba |
children | 4a385fdb9e58 |
comparison
equal
deleted
inserted
replaced
425:233d5d1e6bb2 | 426:00a78da611ba |
---|---|
3 import cobra | 3 import cobra |
4 import pickle | 4 import pickle |
5 import argparse | 5 import argparse |
6 import pandas as pd | 6 import pandas as pd |
7 import re | 7 import re |
8 import logging | |
8 from typing import Optional, Tuple, Union, List, Dict, Set | 9 from typing import Optional, Tuple, Union, List, Dict, Set |
10 from collections import defaultdict | |
9 import utils.general_utils as utils | 11 import utils.general_utils as utils |
10 import utils.rule_parsing as rulesUtils | 12 import utils.rule_parsing as rulesUtils |
11 import utils.reaction_parsing as reactionUtils | 13 import utils.reaction_parsing as reactionUtils |
12 from cobra import Model as cobraModel, Reaction, Metabolite | 14 from cobra import Model as cobraModel, Reaction, Metabolite |
13 | 15 |
70 trueMedium.append(r.id) | 72 trueMedium.append(r.id) |
71 | 73 |
72 df_medium = pd.DataFrame() | 74 df_medium = pd.DataFrame() |
73 df_medium["reaction"] = trueMedium | 75 df_medium["reaction"] = trueMedium |
74 return df_medium | 76 return df_medium |
77 | |
78 def extract_objective_coefficients(model: cobraModel) -> pd.DataFrame: | |
79 """ | |
80 Estrae i coefficienti della funzione obiettivo per ciascuna reazione del modello. | |
81 | |
82 Args: | |
83 model : cobra.Model | |
84 | |
85 Returns: | |
86 pd.DataFrame con colonne: ReactionID, ObjectiveCoefficient | |
87 """ | |
88 coeffs = [] | |
89 # model.objective.expression è un'espressione lineare | |
90 objective_expr = model.objective.expression.as_coefficients_dict() | |
91 | |
92 for reaction in model.reactions: | |
93 coeff = objective_expr.get(reaction.forward_variable, 0.0) | |
94 coeffs.append({ | |
95 "ReactionID": reaction.id, | |
96 "ObjectiveCoefficient": coeff | |
97 }) | |
98 | |
99 return pd.DataFrame(coeffs) | |
75 | 100 |
76 def generate_bounds(model:cobraModel) -> pd.DataFrame: | 101 def generate_bounds(model:cobraModel) -> pd.DataFrame: |
77 | 102 |
78 rxns = [] | 103 rxns = [] |
79 for reaction in model.reactions: | 104 for reaction in model.reactions: |
402 print("No annotation in gene dict!") | 427 print("No annotation in gene dict!") |
403 return -1 | 428 return -1 |
404 rename_genes(model2,dict_genes) | 429 rename_genes(model2,dict_genes) |
405 | 430 |
406 return model2 | 431 return model2 |
432 | |
433 | |
434 | |
435 # ---------- Utility helpers ---------- | |
436 def _normalize_colname(col: str) -> str: | |
437 return col.strip().lower().replace(' ', '_') | |
438 | |
439 def _choose_columns(mapping_df: 'pd.DataFrame') -> Dict[str, str]: | |
440 """ | |
441 Cerca colonne utili e ritorna dict {ensg: colname1, hgnc_id: colname2, ...} | |
442 Lancia ValueError se non trova almeno un mapping utile. | |
443 """ | |
444 cols = { _normalize_colname(c): c for c in mapping_df.columns } | |
445 chosen = {} | |
446 # possibili nomi per ciascuna categoria | |
447 candidates = { | |
448 'ensg': ['ensg', 'ensembl_gene_id', 'ensembl'], | |
449 'hgnc_id': ['hgnc_id', 'hgnc', 'hgnc:'], | |
450 'hgnc_symbol': ['hgnc_symbol', 'hgnc_symbol', 'symbol'], | |
451 'entrez_id': ['entrez', 'entrez_id', 'entrezgene'] | |
452 } | |
453 for key, names in candidates.items(): | |
454 for n in names: | |
455 if n in cols: | |
456 chosen[key] = cols[n] | |
457 break | |
458 return chosen | |
459 | |
460 def _validate_target_uniqueness(mapping_df: 'pd.DataFrame', | |
461 source_col: str, | |
462 target_col: str, | |
463 model_source_genes: Optional[Set[str]] = None, | |
464 logger: Optional[logging.Logger] = None) -> None: | |
465 """ | |
466 Verifica che, nel mapping_df (eventualmente già filtrato sui source di interesse), | |
467 ogni target sia associato ad al massimo un source. Se trova target associati a | |
468 >1 source solleva ValueError mostrando esempi. | |
469 | |
470 - mapping_df: DataFrame con colonne source_col, target_col | |
471 - model_source_genes: se fornito, è un set di source normalizzati che stiamo traducendo | |
472 (se None, si usa tutto mapping_df) | |
473 """ | |
474 if logger is None: | |
475 logger = logging.getLogger(__name__) | |
476 | |
477 if mapping_df.empty: | |
478 logger.warning("Mapping dataframe is empty for the requested source genes; skipping uniqueness validation.") | |
479 return | |
480 | |
481 # normalizza le colonne temporanee per gruppi (senza modificare il df originale) | |
482 tmp = mapping_df[[source_col, target_col]].copy() | |
483 tmp['_src_norm'] = tmp[source_col].astype(str).map(_normalize_gene_id) | |
484 tmp['_tgt_norm'] = tmp[target_col].astype(str).str.strip() | |
485 | |
486 # se è passato un insieme di geni modello, filtra qui (già fatto nella chiamata, ma doppio-check ok) | |
487 if model_source_genes is not None: | |
488 tmp = tmp[tmp['_src_norm'].isin(model_source_genes)] | |
489 | |
490 if tmp.empty: | |
491 logger.warning("After filtering to model source genes, mapping table is empty — nothing to validate.") | |
492 return | |
493 | |
494 # costruisci il reverse mapping target -> set(sources) | |
495 grouped = tmp.groupby('_tgt_norm')['_src_norm'].agg(lambda s: set(s.dropna())) | |
496 # trova target con più di 1 source | |
497 problematic = {t: sorted(list(s)) for t, s in grouped.items() if len(s) > 1} | |
498 | |
499 if problematic: | |
500 # prepara messaggio di errore con esempi (fino a 20) | |
501 sample_items = list(problematic.items())[:20] | |
502 msg_lines = ["Mapping validation failed: some target IDs are associated with multiple source IDs."] | |
503 for tgt, sources in sample_items: | |
504 msg_lines.append(f" - target '{tgt}' <- sources: {', '.join(sources)}") | |
505 if len(problematic) > len(sample_items): | |
506 msg_lines.append(f" ... and {len(problematic) - len(sample_items)} more cases.") | |
507 full_msg = "\n".join(msg_lines) | |
508 # loggare e sollevare errore | |
509 logger.error(full_msg) | |
510 raise ValueError(full_msg) | |
511 | |
512 # se tutto ok | |
513 logger.info("Mapping validation passed: no target ID is associated with multiple source IDs (within filtered set).") | |
514 | |
515 | |
516 def _normalize_gene_id(g: str) -> str: | |
517 """Rendi consistente un gene id per l'uso come chiave (rimuove prefissi come 'HGNC:' e strip).""" | |
518 if g is None: | |
519 return "" | |
520 g = str(g).strip() | |
521 # remove common prefixes | |
522 g = re.sub(r'^(HGNC:)', '', g, flags=re.IGNORECASE) | |
523 g = re.sub(r'^(ENSG:)', '', g, flags=re.IGNORECASE) | |
524 return g | |
525 | |
526 # ---------- Main public function ---------- | |
527 def translate_model_genes(model: 'cobra.Model', | |
528 mapping_df: 'pd.DataFrame', | |
529 target_nomenclature: str, | |
530 source_nomenclature: str = 'hgnc_id', | |
531 logger: Optional[logging.Logger] = None) -> 'cobra.Model': | |
532 """ | |
533 Translate model genes from source_nomenclature to target_nomenclature. | |
534 mapping_df should contain at least columns that allow the mapping | |
535 (e.g. ensg, hgnc_id, hgnc_symbol, entrez). | |
536 """ | |
537 if logger is None: | |
538 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
539 logger = logging.getLogger(__name__) | |
540 | |
541 logger.info(f"Translating genes from '{source_nomenclature}' to '{target_nomenclature}'") | |
542 | |
543 # normalize column names and choose relevant columns | |
544 chosen = _choose_columns(mapping_df) | |
545 if not chosen: | |
546 raise ValueError("Could not detect useful columns in mapping_df. Expected at least one of: ensg, hgnc_id, hgnc_symbol, entrez.") | |
547 | |
548 # map source/target to actual dataframe column names (allow user-specified source/target keys) | |
549 # normalize input args | |
550 src_key = source_nomenclature.strip().lower() | |
551 tgt_key = target_nomenclature.strip().lower() | |
552 | |
553 # try to find the actual column names for requested keys | |
554 # support synonyms: user may pass "ensg" or "ENSG" etc. | |
555 col_for_src = None | |
556 col_for_tgt = None | |
557 # first, try exact match | |
558 for k, actual in chosen.items(): | |
559 if k == src_key: | |
560 col_for_src = actual | |
561 if k == tgt_key: | |
562 col_for_tgt = actual | |
563 | |
564 # if not found, try mapping common names | |
565 if col_for_src is None: | |
566 # fallback: if user passed 'hgnc_id' but chosen has only 'hgnc_symbol', it's not useful | |
567 # we require at least the source column to exist | |
568 possible_src_names = {k: v for k, v in chosen.items()} | |
569 # try to match by contained substring | |
570 for k, actual in possible_src_names.items(): | |
571 if src_key in k: | |
572 col_for_src = actual | |
573 break | |
574 | |
575 if col_for_tgt is None: | |
576 for k, actual in chosen.items(): | |
577 if tgt_key in k: | |
578 col_for_tgt = actual | |
579 break | |
580 | |
581 if col_for_src is None: | |
582 raise ValueError(f"Source column for '{source_nomenclature}' not found in mapping dataframe.") | |
583 if col_for_tgt is None: | |
584 raise ValueError(f"Target column for '{target_nomenclature}' not found in mapping dataframe.") | |
585 | |
586 | |
587 model_source_genes = { _normalize_gene_id(g.id) for g in model.genes } | |
588 logger.info(f"Filtering mapping to {len(model_source_genes)} source genes present in model (normalized).") | |
589 | |
590 tmp_map = mapping_df[[col_for_src, col_for_tgt]].dropna().copy() | |
591 tmp_map[col_for_src + "_norm"] = tmp_map[col_for_src].astype(str).map(_normalize_gene_id) | |
592 | |
593 filtered_map = tmp_map[tmp_map[col_for_src + "_norm"].isin(model_source_genes)].copy() | |
594 | |
595 # Se non ci sono righe rilevanti, avvisa (possono non esserci mapping per i geni presenti) | |
596 if filtered_map.empty: | |
597 logger.warning("No mapping rows correspond to source genes present in the model after filtering. Proceeding with empty mapping (no translation will occur).") | |
598 | |
599 # --- VALIDAZIONE: nessun target deve essere mappato da piu' di un source (nell'insieme filtrato) --- | |
600 # Se vuoi la verifica su tutto il dataframe (non solo sui geni del modello), passa model_source_genes=None. | |
601 _validate_target_uniqueness(filtered_map, col_for_src, col_for_tgt, model_source_genes=model_source_genes, logger=logger) | |
602 | |
603 # Ora crea il mapping solo sul sottoinsieme filtrato (piu' efficiente) | |
604 # ATTENZIONE: _create_gene_mapping si aspetta i nomi originali delle colonne | |
605 # quindi passiamo filtered_map con le colonne rimappate (senza la col_for_src + "_norm") | |
606 gene_mapping = _create_gene_mapping(filtered_map, col_for_src, col_for_tgt, logger) | |
607 | |
608 # copy model | |
609 model_copy = model.copy() | |
610 | |
611 # statistics | |
612 stats = {'translated': 0, 'one_to_one': 0, 'one_to_many': 0, 'not_found': 0} | |
613 unmapped = [] | |
614 multi = [] | |
615 | |
616 original_genes = {g.id for g in model_copy.genes} | |
617 logger.info(f"Original genes count: {len(original_genes)}") | |
618 | |
619 # translate GPRs | |
620 for rxn in model_copy.reactions: | |
621 gpr = rxn.gene_reaction_rule | |
622 if gpr and gpr.strip(): | |
623 new_gpr = _translate_gpr(gpr, gene_mapping, stats, unmapped, multi, logger) | |
624 if new_gpr != gpr: | |
625 rxn.gene_reaction_rule = new_gpr | |
626 logger.debug(f"Reaction {rxn.id}: '{gpr}' -> '{new_gpr}'") | |
627 | |
628 # update model genes based on new GPRs | |
629 _update_model_genes(model_copy, logger) | |
630 | |
631 # final logging | |
632 _log_translation_statistics(stats, unmapped, multi, original_genes, model_copy.genes, logger) | |
633 | |
634 logger.info("Translation finished") | |
635 return model_copy | |
636 | |
637 | |
638 # ---------- helper functions ---------- | |
639 def _create_gene_mapping(mapping_df, source_col: str, target_col: str, logger: logging.Logger) -> Dict[str, List[str]]: | |
640 """ | |
641 Build mapping dict: source_id -> list of target_ids | |
642 Normalizes IDs (removes prefixes like 'HGNC:' etc). | |
643 """ | |
644 df = mapping_df[[source_col, target_col]].dropna().copy() | |
645 # normalize to string | |
646 df[source_col] = df[source_col].astype(str).map(_normalize_gene_id) | |
647 df[target_col] = df[target_col].astype(str).str.strip() | |
648 | |
649 df = df.drop_duplicates() | |
650 | |
651 logger.info(f"Creating mapping from {len(df)} rows") | |
652 | |
653 mapping = defaultdict(list) | |
654 for _, row in df.iterrows(): | |
655 s = row[source_col] | |
656 t = row[target_col] | |
657 if t not in mapping[s]: | |
658 mapping[s].append(t) | |
659 | |
660 # stats | |
661 one_to_one = sum(1 for v in mapping.values() if len(v) == 1) | |
662 one_to_many = sum(1 for v in mapping.values() if len(v) > 1) | |
663 logger.info(f"Mapping: {len(mapping)} source keys, {one_to_one} 1:1, {one_to_many} 1:many") | |
664 return dict(mapping) | |
665 | |
666 | |
667 def _translate_gpr(gpr_string: str, | |
668 gene_mapping: Dict[str, List[str]], | |
669 stats: Dict[str, int], | |
670 unmapped_genes: List[str], | |
671 multi_mapping_genes: List[Tuple[str, List[str]]], | |
672 logger: logging.Logger) -> str: | |
673 """ | |
674 Translate genes inside a GPR string using gene_mapping. | |
675 Returns new GPR string. | |
676 """ | |
677 # Generic token pattern: letters, digits, :, _, -, ., (captures HGNC:1234, ENSG000..., symbols) | |
678 token_pattern = r'\b[A-Za-z0-9:_.-]+\b' | |
679 tokens = re.findall(token_pattern, gpr_string) | |
680 | |
681 logical = {'and', 'or', 'AND', 'OR', '(', ')'} | |
682 tokens = [t for t in tokens if t not in logical] | |
683 | |
684 new_gpr = gpr_string | |
685 | |
686 for token in sorted(set(tokens), key=lambda x: -len(x)): # longer tokens first to avoid partial replacement | |
687 norm = _normalize_gene_id(token) | |
688 if norm in gene_mapping: | |
689 targets = gene_mapping[norm] | |
690 stats['translated'] += 1 | |
691 if len(targets) == 1: | |
692 stats['one_to_one'] += 1 | |
693 replacement = targets[0] | |
694 else: | |
695 stats['one_to_many'] += 1 | |
696 multi_mapping_genes.append((token, targets)) | |
697 replacement = "(" + " or ".join(targets) + ")" | |
698 | |
699 pattern = r'\b' + re.escape(token) + r'\b' | |
700 new_gpr = re.sub(pattern, replacement, new_gpr) | |
701 else: | |
702 stats['not_found'] += 1 | |
703 if token not in unmapped_genes: | |
704 unmapped_genes.append(token) | |
705 logger.debug(f"Token not found in mapping (left as-is): {token}") | |
706 | |
707 return new_gpr | |
708 | |
709 | |
710 def _update_model_genes(model: 'cobra.Model', logger: logging.Logger): | |
711 """ | |
712 Rebuild model.genes from gene_reaction_rule content. | |
713 Removes genes not referenced and adds missing ones. | |
714 """ | |
715 # collect genes in GPRs | |
716 gene_pattern = r'\b[A-Za-z0-9:_.-]+\b' | |
717 logical = {'and', 'or', 'AND', 'OR', '(', ')'} | |
718 genes_in_gpr: Set[str] = set() | |
719 | |
720 for rxn in model.reactions: | |
721 gpr = rxn.gene_reaction_rule | |
722 if gpr and gpr.strip(): | |
723 toks = re.findall(gene_pattern, gpr) | |
724 toks = [t for t in toks if t not in logical] | |
725 # normalize IDs consistent with mapping normalization | |
726 toks = [_normalize_gene_id(t) for t in toks] | |
727 genes_in_gpr.update(toks) | |
728 | |
729 # existing gene ids | |
730 existing = {g.id for g in model.genes} | |
731 | |
732 # remove obsolete genes | |
733 to_remove = [gid for gid in existing if gid not in genes_in_gpr] | |
734 removed = 0 | |
735 for gid in to_remove: | |
736 try: | |
737 gene_obj = model.genes.get_by_id(gid) | |
738 model.genes.remove(gene_obj) | |
739 removed += 1 | |
740 except Exception: | |
741 # safe-ignore | |
742 pass | |
743 | |
744 # add new genes | |
745 added = 0 | |
746 for gid in genes_in_gpr: | |
747 if gid not in existing: | |
748 new_gene = cobra.Gene(gid) | |
749 try: | |
750 model.genes.add(new_gene) | |
751 except Exception: | |
752 # fallback: if model.genes doesn't support add, try append or model.add_genes | |
753 try: | |
754 model.genes.append(new_gene) | |
755 except Exception: | |
756 try: | |
757 model.add_genes([new_gene]) | |
758 except Exception: | |
759 logger.warning(f"Could not add gene object for {gid}") | |
760 added += 1 | |
761 | |
762 logger.info(f"Model genes updated: removed {removed}, added {added}") | |
763 | |
764 | |
765 def _log_translation_statistics(stats: Dict[str, int], | |
766 unmapped_genes: List[str], | |
767 multi_mapping_genes: List[Tuple[str, List[str]]], | |
768 original_genes: Set[str], | |
769 final_genes, | |
770 logger: logging.Logger): | |
771 logger.info("=== TRANSLATION STATISTICS ===") | |
772 logger.info(f"Translated: {stats.get('translated', 0)} (1:1 = {stats.get('one_to_one', 0)}, 1:many = {stats.get('one_to_many', 0)})") | |
773 logger.info(f"Not found tokens: {stats.get('not_found', 0)}") | |
774 | |
775 final_ids = {g.id for g in final_genes} | |
776 logger.info(f"Genes in model: {len(original_genes)} -> {len(final_ids)}") | |
777 | |
778 if unmapped_genes: | |
779 logger.warning(f"Unmapped tokens ({len(unmapped_genes)}): {', '.join(unmapped_genes[:20])}{(' ...' if len(unmapped_genes)>20 else '')}") | |
780 if multi_mapping_genes: | |
781 logger.info(f"Multi-mapping examples ({len(multi_mapping_genes)}):") | |
782 for orig, targets in multi_mapping_genes[:10]: | |
783 logger.info(f" {orig} -> {', '.join(targets)}") |