Mercurial > repos > tduigou > get_from_db
comparison get_db_info.py @ 0:1769c133986b draft default tip
planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 3401816c949b538bd9c67e61cbe92badff6a4007-dirty
| author | tduigou |
|---|---|
| date | Wed, 11 Jun 2025 09:42:52 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1769c133986b |
|---|---|
| 1 import subprocess | |
| 2 import argparse | |
| 3 import time | |
| 4 import os | |
| 5 import socket | |
| 6 import re | |
| 7 from Bio.Seq import Seq | |
| 8 import pandas as pd | |
| 9 from Bio.SeqRecord import SeqRecord | |
| 10 from sqlalchemy import create_engine, inspect | |
| 11 from sqlalchemy.engine.url import make_url | |
| 12 from sqlalchemy.sql import text | |
| 13 from sqlalchemy.exc import OperationalError | |
| 14 | |
| 15 | |
| 16 def fix_db_uri(uri): | |
| 17 """Replace __at__ with @ in the URI if needed.""" | |
| 18 return uri.replace("__at__", "@") | |
| 19 | |
| 20 | |
| 21 def is_port_in_use(uri): | |
| 22 """Check if a TCP port is already in use on host.""" | |
| 23 url = make_url(uri) | |
| 24 host = url.host | |
| 25 port = url.port | |
| 26 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| 27 s.settimeout(2) | |
| 28 return s.connect_ex((host, port)) == 0 | |
| 29 | |
| 30 | |
| 31 def extract_db_name(uri): | |
| 32 """Extract the database name from the SQLAlchemy URI.""" | |
| 33 url = make_url(uri) | |
| 34 return url.database | |
| 35 | |
| 36 | |
| 37 # this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN() | |
| 38 def start_postgres_container(db_name): | |
| 39 """Start a PostgreSQL container with the given database name as the container name.""" | |
| 40 container_name = db_name | |
| 41 | |
| 42 # Check if container is already running | |
| 43 container_running = subprocess.run( | |
| 44 f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True | |
| 45 ) | |
| 46 | |
| 47 if container_running.stdout.strip(): | |
| 48 print(f"Container '{container_name}' is already running.") | |
| 49 return | |
| 50 | |
| 51 # Check if container exists (stopped) | |
| 52 container_exists = subprocess.run( | |
| 53 f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True | |
| 54 ) | |
| 55 | |
| 56 if container_exists.stdout.strip(): | |
| 57 print(f"Starting existing container '{container_name}'...") | |
| 58 subprocess.run(f"docker start {container_name}", shell=True) | |
| 59 print(f"PostgreSQL Docker container '{container_name}' activated.") | |
| 60 return | |
| 61 | |
| 62 # If container does not exist, create and start a new one | |
| 63 port = 5432 if not is_port_in_use(5432) else 5433 | |
| 64 postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17") | |
| 65 | |
| 66 start_command = [ | |
| 67 "docker", "run", "--name", container_name, | |
| 68 "-e", f"POSTGRES_PASSWORD={postgres_password}", | |
| 69 "-p", f"{port}:5432", | |
| 70 "-d", "postgres" | |
| 71 ] | |
| 72 | |
| 73 try: | |
| 74 subprocess.run(start_command, check=True) | |
| 75 print(f"PostgreSQL Docker container '{container_name}' started on port {port}.") | |
| 76 except subprocess.CalledProcessError as e: | |
| 77 print(f"Failed to start Docker container: {e}") | |
| 78 | |
| 79 | |
| 80 def wait_for_db(uri, timeout=60): | |
| 81 """Try connecting to the DB until it works or timeout.""" | |
| 82 engine = create_engine(uri) | |
| 83 start_time = time.time() | |
| 84 while time.time() - start_time < timeout: | |
| 85 try: | |
| 86 with engine.connect(): | |
| 87 print("Connected to database.") | |
| 88 return | |
| 89 except OperationalError: | |
| 90 print("Database not ready, retrying...") | |
| 91 time.sleep(2) | |
| 92 raise Exception("Database connection failed after timeout.") | |
| 93 | |
| 94 | |
| 95 def fetch_annotations(csv_file, sequence_column, annotation_columns, db_uri, table_name, fragment_column_name, output): | |
| 96 """Fetch annotations from the database and save the result as GenBank files.""" | |
| 97 db_uri = fix_db_uri(db_uri) | |
| 98 df = pd.read_csv(csv_file, sep=',', header=None) | |
| 99 | |
| 100 engine = create_engine(db_uri) | |
| 101 connection = engine.connect() | |
| 102 | |
| 103 annotated_data = [] | |
| 104 | |
| 105 try: | |
| 106 with connection: | |
| 107 inspector = inspect(engine) | |
| 108 columns = [column['name'] for column in inspector.get_columns(table_name)] | |
| 109 | |
| 110 # Fetch all fragments from the table once | |
| 111 if fragment_column_name not in columns: | |
| 112 raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.") | |
| 113 | |
| 114 fragment_column_index = columns.index(fragment_column_name) | |
| 115 all_rows = connection.execute(text(f"SELECT * FROM {table_name}")).fetchall() | |
| 116 fragment_map = {row[fragment_column_index]: row for row in all_rows} | |
| 117 | |
| 118 # Compare fragments between CSV and DB | |
| 119 csv_fragments = set() | |
| 120 all_ids = set(df[0].dropna().astype(str)) | |
| 121 for _, row in df.iterrows(): | |
| 122 for col in df.columns: | |
| 123 if col != 0: | |
| 124 fragment = row[col] | |
| 125 if pd.notna(fragment): | |
| 126 fragment_str = str(fragment) | |
| 127 if fragment_str not in all_ids: | |
| 128 csv_fragments.add(fragment_str) | |
| 129 | |
| 130 db_fragments = set(fragment_map.keys()) | |
| 131 missing_fragments = sorted(list(csv_fragments - db_fragments)) | |
| 132 if missing_fragments: | |
| 133 raise ValueError( | |
| 134 f" Missing fragments in DB: {', '.join(missing_fragments)}" | |
| 135 ) | |
| 136 | |
| 137 # === CONTINUE WITH GB FILE CREATION === | |
| 138 for _, row in df.iterrows(): | |
| 139 annotated_row = {"Backbone": row[0], "Fragments": []} | |
| 140 for col in df.columns: | |
| 141 if col != 0: | |
| 142 fragment = row[col] | |
| 143 if fragment not in csv_fragments: | |
| 144 continue | |
| 145 db_row = fragment_map.get(fragment) | |
| 146 | |
| 147 if db_row: | |
| 148 fragment_data = {"id": fragment} | |
| 149 for i, column_name in enumerate(columns[1:]): # skip ID column | |
| 150 fragment_data[column_name] = db_row[i + 1] | |
| 151 else: | |
| 152 fragment_data = {"id": fragment, "metadata": "No data found"} | |
| 153 | |
| 154 annotated_row["Fragments"].append(fragment_data) | |
| 155 | |
| 156 annotated_data.append(annotated_row) | |
| 157 | |
| 158 except Exception as e: | |
| 159 print(f"Error occurred during annotation: {e}") | |
| 160 raise # Ensures the error exits the script | |
| 161 | |
| 162 # GenBank file generation per fragment | |
| 163 try: | |
| 164 for annotated_row in annotated_data: | |
| 165 backbone_id = annotated_row["Backbone"] | |
| 166 for fragment in annotated_row["Fragments"]: | |
| 167 fragment_id = fragment["id"] | |
| 168 sequence = fragment.get(sequence_column, "") | |
| 169 annotation = fragment.get(annotation_columns, "") | |
| 170 | |
| 171 # Create the SeqRecord | |
| 172 record = SeqRecord( | |
| 173 Seq(sequence), | |
| 174 id=fragment_id, | |
| 175 name=fragment_id, | |
| 176 description=f"Fragment {fragment_id} from Backbone {backbone_id}" | |
| 177 ) | |
| 178 | |
| 179 # Add annotations to GenBank header | |
| 180 record.annotations = { | |
| 181 k: str(fragment[k]) for k in annotation_columns if k in fragment | |
| 182 } | |
| 183 | |
| 184 # LOCUS line extraction from annotation (copy-paste the LOCUS from annotation) | |
| 185 locus_line_match = re.search(r"LOCUS\s+.+", annotation) | |
| 186 if locus_line_match: | |
| 187 locus_line = locus_line_match.group() | |
| 188 else: | |
| 189 print(f"LOCUS info missing for fragment {fragment_id}") | |
| 190 locus_line = f"LOCUS {fragment_id: <20} {len(sequence)} bp DNA linear UNK 01-JAN-2025" | |
| 191 | |
| 192 # Format sequence as per GenBank standards (with ORIGIN and line breaks) | |
| 193 if "ORIGIN" in sequence: | |
| 194 origin_block = sequence.strip() | |
| 195 else: | |
| 196 # Format sequence as per GenBank standards (with ORIGIN and line breaks) | |
| 197 formatted_sequence = "ORIGIN\n" | |
| 198 seq_str = str(record.seq) | |
| 199 for i in range(0, len(seq_str), 60): # 60 bases per line | |
| 200 line_seq = seq_str[i:i + 60] | |
| 201 formatted_sequence += f"{str(i + 1).rjust(9)} { ' '.join([line_seq[j:j+10] for j in range(0, len(line_seq), 10)]) }\n" | |
| 202 origin_block = formatted_sequence.strip() | |
| 203 | |
| 204 # Find and copy the FEATURES section directly from annotation | |
| 205 features_section = "" | |
| 206 features_start = annotation.find("FEATURES") | |
| 207 if features_start != -1: | |
| 208 features_section = annotation[features_start:] | |
| 209 | |
| 210 # Writing the GenBank file | |
| 211 if not os.path.exists(output): | |
| 212 os.makedirs(output) | |
| 213 | |
| 214 gb_filename = os.path.join(output, f"{fragment_id}.gb") | |
| 215 with open(gb_filename, "w") as f: | |
| 216 # Write the LOCUS line | |
| 217 f.write(locus_line + "\n") | |
| 218 # Write DEFINITION, ACCESSION, and other annotations | |
| 219 f.write(f"DEFINITION {record.description}\n") | |
| 220 f.write(f"ACCESSION {record.id}\n") | |
| 221 f.write(f"VERSION DB\n") | |
| 222 f.write(f"KEYWORDS .\n") | |
| 223 f.write(f"SOURCE .\n") | |
| 224 # Write the FEATURES section directly from annotation | |
| 225 f.write(features_section) | |
| 226 # Write the ORIGIN section | |
| 227 f.write(origin_block + "\n") | |
| 228 f.write("//\n") | |
| 229 | |
| 230 except Exception as e: | |
| 231 print(f"Error saving GenBank files: {e}") | |
| 232 return | |
| 233 | |
| 234 | |
| 235 def main(): | |
| 236 parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.") | |
| 237 parser.add_argument("--input", required=True, help="Input CSV file") | |
| 238 parser.add_argument("--sequence_column", required=True, help="DB column contains sequence for ganbank file") | |
| 239 parser.add_argument("--annotation_columns", required=True, help="DB column contains head for ganbank file") | |
| 240 parser.add_argument("--db_uri", required=True, help="Database URI connection string") | |
| 241 parser.add_argument("--table", required=True, help="Table name in the database") | |
| 242 parser.add_argument("--fragment_column", required=True, help="Fragment column name in the database") | |
| 243 parser.add_argument("--output", required=True, help="Output dir for gb files") | |
| 244 args = parser.parse_args() | |
| 245 | |
| 246 # Wait until the database is ready | |
| 247 db_uri = fix_db_uri(args.db_uri) | |
| 248 # db_name = extract_db_name(db_uri) | |
| 249 # start_postgres_container(db_name) | |
| 250 MAX_RETRIES = 3 | |
| 251 for attempt in range(1, MAX_RETRIES + 1): | |
| 252 try: | |
| 253 wait_for_db(db_uri) | |
| 254 break # Success | |
| 255 except Exception as e: | |
| 256 if attempt == MAX_RETRIES: | |
| 257 print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.") | |
| 258 raise e | |
| 259 else: | |
| 260 time.sleep(2) | |
| 261 | |
| 262 # Fetch annotations from the database and save as gb | |
| 263 fetch_annotations(args.input, args.sequence_column, args.annotation_columns, db_uri, args.table, args.fragment_column, args.output) | |
| 264 | |
| 265 | |
| 266 if __name__ == "__main__": | |
| 267 main() |
