Mercurial > repos > bimib > cobraxy
diff COBRAxy/utils/model_utils.py @ 426:00a78da611ba draft
Uploaded
author | francesco_lapi |
---|---|
date | Wed, 10 Sep 2025 09:25:32 +0000 |
parents | ed2c1f9e20ba |
children | 4a385fdb9e58 |
line wrap: on
line diff
--- a/COBRAxy/utils/model_utils.py Tue Sep 09 15:05:02 2025 +0000 +++ b/COBRAxy/utils/model_utils.py Wed Sep 10 09:25:32 2025 +0000 @@ -5,7 +5,9 @@ import argparse import pandas as pd import re +import logging from typing import Optional, Tuple, Union, List, Dict, Set +from collections import defaultdict import utils.general_utils as utils import utils.rule_parsing as rulesUtils import utils.reaction_parsing as reactionUtils @@ -73,6 +75,29 @@ df_medium["reaction"] = trueMedium return df_medium +def extract_objective_coefficients(model: cobraModel) -> pd.DataFrame: + """ + Estrae i coefficienti della funzione obiettivo per ciascuna reazione del modello. + + Args: + model : cobra.Model + + Returns: + pd.DataFrame con colonne: ReactionID, ObjectiveCoefficient + """ + coeffs = [] + # model.objective.expression è un'espressione lineare + objective_expr = model.objective.expression.as_coefficients_dict() + + for reaction in model.reactions: + coeff = objective_expr.get(reaction.forward_variable, 0.0) + coeffs.append({ + "ReactionID": reaction.id, + "ObjectiveCoefficient": coeff + }) + + return pd.DataFrame(coeffs) + def generate_bounds(model:cobraModel) -> pd.DataFrame: rxns = [] @@ -403,4 +428,356 @@ return -1 rename_genes(model2,dict_genes) - return model2 \ No newline at end of file + return model2 + + + +# ---------- Utility helpers ---------- +def _normalize_colname(col: str) -> str: + return col.strip().lower().replace(' ', '_') + +def _choose_columns(mapping_df: 'pd.DataFrame') -> Dict[str, str]: + """ + Cerca colonne utili e ritorna dict {ensg: colname1, hgnc_id: colname2, ...} + Lancia ValueError se non trova almeno un mapping utile. + """ + cols = { _normalize_colname(c): c for c in mapping_df.columns } + chosen = {} + # possibili nomi per ciascuna categoria + candidates = { + 'ensg': ['ensg', 'ensembl_gene_id', 'ensembl'], + 'hgnc_id': ['hgnc_id', 'hgnc', 'hgnc:'], + 'hgnc_symbol': ['hgnc_symbol', 'hgnc_symbol', 'symbol'], + 'entrez_id': ['entrez', 'entrez_id', 'entrezgene'] + } + for key, names in candidates.items(): + for n in names: + if n in cols: + chosen[key] = cols[n] + break + return chosen + +def _validate_target_uniqueness(mapping_df: 'pd.DataFrame', + source_col: str, + target_col: str, + model_source_genes: Optional[Set[str]] = None, + logger: Optional[logging.Logger] = None) -> None: + """ + Verifica che, nel mapping_df (eventualmente già filtrato sui source di interesse), + ogni target sia associato ad al massimo un source. Se trova target associati a + >1 source solleva ValueError mostrando esempi. + + - mapping_df: DataFrame con colonne source_col, target_col + - model_source_genes: se fornito, è un set di source normalizzati che stiamo traducendo + (se None, si usa tutto mapping_df) + """ + if logger is None: + logger = logging.getLogger(__name__) + + if mapping_df.empty: + logger.warning("Mapping dataframe is empty for the requested source genes; skipping uniqueness validation.") + return + + # normalizza le colonne temporanee per gruppi (senza modificare il df originale) + tmp = mapping_df[[source_col, target_col]].copy() + tmp['_src_norm'] = tmp[source_col].astype(str).map(_normalize_gene_id) + tmp['_tgt_norm'] = tmp[target_col].astype(str).str.strip() + + # se è passato un insieme di geni modello, filtra qui (già fatto nella chiamata, ma doppio-check ok) + if model_source_genes is not None: + tmp = tmp[tmp['_src_norm'].isin(model_source_genes)] + + if tmp.empty: + logger.warning("After filtering to model source genes, mapping table is empty — nothing to validate.") + return + + # costruisci il reverse mapping target -> set(sources) + grouped = tmp.groupby('_tgt_norm')['_src_norm'].agg(lambda s: set(s.dropna())) + # trova target con più di 1 source + problematic = {t: sorted(list(s)) for t, s in grouped.items() if len(s) > 1} + + if problematic: + # prepara messaggio di errore con esempi (fino a 20) + sample_items = list(problematic.items())[:20] + msg_lines = ["Mapping validation failed: some target IDs are associated with multiple source IDs."] + for tgt, sources in sample_items: + msg_lines.append(f" - target '{tgt}' <- sources: {', '.join(sources)}") + if len(problematic) > len(sample_items): + msg_lines.append(f" ... and {len(problematic) - len(sample_items)} more cases.") + full_msg = "\n".join(msg_lines) + # loggare e sollevare errore + logger.error(full_msg) + raise ValueError(full_msg) + + # se tutto ok + logger.info("Mapping validation passed: no target ID is associated with multiple source IDs (within filtered set).") + + +def _normalize_gene_id(g: str) -> str: + """Rendi consistente un gene id per l'uso come chiave (rimuove prefissi come 'HGNC:' e strip).""" + if g is None: + return "" + g = str(g).strip() + # remove common prefixes + g = re.sub(r'^(HGNC:)', '', g, flags=re.IGNORECASE) + g = re.sub(r'^(ENSG:)', '', g, flags=re.IGNORECASE) + return g + +# ---------- Main public function ---------- +def translate_model_genes(model: 'cobra.Model', + mapping_df: 'pd.DataFrame', + target_nomenclature: str, + source_nomenclature: str = 'hgnc_id', + logger: Optional[logging.Logger] = None) -> 'cobra.Model': + """ + Translate model genes from source_nomenclature to target_nomenclature. + mapping_df should contain at least columns that allow the mapping + (e.g. ensg, hgnc_id, hgnc_symbol, entrez). + """ + if logger is None: + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) + + logger.info(f"Translating genes from '{source_nomenclature}' to '{target_nomenclature}'") + + # normalize column names and choose relevant columns + chosen = _choose_columns(mapping_df) + if not chosen: + raise ValueError("Could not detect useful columns in mapping_df. Expected at least one of: ensg, hgnc_id, hgnc_symbol, entrez.") + + # map source/target to actual dataframe column names (allow user-specified source/target keys) + # normalize input args + src_key = source_nomenclature.strip().lower() + tgt_key = target_nomenclature.strip().lower() + + # try to find the actual column names for requested keys + # support synonyms: user may pass "ensg" or "ENSG" etc. + col_for_src = None + col_for_tgt = None + # first, try exact match + for k, actual in chosen.items(): + if k == src_key: + col_for_src = actual + if k == tgt_key: + col_for_tgt = actual + + # if not found, try mapping common names + if col_for_src is None: + # fallback: if user passed 'hgnc_id' but chosen has only 'hgnc_symbol', it's not useful + # we require at least the source column to exist + possible_src_names = {k: v for k, v in chosen.items()} + # try to match by contained substring + for k, actual in possible_src_names.items(): + if src_key in k: + col_for_src = actual + break + + if col_for_tgt is None: + for k, actual in chosen.items(): + if tgt_key in k: + col_for_tgt = actual + break + + if col_for_src is None: + raise ValueError(f"Source column for '{source_nomenclature}' not found in mapping dataframe.") + if col_for_tgt is None: + raise ValueError(f"Target column for '{target_nomenclature}' not found in mapping dataframe.") + + + model_source_genes = { _normalize_gene_id(g.id) for g in model.genes } + logger.info(f"Filtering mapping to {len(model_source_genes)} source genes present in model (normalized).") + + tmp_map = mapping_df[[col_for_src, col_for_tgt]].dropna().copy() + tmp_map[col_for_src + "_norm"] = tmp_map[col_for_src].astype(str).map(_normalize_gene_id) + + filtered_map = tmp_map[tmp_map[col_for_src + "_norm"].isin(model_source_genes)].copy() + + # Se non ci sono righe rilevanti, avvisa (possono non esserci mapping per i geni presenti) + if filtered_map.empty: + logger.warning("No mapping rows correspond to source genes present in the model after filtering. Proceeding with empty mapping (no translation will occur).") + + # --- VALIDAZIONE: nessun target deve essere mappato da piu' di un source (nell'insieme filtrato) --- + # Se vuoi la verifica su tutto il dataframe (non solo sui geni del modello), passa model_source_genes=None. + _validate_target_uniqueness(filtered_map, col_for_src, col_for_tgt, model_source_genes=model_source_genes, logger=logger) + + # Ora crea il mapping solo sul sottoinsieme filtrato (piu' efficiente) + # ATTENZIONE: _create_gene_mapping si aspetta i nomi originali delle colonne + # quindi passiamo filtered_map con le colonne rimappate (senza la col_for_src + "_norm") + gene_mapping = _create_gene_mapping(filtered_map, col_for_src, col_for_tgt, logger) + + # copy model + model_copy = model.copy() + + # statistics + stats = {'translated': 0, 'one_to_one': 0, 'one_to_many': 0, 'not_found': 0} + unmapped = [] + multi = [] + + original_genes = {g.id for g in model_copy.genes} + logger.info(f"Original genes count: {len(original_genes)}") + + # translate GPRs + for rxn in model_copy.reactions: + gpr = rxn.gene_reaction_rule + if gpr and gpr.strip(): + new_gpr = _translate_gpr(gpr, gene_mapping, stats, unmapped, multi, logger) + if new_gpr != gpr: + rxn.gene_reaction_rule = new_gpr + logger.debug(f"Reaction {rxn.id}: '{gpr}' -> '{new_gpr}'") + + # update model genes based on new GPRs + _update_model_genes(model_copy, logger) + + # final logging + _log_translation_statistics(stats, unmapped, multi, original_genes, model_copy.genes, logger) + + logger.info("Translation finished") + return model_copy + + +# ---------- helper functions ---------- +def _create_gene_mapping(mapping_df, source_col: str, target_col: str, logger: logging.Logger) -> Dict[str, List[str]]: + """ + Build mapping dict: source_id -> list of target_ids + Normalizes IDs (removes prefixes like 'HGNC:' etc). + """ + df = mapping_df[[source_col, target_col]].dropna().copy() + # normalize to string + df[source_col] = df[source_col].astype(str).map(_normalize_gene_id) + df[target_col] = df[target_col].astype(str).str.strip() + + df = df.drop_duplicates() + + logger.info(f"Creating mapping from {len(df)} rows") + + mapping = defaultdict(list) + for _, row in df.iterrows(): + s = row[source_col] + t = row[target_col] + if t not in mapping[s]: + mapping[s].append(t) + + # stats + one_to_one = sum(1 for v in mapping.values() if len(v) == 1) + one_to_many = sum(1 for v in mapping.values() if len(v) > 1) + logger.info(f"Mapping: {len(mapping)} source keys, {one_to_one} 1:1, {one_to_many} 1:many") + return dict(mapping) + + +def _translate_gpr(gpr_string: str, + gene_mapping: Dict[str, List[str]], + stats: Dict[str, int], + unmapped_genes: List[str], + multi_mapping_genes: List[Tuple[str, List[str]]], + logger: logging.Logger) -> str: + """ + Translate genes inside a GPR string using gene_mapping. + Returns new GPR string. + """ + # Generic token pattern: letters, digits, :, _, -, ., (captures HGNC:1234, ENSG000..., symbols) + token_pattern = r'\b[A-Za-z0-9:_.-]+\b' + tokens = re.findall(token_pattern, gpr_string) + + logical = {'and', 'or', 'AND', 'OR', '(', ')'} + tokens = [t for t in tokens if t not in logical] + + new_gpr = gpr_string + + for token in sorted(set(tokens), key=lambda x: -len(x)): # longer tokens first to avoid partial replacement + norm = _normalize_gene_id(token) + if norm in gene_mapping: + targets = gene_mapping[norm] + stats['translated'] += 1 + if len(targets) == 1: + stats['one_to_one'] += 1 + replacement = targets[0] + else: + stats['one_to_many'] += 1 + multi_mapping_genes.append((token, targets)) + replacement = "(" + " or ".join(targets) + ")" + + pattern = r'\b' + re.escape(token) + r'\b' + new_gpr = re.sub(pattern, replacement, new_gpr) + else: + stats['not_found'] += 1 + if token not in unmapped_genes: + unmapped_genes.append(token) + logger.debug(f"Token not found in mapping (left as-is): {token}") + + return new_gpr + + +def _update_model_genes(model: 'cobra.Model', logger: logging.Logger): + """ + Rebuild model.genes from gene_reaction_rule content. + Removes genes not referenced and adds missing ones. + """ + # collect genes in GPRs + gene_pattern = r'\b[A-Za-z0-9:_.-]+\b' + logical = {'and', 'or', 'AND', 'OR', '(', ')'} + genes_in_gpr: Set[str] = set() + + for rxn in model.reactions: + gpr = rxn.gene_reaction_rule + if gpr and gpr.strip(): + toks = re.findall(gene_pattern, gpr) + toks = [t for t in toks if t not in logical] + # normalize IDs consistent with mapping normalization + toks = [_normalize_gene_id(t) for t in toks] + genes_in_gpr.update(toks) + + # existing gene ids + existing = {g.id for g in model.genes} + + # remove obsolete genes + to_remove = [gid for gid in existing if gid not in genes_in_gpr] + removed = 0 + for gid in to_remove: + try: + gene_obj = model.genes.get_by_id(gid) + model.genes.remove(gene_obj) + removed += 1 + except Exception: + # safe-ignore + pass + + # add new genes + added = 0 + for gid in genes_in_gpr: + if gid not in existing: + new_gene = cobra.Gene(gid) + try: + model.genes.add(new_gene) + except Exception: + # fallback: if model.genes doesn't support add, try append or model.add_genes + try: + model.genes.append(new_gene) + except Exception: + try: + model.add_genes([new_gene]) + except Exception: + logger.warning(f"Could not add gene object for {gid}") + added += 1 + + logger.info(f"Model genes updated: removed {removed}, added {added}") + + +def _log_translation_statistics(stats: Dict[str, int], + unmapped_genes: List[str], + multi_mapping_genes: List[Tuple[str, List[str]]], + original_genes: Set[str], + final_genes, + logger: logging.Logger): + logger.info("=== TRANSLATION STATISTICS ===") + logger.info(f"Translated: {stats.get('translated', 0)} (1:1 = {stats.get('one_to_one', 0)}, 1:many = {stats.get('one_to_many', 0)})") + logger.info(f"Not found tokens: {stats.get('not_found', 0)}") + + final_ids = {g.id for g in final_genes} + logger.info(f"Genes in model: {len(original_genes)} -> {len(final_ids)}") + + if unmapped_genes: + logger.warning(f"Unmapped tokens ({len(unmapped_genes)}): {', '.join(unmapped_genes[:20])}{(' ...' if len(unmapped_genes)>20 else '')}") + if multi_mapping_genes: + logger.info(f"Multi-mapping examples ({len(multi_mapping_genes)}):") + for orig, targets in multi_mapping_genes[:10]: + logger.info(f" {orig} -> {', '.join(targets)}") \ No newline at end of file