Mercurial > repos > tduigou > get_from_db
view get_db_info.py @ 0:1769c133986b draft default tip
planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 3401816c949b538bd9c67e61cbe92badff6a4007-dirty
author | tduigou |
---|---|
date | Wed, 11 Jun 2025 09:42:52 +0000 |
parents | |
children |
line wrap: on
line source
import subprocess import argparse import time import os import socket import re from Bio.Seq import Seq import pandas as pd from Bio.SeqRecord import SeqRecord from sqlalchemy import create_engine, inspect from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text from sqlalchemy.exc import OperationalError def fix_db_uri(uri): """Replace __at__ with @ in the URI if needed.""" return uri.replace("__at__", "@") def is_port_in_use(uri): """Check if a TCP port is already in use on host.""" url = make_url(uri) host = url.host port = url.port with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(2) return s.connect_ex((host, port)) == 0 def extract_db_name(uri): """Extract the database name from the SQLAlchemy URI.""" url = make_url(uri) return url.database # this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN() def start_postgres_container(db_name): """Start a PostgreSQL container with the given database name as the container name.""" container_name = db_name # Check if container is already running container_running = subprocess.run( f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True ) if container_running.stdout.strip(): print(f"Container '{container_name}' is already running.") return # Check if container exists (stopped) container_exists = subprocess.run( f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True ) if container_exists.stdout.strip(): print(f"Starting existing container '{container_name}'...") subprocess.run(f"docker start {container_name}", shell=True) print(f"PostgreSQL Docker container '{container_name}' activated.") return # If container does not exist, create and start a new one port = 5432 if not is_port_in_use(5432) else 5433 postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17") start_command = [ "docker", "run", "--name", container_name, "-e", f"POSTGRES_PASSWORD={postgres_password}", "-p", f"{port}:5432", "-d", "postgres" ] try: subprocess.run(start_command, check=True) print(f"PostgreSQL Docker container '{container_name}' started on port {port}.") except subprocess.CalledProcessError as e: print(f"Failed to start Docker container: {e}") def wait_for_db(uri, timeout=60): """Try connecting to the DB until it works or timeout.""" engine = create_engine(uri) start_time = time.time() while time.time() - start_time < timeout: try: with engine.connect(): print("Connected to database.") return except OperationalError: print("Database not ready, retrying...") time.sleep(2) raise Exception("Database connection failed after timeout.") def fetch_annotations(csv_file, sequence_column, annotation_columns, db_uri, table_name, fragment_column_name, output): """Fetch annotations from the database and save the result as GenBank files.""" db_uri = fix_db_uri(db_uri) df = pd.read_csv(csv_file, sep=',', header=None) engine = create_engine(db_uri) connection = engine.connect() annotated_data = [] try: with connection: inspector = inspect(engine) columns = [column['name'] for column in inspector.get_columns(table_name)] # Fetch all fragments from the table once if fragment_column_name not in columns: raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.") fragment_column_index = columns.index(fragment_column_name) all_rows = connection.execute(text(f"SELECT * FROM {table_name}")).fetchall() fragment_map = {row[fragment_column_index]: row for row in all_rows} # Compare fragments between CSV and DB csv_fragments = set() all_ids = set(df[0].dropna().astype(str)) for _, row in df.iterrows(): for col in df.columns: if col != 0: fragment = row[col] if pd.notna(fragment): fragment_str = str(fragment) if fragment_str not in all_ids: csv_fragments.add(fragment_str) db_fragments = set(fragment_map.keys()) missing_fragments = sorted(list(csv_fragments - db_fragments)) if missing_fragments: raise ValueError( f" Missing fragments in DB: {', '.join(missing_fragments)}" ) # === CONTINUE WITH GB FILE CREATION === for _, row in df.iterrows(): annotated_row = {"Backbone": row[0], "Fragments": []} for col in df.columns: if col != 0: fragment = row[col] if fragment not in csv_fragments: continue db_row = fragment_map.get(fragment) if db_row: fragment_data = {"id": fragment} for i, column_name in enumerate(columns[1:]): # skip ID column fragment_data[column_name] = db_row[i + 1] else: fragment_data = {"id": fragment, "metadata": "No data found"} annotated_row["Fragments"].append(fragment_data) annotated_data.append(annotated_row) except Exception as e: print(f"Error occurred during annotation: {e}") raise # Ensures the error exits the script # GenBank file generation per fragment try: for annotated_row in annotated_data: backbone_id = annotated_row["Backbone"] for fragment in annotated_row["Fragments"]: fragment_id = fragment["id"] sequence = fragment.get(sequence_column, "") annotation = fragment.get(annotation_columns, "") # Create the SeqRecord record = SeqRecord( Seq(sequence), id=fragment_id, name=fragment_id, description=f"Fragment {fragment_id} from Backbone {backbone_id}" ) # Add annotations to GenBank header record.annotations = { k: str(fragment[k]) for k in annotation_columns if k in fragment } # LOCUS line extraction from annotation (copy-paste the LOCUS from annotation) locus_line_match = re.search(r"LOCUS\s+.+", annotation) if locus_line_match: locus_line = locus_line_match.group() else: print(f"LOCUS info missing for fragment {fragment_id}") locus_line = f"LOCUS {fragment_id: <20} {len(sequence)} bp DNA linear UNK 01-JAN-2025" # Format sequence as per GenBank standards (with ORIGIN and line breaks) if "ORIGIN" in sequence: origin_block = sequence.strip() else: # Format sequence as per GenBank standards (with ORIGIN and line breaks) formatted_sequence = "ORIGIN\n" seq_str = str(record.seq) for i in range(0, len(seq_str), 60): # 60 bases per line line_seq = seq_str[i:i + 60] formatted_sequence += f"{str(i + 1).rjust(9)} { ' '.join([line_seq[j:j+10] for j in range(0, len(line_seq), 10)]) }\n" origin_block = formatted_sequence.strip() # Find and copy the FEATURES section directly from annotation features_section = "" features_start = annotation.find("FEATURES") if features_start != -1: features_section = annotation[features_start:] # Writing the GenBank file if not os.path.exists(output): os.makedirs(output) gb_filename = os.path.join(output, f"{fragment_id}.gb") with open(gb_filename, "w") as f: # Write the LOCUS line f.write(locus_line + "\n") # Write DEFINITION, ACCESSION, and other annotations f.write(f"DEFINITION {record.description}\n") f.write(f"ACCESSION {record.id}\n") f.write(f"VERSION DB\n") f.write(f"KEYWORDS .\n") f.write(f"SOURCE .\n") # Write the FEATURES section directly from annotation f.write(features_section) # Write the ORIGIN section f.write(origin_block + "\n") f.write("//\n") except Exception as e: print(f"Error saving GenBank files: {e}") return def main(): parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.") parser.add_argument("--input", required=True, help="Input CSV file") parser.add_argument("--sequence_column", required=True, help="DB column contains sequence for ganbank file") parser.add_argument("--annotation_columns", required=True, help="DB column contains head for ganbank file") parser.add_argument("--db_uri", required=True, help="Database URI connection string") parser.add_argument("--table", required=True, help="Table name in the database") parser.add_argument("--fragment_column", required=True, help="Fragment column name in the database") parser.add_argument("--output", required=True, help="Output dir for gb files") args = parser.parse_args() # Wait until the database is ready db_uri = fix_db_uri(args.db_uri) # db_name = extract_db_name(db_uri) # start_postgres_container(db_name) MAX_RETRIES = 3 for attempt in range(1, MAX_RETRIES + 1): try: wait_for_db(db_uri) break # Success except Exception as e: if attempt == MAX_RETRIES: print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.") raise e else: time.sleep(2) # Fetch annotations from the database and save as gb fetch_annotations(args.input, args.sequence_column, args.annotation_columns, db_uri, args.table, args.fragment_column, args.output) if __name__ == "__main__": main()