diff get_db_info.py @ 0:1769c133986b draft default tip

planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 3401816c949b538bd9c67e61cbe92badff6a4007-dirty
author tduigou
date Wed, 11 Jun 2025 09:42:52 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/get_db_info.py	Wed Jun 11 09:42:52 2025 +0000
@@ -0,0 +1,267 @@
+import subprocess
+import argparse
+import time
+import os
+import socket
+import re
+from Bio.Seq import Seq
+import pandas as pd
+from Bio.SeqRecord import SeqRecord
+from sqlalchemy import create_engine, inspect
+from sqlalchemy.engine.url import make_url
+from sqlalchemy.sql import text
+from sqlalchemy.exc import OperationalError
+
+
+def fix_db_uri(uri):
+    """Replace __at__ with @ in the URI if needed."""
+    return uri.replace("__at__", "@")
+
+
+def is_port_in_use(uri):
+    """Check if a TCP port is already in use on host."""
+    url = make_url(uri)
+    host = url.host
+    port = url.port
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        s.settimeout(2)
+        return s.connect_ex((host, port)) == 0
+
+
+def extract_db_name(uri):
+    """Extract the database name from the SQLAlchemy URI."""
+    url = make_url(uri)
+    return url.database
+
+
+# this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN()
+def start_postgres_container(db_name):
+    """Start a PostgreSQL container with the given database name as the container name."""
+    container_name = db_name
+
+    # Check if container is already running
+    container_running = subprocess.run(
+        f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True
+    )
+
+    if container_running.stdout.strip():
+        print(f"Container '{container_name}' is already running.")
+        return
+
+    # Check if container exists (stopped)
+    container_exists = subprocess.run(
+        f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True
+    )
+
+    if container_exists.stdout.strip():
+        print(f"Starting existing container '{container_name}'...")
+        subprocess.run(f"docker start {container_name}", shell=True)
+        print(f"PostgreSQL Docker container '{container_name}' activated.")
+        return
+
+    # If container does not exist, create and start a new one
+    port = 5432 if not is_port_in_use(5432) else 5433
+    postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17")
+
+    start_command = [
+        "docker", "run", "--name", container_name,
+        "-e", f"POSTGRES_PASSWORD={postgres_password}",
+        "-p", f"{port}:5432",
+        "-d", "postgres"
+    ]
+
+    try:
+        subprocess.run(start_command, check=True)
+        print(f"PostgreSQL Docker container '{container_name}' started on port {port}.")
+    except subprocess.CalledProcessError as e:
+        print(f"Failed to start Docker container: {e}")
+
+
+def wait_for_db(uri, timeout=60):
+    """Try connecting to the DB until it works or timeout."""
+    engine = create_engine(uri)
+    start_time = time.time()
+    while time.time() - start_time < timeout:
+        try:
+            with engine.connect():
+                print("Connected to database.")
+                return
+        except OperationalError:
+            print("Database not ready, retrying...")
+            time.sleep(2)
+    raise Exception("Database connection failed after timeout.")
+
+
+def fetch_annotations(csv_file, sequence_column, annotation_columns, db_uri, table_name, fragment_column_name, output):
+    """Fetch annotations from the database and save the result as GenBank files."""
+    db_uri = fix_db_uri(db_uri)
+    df = pd.read_csv(csv_file, sep=',', header=None)
+
+    engine = create_engine(db_uri)
+    connection = engine.connect()
+
+    annotated_data = []
+
+    try:
+        with connection:
+            inspector = inspect(engine)
+            columns = [column['name'] for column in inspector.get_columns(table_name)]
+
+            # Fetch all fragments from the table once
+            if fragment_column_name not in columns:
+                raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.")
+
+            fragment_column_index = columns.index(fragment_column_name)
+            all_rows = connection.execute(text(f"SELECT * FROM {table_name}")).fetchall()
+            fragment_map = {row[fragment_column_index]: row for row in all_rows}
+
+            # Compare fragments between CSV and DB
+            csv_fragments = set()
+            all_ids = set(df[0].dropna().astype(str))
+            for _, row in df.iterrows():
+                for col in df.columns:
+                    if col != 0:
+                        fragment = row[col]
+                        if pd.notna(fragment):
+                            fragment_str = str(fragment)
+                            if fragment_str not in all_ids:
+                                csv_fragments.add(fragment_str)
+
+            db_fragments = set(fragment_map.keys())
+            missing_fragments = sorted(list(csv_fragments - db_fragments))
+            if missing_fragments:
+                raise ValueError(
+                    f" Missing fragments in DB: {', '.join(missing_fragments)}"
+                )
+
+            # === CONTINUE WITH GB FILE CREATION ===
+            for _, row in df.iterrows():
+                annotated_row = {"Backbone": row[0], "Fragments": []}
+                for col in df.columns:
+                    if col != 0:
+                        fragment = row[col]
+                        if fragment not in csv_fragments:
+                            continue
+                        db_row = fragment_map.get(fragment)
+
+                        if db_row:
+                            fragment_data = {"id": fragment}
+                            for i, column_name in enumerate(columns[1:]):  # skip ID column
+                                fragment_data[column_name] = db_row[i + 1]
+                        else:
+                            fragment_data = {"id": fragment, "metadata": "No data found"}
+
+                        annotated_row["Fragments"].append(fragment_data)
+
+                annotated_data.append(annotated_row)
+
+    except Exception as e:
+        print(f"Error occurred during annotation: {e}")
+        raise  # Ensures the error exits the script
+
+    # GenBank file generation per fragment
+    try:
+        for annotated_row in annotated_data:
+            backbone_id = annotated_row["Backbone"]
+            for fragment in annotated_row["Fragments"]:
+                fragment_id = fragment["id"]
+                sequence = fragment.get(sequence_column, "")
+                annotation = fragment.get(annotation_columns, "")
+
+                # Create the SeqRecord
+                record = SeqRecord(
+                    Seq(sequence),
+                    id=fragment_id,
+                    name=fragment_id,
+                    description=f"Fragment {fragment_id} from Backbone {backbone_id}"
+                )
+
+                # Add annotations to GenBank header
+                record.annotations = {
+                    k: str(fragment[k]) for k in annotation_columns if k in fragment
+                }
+
+                # LOCUS line extraction from annotation (copy-paste the LOCUS from annotation)
+                locus_line_match = re.search(r"LOCUS\s+.+", annotation)
+                if locus_line_match:
+                    locus_line = locus_line_match.group()
+                else:
+                    print(f"LOCUS info missing for fragment {fragment_id}")
+                    locus_line = f"LOCUS       {fragment_id: <20} {len(sequence)} bp    DNA     linear   UNK 01-JAN-2025"
+
+                # Format sequence as per GenBank standards (with ORIGIN and line breaks)
+                if "ORIGIN" in sequence:
+                    origin_block = sequence.strip()
+                else:
+                    # Format sequence as per GenBank standards (with ORIGIN and line breaks)
+                    formatted_sequence = "ORIGIN\n"
+                    seq_str = str(record.seq)
+                    for i in range(0, len(seq_str), 60):  # 60 bases per line
+                        line_seq = seq_str[i:i + 60]
+                        formatted_sequence += f"{str(i + 1).rjust(9)} { ' '.join([line_seq[j:j+10] for j in range(0, len(line_seq), 10)]) }\n"
+                    origin_block = formatted_sequence.strip()
+
+                # Find and copy the FEATURES section directly from annotation
+                features_section = ""
+                features_start = annotation.find("FEATURES")
+                if features_start != -1:
+                    features_section = annotation[features_start:]
+
+                # Writing the GenBank file
+                if not os.path.exists(output):
+                    os.makedirs(output)
+
+                gb_filename = os.path.join(output, f"{fragment_id}.gb")
+                with open(gb_filename, "w") as f:
+                    # Write the LOCUS line
+                    f.write(locus_line + "\n")
+                    # Write DEFINITION, ACCESSION, and other annotations
+                    f.write(f"DEFINITION  {record.description}\n")
+                    f.write(f"ACCESSION   {record.id}\n")
+                    f.write(f"VERSION     DB\n")
+                    f.write(f"KEYWORDS    .\n")
+                    f.write(f"SOURCE      .\n")
+                    # Write the FEATURES section directly from annotation
+                    f.write(features_section)
+                    # Write the ORIGIN section
+                    f.write(origin_block + "\n")
+                    f.write("//\n")
+
+    except Exception as e:
+        print(f"Error saving GenBank files: {e}")
+        return
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.")
+    parser.add_argument("--input", required=True, help="Input CSV file")
+    parser.add_argument("--sequence_column", required=True, help="DB column contains sequence for ganbank file")
+    parser.add_argument("--annotation_columns", required=True, help="DB column contains head for ganbank file")
+    parser.add_argument("--db_uri", required=True, help="Database URI connection string")
+    parser.add_argument("--table", required=True, help="Table name in the database")
+    parser.add_argument("--fragment_column", required=True, help="Fragment column name in the database")
+    parser.add_argument("--output", required=True, help="Output dir for gb files")
+    args = parser.parse_args()
+
+    # Wait until the database is ready
+    db_uri = fix_db_uri(args.db_uri)
+    # db_name = extract_db_name(db_uri)
+    # start_postgres_container(db_name)
+    MAX_RETRIES = 3
+    for attempt in range(1, MAX_RETRIES + 1):
+        try:
+            wait_for_db(db_uri)
+            break  # Success
+        except Exception as e:
+            if attempt == MAX_RETRIES:
+                print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.")
+                raise e
+            else:
+                time.sleep(2)
+
+    # Fetch annotations from the database and save as gb
+    fetch_annotations(args.input, args.sequence_column, args.annotation_columns, db_uri, args.table, args.fragment_column, args.output)
+
+
+if __name__ == "__main__":
+    main()