Mercurial > repos > tduigou > get_from_db
comparison get_db_info.py @ 0:1769c133986b draft default tip
planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 3401816c949b538bd9c67e61cbe92badff6a4007-dirty
author | tduigou |
---|---|
date | Wed, 11 Jun 2025 09:42:52 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:1769c133986b |
---|---|
1 import subprocess | |
2 import argparse | |
3 import time | |
4 import os | |
5 import socket | |
6 import re | |
7 from Bio.Seq import Seq | |
8 import pandas as pd | |
9 from Bio.SeqRecord import SeqRecord | |
10 from sqlalchemy import create_engine, inspect | |
11 from sqlalchemy.engine.url import make_url | |
12 from sqlalchemy.sql import text | |
13 from sqlalchemy.exc import OperationalError | |
14 | |
15 | |
16 def fix_db_uri(uri): | |
17 """Replace __at__ with @ in the URI if needed.""" | |
18 return uri.replace("__at__", "@") | |
19 | |
20 | |
21 def is_port_in_use(uri): | |
22 """Check if a TCP port is already in use on host.""" | |
23 url = make_url(uri) | |
24 host = url.host | |
25 port = url.port | |
26 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
27 s.settimeout(2) | |
28 return s.connect_ex((host, port)) == 0 | |
29 | |
30 | |
31 def extract_db_name(uri): | |
32 """Extract the database name from the SQLAlchemy URI.""" | |
33 url = make_url(uri) | |
34 return url.database | |
35 | |
36 | |
37 # this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN() | |
38 def start_postgres_container(db_name): | |
39 """Start a PostgreSQL container with the given database name as the container name.""" | |
40 container_name = db_name | |
41 | |
42 # Check if container is already running | |
43 container_running = subprocess.run( | |
44 f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True | |
45 ) | |
46 | |
47 if container_running.stdout.strip(): | |
48 print(f"Container '{container_name}' is already running.") | |
49 return | |
50 | |
51 # Check if container exists (stopped) | |
52 container_exists = subprocess.run( | |
53 f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True | |
54 ) | |
55 | |
56 if container_exists.stdout.strip(): | |
57 print(f"Starting existing container '{container_name}'...") | |
58 subprocess.run(f"docker start {container_name}", shell=True) | |
59 print(f"PostgreSQL Docker container '{container_name}' activated.") | |
60 return | |
61 | |
62 # If container does not exist, create and start a new one | |
63 port = 5432 if not is_port_in_use(5432) else 5433 | |
64 postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17") | |
65 | |
66 start_command = [ | |
67 "docker", "run", "--name", container_name, | |
68 "-e", f"POSTGRES_PASSWORD={postgres_password}", | |
69 "-p", f"{port}:5432", | |
70 "-d", "postgres" | |
71 ] | |
72 | |
73 try: | |
74 subprocess.run(start_command, check=True) | |
75 print(f"PostgreSQL Docker container '{container_name}' started on port {port}.") | |
76 except subprocess.CalledProcessError as e: | |
77 print(f"Failed to start Docker container: {e}") | |
78 | |
79 | |
80 def wait_for_db(uri, timeout=60): | |
81 """Try connecting to the DB until it works or timeout.""" | |
82 engine = create_engine(uri) | |
83 start_time = time.time() | |
84 while time.time() - start_time < timeout: | |
85 try: | |
86 with engine.connect(): | |
87 print("Connected to database.") | |
88 return | |
89 except OperationalError: | |
90 print("Database not ready, retrying...") | |
91 time.sleep(2) | |
92 raise Exception("Database connection failed after timeout.") | |
93 | |
94 | |
95 def fetch_annotations(csv_file, sequence_column, annotation_columns, db_uri, table_name, fragment_column_name, output): | |
96 """Fetch annotations from the database and save the result as GenBank files.""" | |
97 db_uri = fix_db_uri(db_uri) | |
98 df = pd.read_csv(csv_file, sep=',', header=None) | |
99 | |
100 engine = create_engine(db_uri) | |
101 connection = engine.connect() | |
102 | |
103 annotated_data = [] | |
104 | |
105 try: | |
106 with connection: | |
107 inspector = inspect(engine) | |
108 columns = [column['name'] for column in inspector.get_columns(table_name)] | |
109 | |
110 # Fetch all fragments from the table once | |
111 if fragment_column_name not in columns: | |
112 raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.") | |
113 | |
114 fragment_column_index = columns.index(fragment_column_name) | |
115 all_rows = connection.execute(text(f"SELECT * FROM {table_name}")).fetchall() | |
116 fragment_map = {row[fragment_column_index]: row for row in all_rows} | |
117 | |
118 # Compare fragments between CSV and DB | |
119 csv_fragments = set() | |
120 all_ids = set(df[0].dropna().astype(str)) | |
121 for _, row in df.iterrows(): | |
122 for col in df.columns: | |
123 if col != 0: | |
124 fragment = row[col] | |
125 if pd.notna(fragment): | |
126 fragment_str = str(fragment) | |
127 if fragment_str not in all_ids: | |
128 csv_fragments.add(fragment_str) | |
129 | |
130 db_fragments = set(fragment_map.keys()) | |
131 missing_fragments = sorted(list(csv_fragments - db_fragments)) | |
132 if missing_fragments: | |
133 raise ValueError( | |
134 f" Missing fragments in DB: {', '.join(missing_fragments)}" | |
135 ) | |
136 | |
137 # === CONTINUE WITH GB FILE CREATION === | |
138 for _, row in df.iterrows(): | |
139 annotated_row = {"Backbone": row[0], "Fragments": []} | |
140 for col in df.columns: | |
141 if col != 0: | |
142 fragment = row[col] | |
143 if fragment not in csv_fragments: | |
144 continue | |
145 db_row = fragment_map.get(fragment) | |
146 | |
147 if db_row: | |
148 fragment_data = {"id": fragment} | |
149 for i, column_name in enumerate(columns[1:]): # skip ID column | |
150 fragment_data[column_name] = db_row[i + 1] | |
151 else: | |
152 fragment_data = {"id": fragment, "metadata": "No data found"} | |
153 | |
154 annotated_row["Fragments"].append(fragment_data) | |
155 | |
156 annotated_data.append(annotated_row) | |
157 | |
158 except Exception as e: | |
159 print(f"Error occurred during annotation: {e}") | |
160 raise # Ensures the error exits the script | |
161 | |
162 # GenBank file generation per fragment | |
163 try: | |
164 for annotated_row in annotated_data: | |
165 backbone_id = annotated_row["Backbone"] | |
166 for fragment in annotated_row["Fragments"]: | |
167 fragment_id = fragment["id"] | |
168 sequence = fragment.get(sequence_column, "") | |
169 annotation = fragment.get(annotation_columns, "") | |
170 | |
171 # Create the SeqRecord | |
172 record = SeqRecord( | |
173 Seq(sequence), | |
174 id=fragment_id, | |
175 name=fragment_id, | |
176 description=f"Fragment {fragment_id} from Backbone {backbone_id}" | |
177 ) | |
178 | |
179 # Add annotations to GenBank header | |
180 record.annotations = { | |
181 k: str(fragment[k]) for k in annotation_columns if k in fragment | |
182 } | |
183 | |
184 # LOCUS line extraction from annotation (copy-paste the LOCUS from annotation) | |
185 locus_line_match = re.search(r"LOCUS\s+.+", annotation) | |
186 if locus_line_match: | |
187 locus_line = locus_line_match.group() | |
188 else: | |
189 print(f"LOCUS info missing for fragment {fragment_id}") | |
190 locus_line = f"LOCUS {fragment_id: <20} {len(sequence)} bp DNA linear UNK 01-JAN-2025" | |
191 | |
192 # Format sequence as per GenBank standards (with ORIGIN and line breaks) | |
193 if "ORIGIN" in sequence: | |
194 origin_block = sequence.strip() | |
195 else: | |
196 # Format sequence as per GenBank standards (with ORIGIN and line breaks) | |
197 formatted_sequence = "ORIGIN\n" | |
198 seq_str = str(record.seq) | |
199 for i in range(0, len(seq_str), 60): # 60 bases per line | |
200 line_seq = seq_str[i:i + 60] | |
201 formatted_sequence += f"{str(i + 1).rjust(9)} { ' '.join([line_seq[j:j+10] for j in range(0, len(line_seq), 10)]) }\n" | |
202 origin_block = formatted_sequence.strip() | |
203 | |
204 # Find and copy the FEATURES section directly from annotation | |
205 features_section = "" | |
206 features_start = annotation.find("FEATURES") | |
207 if features_start != -1: | |
208 features_section = annotation[features_start:] | |
209 | |
210 # Writing the GenBank file | |
211 if not os.path.exists(output): | |
212 os.makedirs(output) | |
213 | |
214 gb_filename = os.path.join(output, f"{fragment_id}.gb") | |
215 with open(gb_filename, "w") as f: | |
216 # Write the LOCUS line | |
217 f.write(locus_line + "\n") | |
218 # Write DEFINITION, ACCESSION, and other annotations | |
219 f.write(f"DEFINITION {record.description}\n") | |
220 f.write(f"ACCESSION {record.id}\n") | |
221 f.write(f"VERSION DB\n") | |
222 f.write(f"KEYWORDS .\n") | |
223 f.write(f"SOURCE .\n") | |
224 # Write the FEATURES section directly from annotation | |
225 f.write(features_section) | |
226 # Write the ORIGIN section | |
227 f.write(origin_block + "\n") | |
228 f.write("//\n") | |
229 | |
230 except Exception as e: | |
231 print(f"Error saving GenBank files: {e}") | |
232 return | |
233 | |
234 | |
235 def main(): | |
236 parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.") | |
237 parser.add_argument("--input", required=True, help="Input CSV file") | |
238 parser.add_argument("--sequence_column", required=True, help="DB column contains sequence for ganbank file") | |
239 parser.add_argument("--annotation_columns", required=True, help="DB column contains head for ganbank file") | |
240 parser.add_argument("--db_uri", required=True, help="Database URI connection string") | |
241 parser.add_argument("--table", required=True, help="Table name in the database") | |
242 parser.add_argument("--fragment_column", required=True, help="Fragment column name in the database") | |
243 parser.add_argument("--output", required=True, help="Output dir for gb files") | |
244 args = parser.parse_args() | |
245 | |
246 # Wait until the database is ready | |
247 db_uri = fix_db_uri(args.db_uri) | |
248 # db_name = extract_db_name(db_uri) | |
249 # start_postgres_container(db_name) | |
250 MAX_RETRIES = 3 | |
251 for attempt in range(1, MAX_RETRIES + 1): | |
252 try: | |
253 wait_for_db(db_uri) | |
254 break # Success | |
255 except Exception as e: | |
256 if attempt == MAX_RETRIES: | |
257 print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.") | |
258 raise e | |
259 else: | |
260 time.sleep(2) | |
261 | |
262 # Fetch annotations from the database and save as gb | |
263 fetch_annotations(args.input, args.sequence_column, args.annotation_columns, db_uri, args.table, args.fragment_column, args.output) | |
264 | |
265 | |
266 if __name__ == "__main__": | |
267 main() |