comparison reAnnotate.py @ 11:5366d5ea04bc draft

planemo upload commit 9d1b19f98d8b7f0a0d1baf2da63a373d155626f8-dirty
author petr-novak
date Fri, 04 Aug 2023 12:35:32 +0000
parents
children
comparison
equal deleted inserted replaced
10:276efc4cb17f 11:5366d5ea04bc
1 #!/usr/bin/env python
2 """
3 parse blast output table to gff file
4 """
5 import argparse
6 import itertools
7 import os
8 import re
9 import shutil
10 import subprocess
11 import sys
12 import tempfile
13 from collections import defaultdict
14
15 # check version of python, must be at least 3.7
16 if sys.version_info < (3, 10):
17 sys.exit("Python 3.10 or a more recent version is required.")
18
19 def make_temp_files(number_of_files):
20 """
21 Make named temporary files, file will not be deleted upon exit!
22 :param number_of_files:
23 :return:
24 filepaths
25 """
26 temp_files = []
27 for i in range(number_of_files):
28 temp_files.append(tempfile.NamedTemporaryFile(delete=False).name)
29 os.remove(temp_files[-1])
30 return temp_files
31
32
33 def split_fasta_to_chunks(fasta_file, chunk_size=100000000, overlap=100000):
34 """
35 Split fasta file to chunks, sequences longe than chuck size are split to overlaping
36 peaces. If sequences are shorter, chunck with multiple sequences are created.
37 :param fasta_file:
38
39 :param fasta_file:
40 :param chunk_size:
41 :param overlap:
42 :return:
43 fasta_file_split
44 matching_table (list of lists [header,chunk_number, start, end, new_header])
45 """
46 min_chunk_size = chunk_size * 2
47 fasta_sizes_dict = read_fasta_sequence_size(fasta_file)
48 # calculate size of items in fasta_dist dictionary
49 fasta_size = sum(fasta_sizes_dict.values())
50
51 # calculates ranges for splitting of fasta files and store them in list
52 matching_table = []
53 fasta_file_split = tempfile.NamedTemporaryFile(delete=False).name
54 for header, size in fasta_sizes_dict.items():
55 print(header, size, min_chunk_size)
56
57 if size > min_chunk_size:
58 number_of_chunks = int(size / chunk_size)
59 print("number_of_chunks", number_of_chunks)
60 print("size", size)
61 print("chunk_size", chunk_size)
62 print("-----------------------------------------")
63 adjusted_chunk_size = int(size / number_of_chunks)
64 for i in range(number_of_chunks):
65 start = i * adjusted_chunk_size
66 end = ((i + 1) *
67 adjusted_chunk_size
68 + overlap) if i + 1 < number_of_chunks else size
69 new_header = header + '_' + str(i)
70 matching_table.append([header, i, start, end, new_header])
71 else:
72 new_header = header + '_0'
73 matching_table.append([header, 0, 0, size, new_header])
74 # read sequences from fasta files and split them to chunks according to matching table
75 # open output and input files, use with statement to close files
76 number_of_temp_files = len(matching_table)
77 print('number of temp files', number_of_temp_files)
78 fasta_dict = read_single_fasta_to_dictionary(open(fasta_file, 'r'))
79 with open(fasta_file_split, 'w') as fh_out:
80 for header in fasta_dict:
81 matching_table_part = [x for x in matching_table if x[0] == header]
82 for header2, i, start, end, new_header in matching_table_part:
83 fh_out.write('>' + new_header + '\n')
84 fh_out.write(fasta_dict[header][start:end] + '\n')
85 temp_files_fasta = make_temp_files(number_of_temp_files)
86 fasta_seq_size = read_fasta_sequence_size(fasta_file_split)
87 seq_id_size_sorted = [i[0] for i in sorted(
88 fasta_seq_size.items(), key=lambda x: int(x[1]), reverse=True
89 )]
90 seq_id_file_dict = dict(zip(seq_id_size_sorted, itertools.cycle(temp_files_fasta)))
91 # write sequences to temporary files
92 with open(fasta_file_split, 'r') as f:
93 first = True
94 for line in f:
95 if line[0] == '>':
96 # close previous file if it is not the first sequence
97 if not first:
98 fout.close()
99 first = False
100 header = line.strip().split(' ')[0][1:]
101 fout = open(seq_id_file_dict[header],'a')
102 fout.write(line)
103 else:
104 fout.write(line)
105 os.remove(fasta_file_split)
106 return temp_files_fasta, matching_table
107
108
109 def read_fasta_sequence_size(fasta_file):
110 """Read size of sequence into dictionary"""
111 fasta_dict = {}
112 with open(fasta_file, 'r') as f:
113 for line in f:
114 if line[0] == '>':
115 header = line.strip().split(' ')[0][1:] # remove part of name after space
116 fasta_dict[header] = 0
117 else:
118 fasta_dict[header] += len(line.strip())
119 return fasta_dict
120
121
122 def read_single_fasta_to_dictionary(fh):
123 """
124 Read fasta file into dictionary
125 :param fh:
126 :return:
127 fasta_dict
128 """
129 fasta_dict = {}
130 for line in fh:
131 if line[0] == '>':
132 header = line.strip().split(' ')[0][1:] # remove part of name after space
133 fasta_dict[header] = []
134 else:
135 fasta_dict[header] += [line.strip()]
136 fasta_dict = {k: ''.join(v) for k, v in fasta_dict.items()}
137 return fasta_dict
138
139
140 def overlap(a, b):
141 """
142 check if two intervals overlap
143 """
144 return max(a[0], b[0]) <= min(a[1], b[1])
145
146
147 def blast2disjoint(
148 blastfile, seqid_counts=None, start_column=6, end_column=7, class_column=1,
149 bitscore_column=11, pident_column=2, canonical_classification=True
150 ):
151 """
152 find all interval beginning and ends in blast file and create bed file
153 input blastfile is tab separated file with columns:
154 'qaccver saccver pident length mismatch gapopen qstart qend sstart send
155 evalue bitscore' (default outfmt 6
156 blast must be sorted on qseqid and qstart
157 """
158 # assume all in one chromosome!
159 starts_ends = {}
160 intervals = {}
161 if canonical_classification:
162 # make regular expression for canonical classification
163 # to match: Name#classification
164 # e.g. "Name_of_sequence#LTR/Ty1_copia/Angela"
165 regex = re.compile(r"(.*)[#](.*)")
166 group = 2
167 else:
168 # make regular expression for non-canonical classification
169 # to match: Classification__Name
170 # e.g. "LTR/Ty1_copia/Angela__Name_of_sequence"
171 regex = re.compile(r"(.*)__(.*)")
172 group = 1
173
174 # identify continuous intervals
175 with open(blastfile, "r") as f:
176 for seqid in sorted(seqid_counts.keys()):
177 n_lines = seqid_counts[seqid]
178 starts_ends[seqid] = set()
179 for i in range(n_lines):
180 items = f.readline().strip().split()
181 # note 1s and 2s labels are used to distinguish between start and end and
182 # guarantee that with same coordinated start will be before end when
183 # sorting (1s < 2e)
184 starts_ends[seqid].add((int(items[start_column]), '1s'))
185 starts_ends[seqid].add((int(items[end_column]), '2e'))
186 intervals[seqid] = []
187 for p1, p2 in itertools.pairwise(sorted(starts_ends[seqid])):
188 if p1[1] == '1s':
189 sp = 0
190 else:
191 sp = 1
192 if p2[1] == '2e':
193 ep = 0
194 else:
195 ep = 1
196 intervals[seqid].append((p1[0] + sp, p2[0] - ep))
197 # scan each blast hit against continuous region and record hit with best score
198 with open(blastfile, "r") as f:
199 disjoint_regions = []
200 for seqid in sorted(seqid_counts.keys()):
201 n_lines = seqid_counts[seqid]
202 idx_of_overlaps = {}
203 best_pident = defaultdict(lambda: 0.0)
204 best_bitscore = defaultdict(lambda: 0.0)
205 best_hit_name = defaultdict(lambda: "")
206 i1 = 0
207 for i in range(n_lines):
208 items = f.readline().strip().split()
209 start = int(items[start_column])
210 end = int(items[end_column])
211 pident = float(items[pident_column])
212 bitscore = float(items[bitscore_column])
213 classification = items[class_column]
214 j = 0
215 done = False
216 while True:
217 # beginning of searched region - does it overlap?
218 c_ovl = overlap(intervals[seqid][i1], (start, end))
219 if c_ovl:
220 # if overlap is detected, add to dictionary
221 idx_of_overlaps[i] = [i1]
222 if best_bitscore[i1] < bitscore:
223 best_pident[i1] = pident
224 best_bitscore[i1] = bitscore
225 best_hit_name[i1] = classification
226 # add search also downstream
227 while True:
228 j += 1
229 if j + i1 >= len(intervals[seqid]):
230 done = True
231 break
232 c_ovl = overlap(intervals[seqid][i1 + j], (start, end))
233 if c_ovl:
234 idx_of_overlaps[i].append(i1 + j)
235 if best_bitscore[i1 + j] < bitscore:
236 best_pident[i1 + j] = pident
237 best_bitscore[i1 + j] = bitscore
238 best_hit_name[i1 + j] = classification
239 else:
240 done = True
241 break
242
243 else:
244 # does no overlap - search next interval
245 i1 += 1
246 if done or i1 >= (len(intervals[seqid]) - 1):
247 break
248
249 for i in sorted(best_pident.keys()):
250 try:
251 classification = re.match(regex, best_hit_name[i]).group(group)
252 except AttributeError:
253 classification = best_hit_name[i]
254 record = (
255 seqid, intervals[seqid][i][0], intervals[seqid][i][1], best_pident[i],
256 classification)
257 disjoint_regions.append(record)
258 return disjoint_regions
259
260
261 def remove_short_interrupting_regions(regions, min_len=10, max_gap=2):
262 """
263 remove intervals shorter than min_len which are directly adjacent to other
264 regions on both sides which are longer than min_len and has same classification
265 """
266 regions_to_remove = []
267 for i in range(1, len(regions) - 1):
268 if regions[i][2] - regions[i][1] < min_len:
269 c1 = regions[i - 1][2] - regions[i - 1][1] > min_len
270 c2 = regions[i + 1][2] - regions[i + 1][1] > min_len
271 c3 = regions[i - 1][4] == regions[i + 1][4] # same classification
272 c4 = regions[i + 1][4] != regions[i][4] # different classification
273 c5 = regions[i][1] - regions[i - 1][2] < max_gap # max gap between regions
274 c6 = regions[i + 1][1] - regions[i][2] < max_gap # max gap between regions
275 if c1 and c2 and c3 & c4 and c5 and c6:
276 regions_to_remove.append(i)
277 for i in sorted(regions_to_remove, reverse=True):
278 del regions[i]
279 return regions
280
281
282 def remove_short_regions(regions, min_l_score=600):
283 """
284 remove intervals shorter than min_len
285 min_l_score is the minimum score for a region to be considered
286 l_score = length * PID
287 """
288 regions_to_remove = []
289 for i in range(len(regions)):
290 l_score = (regions[i][3] - 50) * (regions[i][2] - regions[i][1])
291 if l_score < min_l_score:
292 regions_to_remove.append(i)
293 for i in sorted(regions_to_remove, reverse=True):
294 del regions[i]
295 return regions
296
297
298 def join_disjoint_regions_by_classification(disjoint_regions, max_gap=0):
299 """
300 merge neighboring intervals with same classification and calculate mean weighted score
301 weight correspond to length of the interval
302 """
303 merged_regions = []
304 for seqid, start, end, score, classification in disjoint_regions:
305 score_length = (end - start + 1) * score
306 if len(merged_regions) == 0:
307 merged_regions.append([seqid, start, end, score_length, classification])
308 else:
309 cond_same_class = merged_regions[-1][4] == classification
310 cond_same_seqid = merged_regions[-1][0] == seqid
311 cond_neighboring = start - merged_regions[-1][2] + 1 <= max_gap
312 if cond_same_class and cond_same_seqid and cond_neighboring:
313 # extend region
314 merged_regions[-1] = [merged_regions[-1][0], merged_regions[-1][1], end,
315 merged_regions[-1][3] + score_length,
316 merged_regions[-1][4]]
317 else:
318 merged_regions.append([seqid, start, end, score_length, classification])
319 # recalculate length weighted score
320 for record in merged_regions:
321 record[3] = record[3] / (record[2] - record[1] + 1)
322 return merged_regions
323
324
325 def write_merged_regions_to_gff3(merged_regions, outfile):
326 """
327 write merged regions to gff3 file
328 """
329 with open(outfile, "w") as f:
330 # write header
331 f.write("##gff-version 3\n")
332 for seqid, start, end, score, classification in merged_regions:
333 attributes = "Name={};score={}".format(classification, score)
334 f.write(
335 "\t".join(
336 [seqid, "blast_parsed", "repeat_region", str(start), str(end),
337 str(round(score,2)), ".", ".", attributes]
338 )
339 )
340 f.write("\n")
341
342
343 def sort_blast_table(
344 blastfile, seqid_column=0, start_column=6, cpu=1
345 ):
346 """
347 split blast table by seqid and sort by start position
348 stores output in temp files
349 columns are indexed from 0
350 but cut uses 1-based indexing!
351 """
352 blast_sorted = tempfile.NamedTemporaryFile().name
353 # create sorted dictionary seqid counts
354 seq_id_counts = {}
355 # sort blast file on disk using sort on seqid and start (numeric) position columns
356 # using sort command as blast output could be very large
357 cmd = "sort -k {0},{0} -k {1},{1}n --parallel {4} {2} > {3}".format(
358 seqid_column + 1, start_column + 1, blastfile, blast_sorted, cpu
359 )
360 subprocess.check_call(cmd, shell=True)
361
362 # count seqids using uniq command
363 cmd = "cut -f {0} {1} | uniq -c > {2}".format(
364 seqid_column + 1, blast_sorted, blast_sorted + ".counts"
365 )
366 subprocess.check_call(cmd, shell=True)
367 # read counts file and create dictionary
368 with open(blast_sorted + ".counts", "r") as f:
369 for line in f:
370 line = line.strip().split()
371 seq_id_counts[line[1]] = int(line[0])
372 # remove counts file
373 subprocess.call(["rm", blast_sorted + ".counts"])
374 # return sorted dictionary and sorted blast file
375 return seq_id_counts, blast_sorted
376
377
378 def run_blastn(
379 query, db, blastfile, evalue=1e-3, max_target_seqs=999999999, gapopen=2,
380 gapextend=1, reward=1, penalty=-1, word_size=9, num_threads=1, outfmt="6"
381 ):
382 """
383 run blastn
384 """
385 # create temporary blast database:
386 db_formated = tempfile.NamedTemporaryFile().name
387 cmd = "makeblastdb -in {0} -dbtype nucl -out {1}".format(db, db_formated)
388 subprocess.check_call(cmd, shell=True)
389 # if query is smaller than 1GB, run blast on single file
390 size = os.path.getsize(query)
391 print("query size: {} bytes".format(size))
392 max_size = 1e6
393 overlap = 50000
394 if size < max_size:
395 cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} "
396 "-max_target_seqs {4} "
397 "-gapopen {5} -gapextend {6} -word_size {7} -num_threads "
398 "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format(
399 query, db_formated, blastfile, evalue, max_target_seqs, gapopen, gapextend,
400 word_size, num_threads, outfmt, reward, penalty
401 )
402 subprocess.check_call(cmd, shell=True)
403 # if query is larger than 1GB, split query in chunks and run blast on each chunk
404 else:
405 print(f"query is larger than {max_size}, splitting query in chunks")
406 query_parts, matching_table = split_fasta_to_chunks(query, max_size, overlap)
407 print(query_parts)
408 for i, part in enumerate(query_parts):
409 print(f"running blast on chunk {i}")
410 print(part)
411 cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} "
412 "-max_target_seqs {4} "
413 "-gapopen {5} -gapextend {6} -word_size {7} -num_threads "
414 "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format(
415 part, db_formated, f'{blastfile}.{i}', evalue, max_target_seqs, gapopen,
416 gapextend,
417 word_size, num_threads, outfmt, reward, penalty
418 )
419 subprocess.check_call(cmd, shell=True)
420 print(cmd)
421 # remove part file
422 # os.unlink(part)
423 # merge blast results and recalculate start, end positions and header
424 merge_blast_results(blastfile, matching_table, n_parts=len(query_parts))
425
426 # remove temporary blast database
427 os.unlink(db_formated + ".nhr")
428 os.unlink(db_formated + ".nin")
429 os.unlink(db_formated + ".nsq")
430
431 def merge_blast_results(blastfile, matching_table, n_parts):
432 """
433 Merge blast tables and recalculate start, end positions based on
434 matching table
435 """
436 with open(blastfile, "w") as f:
437 matching_table_dict = {i[4]: i for i in matching_table}
438 print(matching_table_dict)
439 for i in range(n_parts):
440 with open(f'{blastfile}.{i}', "r") as f2:
441 for line in f2:
442 line = line.strip().split("\t")
443 # seqid (header) is in column 1
444 seqid = line[0]
445 line[0] = matching_table_dict[seqid][0]
446 # increase coordinates by start position of chunk
447 line[6] = str(int(line[6]) + matching_table_dict[seqid][2])
448 line[7] = str(int(line[7]) + matching_table_dict[seqid][2])
449 f.write("\t".join(line) + "\n")
450 # remove temporary blast file
451 # os.unlink(f'{blastfile}.{i}')
452
453 def main():
454 """
455 main function
456 """
457 # get command line arguments
458 parser = argparse.ArgumentParser(
459 description="""This script is used to parse blast output table to gff file""",
460 formatter_class=argparse.RawTextHelpFormatter
461 )
462 parser.add_argument(
463 '-i', '--input', default=None, required=True, help="input file", type=str,
464 action='store'
465 )
466 parser.add_argument(
467 '-d', '--db', default=None, required=False,
468 help="Fasta file with repeat database", type=str, action='store'
469 )
470 parser.add_argument(
471 '-o', '--output', default=None, required=True, help="output file name", type=str,
472 action='store'
473 )
474 parser.add_argument(
475 '-a', '--alternative_classification_coding', default=False,
476 help="Use alternative classification coding", action='store_true'
477 )
478 parser.add_argument(
479 '-f', '--fasta_input', default=False,
480 help="Input is fasta file instead of blast table", action='store_true'
481 )
482 parser.add_argument(
483 '-c', '--cpu', default=1, help="Number of cpu to use", type=int
484 )
485
486 args = parser.parse_args()
487
488 if args.fasta_input:
489 # run blast using blastn
490 blastfile = tempfile.NamedTemporaryFile().name
491 if args.db:
492 run_blastn(args.input, args.db, blastfile, num_threads=args.cpu)
493 else:
494 sys.exit("No repeat database provided")
495 else:
496 blastfile = args.input
497
498 # sort blast table
499 seq_id_counts, blast_sorted = sort_blast_table(blastfile, cpu=args.cpu)
500 disjoin_regions = blast2disjoint(
501 blast_sorted, seq_id_counts,
502 canonical_classification=not args.alternative_classification_coding
503 )
504
505 # remove short regions
506 disjoin_regions = remove_short_interrupting_regions(disjoin_regions)
507
508 # join neighboring regions with same classification
509 merged_regions = join_disjoint_regions_by_classification(disjoin_regions)
510
511 # remove short regions again
512 merged_regions = remove_short_interrupting_regions(merged_regions)
513
514 # merge again neighboring regions with same classification
515 merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=10)
516
517 # remove short weak regions
518 merged_regions = remove_short_regions(merged_regions)
519
520 # last merge
521 merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=20)
522 write_merged_regions_to_gff3(merged_regions, args.output)
523 # remove temporary files
524 os.remove(blast_sorted)
525
526
527 if __name__ == "__main__":
528 main()