comparison trimmer.py @ 1:464aee13e2df draft default tip

"planemo upload commit 8e52aac4afce4ab7c4d244e2b70f205f70c16749-dirty"
author nick
date Fri, 27 May 2022 23:29:45 +0000
parents 7f170cb06e2e
children
comparison
equal deleted inserted replaced
0:7f170cb06e2e 1:464aee13e2df
1 #!/usr/bin/env python 1 #!/usr/bin/env python3
2 from __future__ import division
3 import sys 2 import sys
4 import argparse 3 import argparse
4 import collections
5 import getreads 5 import getreads
6 6
7 OPT_DEFAULTS = {'win_len':1, 'thres':1.0, 'filt_bases':'N'} 7 QUANT_ORDER = 5
8 USAGE = "%(prog)s [options] [input_1.fq [input_2.fq output_1.fq output_2.fq]]" 8 USAGE = "%(prog)s [options] [input_1.fq [input_2.fq output_1.fq output_2.fq]]"
9 DESCRIPTION = """Trim the 5' ends of reads by sequence content, e.g. by GC content or presence of 9 DESCRIPTION = """Trim the 5' ends of reads by sequence content, e.g. by GC content or presence of
10 N's.""" 10 N's."""
11 11
12 12
13 def main(argv): 13 def make_argparser():
14
15 parser = argparse.ArgumentParser(description=DESCRIPTION, usage=USAGE) 14 parser = argparse.ArgumentParser(description=DESCRIPTION, usage=USAGE)
16 parser.set_defaults(**OPT_DEFAULTS)
17
18 parser.add_argument('infile1', metavar='reads_1.fq', nargs='?', type=argparse.FileType('r'), 15 parser.add_argument('infile1', metavar='reads_1.fq', nargs='?', type=argparse.FileType('r'),
19 default=sys.stdin, 16 default=sys.stdin,
20 help='Input reads (mate 1). Omit to read from stdin.') 17 help='Input reads (mate 1). Omit to read from stdin.')
21 parser.add_argument('infile2', metavar='reads_2.fq', nargs='?', type=argparse.FileType('r'), 18 parser.add_argument('infile2', metavar='reads_2.fq', nargs='?', type=argparse.FileType('r'),
22 help='Input reads (mate 2). If given, it will preserve pairs (if one read is filtered out ' 19 help='Input reads (mate 2). If given, it will preserve pairs (if one read is filtered out '
23 'entirely, the other will also be lost).') 20 'entirely, the other will also be lost).')
24 parser.add_argument('outfile1', metavar='reads.filt_1.fq', nargs='?', type=argparse.FileType('w'), 21 parser.add_argument('outfile1', metavar='output_1.fq', nargs='?', type=argparse.FileType('w'),
25 default=sys.stdout, 22 default=sys.stdout,
26 help='Output file for mate 1. WARNING: Will overwrite.') 23 help='Output file for mate 1. WARNING: Will overwrite.')
27 parser.add_argument('outfile2', metavar='reads.filt_2.fq', nargs='?', type=argparse.FileType('w'), 24 parser.add_argument('outfile2', metavar='output_2.fq', nargs='?', type=argparse.FileType('w'),
28 help='Output file for mate 2. WARNING: Will overwrite.') 25 help='Output file for mate 2. WARNING: Will overwrite.')
29 parser.add_argument('-f', '--format', dest='filetype', choices=('fasta', 'fastq'), 26 parser.add_argument('-f', '--format', dest='filetype', choices=('fasta', 'fastq'),
30 help='Input read format.') 27 help='Input read format.')
31 parser.add_argument('-F', '--out-format', dest='out_filetype', choices=('fasta', 'fastq'), 28 parser.add_argument('-F', '--out-format', dest='out_filetype', choices=('fasta', 'fastq'),
32 help='Output read format. Default: whatever the input format is.') 29 help='Output read format. Default: whatever the input format is.')
33 parser.add_argument('-b', '--filt-bases', 30 parser.add_argument('-b', '--filt-bases', default='N',
34 help='The bases to filter on. Case-insensitive. Default: %(default)s.') 31 help='The bases to filter on. Case-insensitive. Default: %(default)s.')
35 parser.add_argument('-t', '--thres', type=float, 32 parser.add_argument('-t', '--thres', type=float, default=0.5,
36 help='The threshold. The read will be trimmed once the proportion of filter bases in the ' 33 help='The threshold. The read will be trimmed once the proportion of filter bases in the '
37 'window exceed this fraction (not a percentage). Default: %(default)s.') 34 'window exceed this fraction (not a percentage). Default: %(default)s.')
38 parser.add_argument('-w', '--window', dest='win_len', type=int, 35 parser.add_argument('-w', '--window', dest='win_len', type=int, default=1,
39 help='Window size for trimming. Default: %(default)s.') 36 help='Window size for trimming. Default: %(default)s.')
40 parser.add_argument('-i', '--invert', action='store_true', 37 parser.add_argument('-i', '--invert', action='store_true',
41 help='Invert the filter bases: filter on bases NOT present in the --filt-bases.') 38 help='Invert the filter bases: filter on bases NOT present in the --filt-bases.')
42 parser.add_argument('-m', '--min-length', type=int, 39 parser.add_argument('-m', '--min-length', type=int, default=1,
43 help='Set a minimum read length. Reads which are trimmed below this length will be filtered ' 40 help='Set a minimum read length. Reads which are trimmed below this length will be filtered '
44 'out (omitted entirely from the output). Read pairs will be preserved: both reads in a ' 41 'out (omitted entirely from the output). Read pairs will be preserved: both reads in a '
45 'pair must exceed this length to be kept. Set to 0 to only omit empty reads.') 42 'pair must exceed this length to be kept. Set to 1 to only omit empty reads. '
46 parser.add_argument('--error', 43 'Default: %(default)s.')
47 help='Fail with this error message (useful for Galaxy tool).')
48 parser.add_argument('-A', '--acgt', action='store_true', 44 parser.add_argument('-A', '--acgt', action='store_true',
49 help='Filter on any non-ACGT base (shortcut for "--invert --filt-bases ACGT").') 45 help='Filter on any non-ACGT base (shortcut for "--invert --filt-bases ACGT").')
50 parser.add_argument('-I', '--iupac', action='store_true', 46 parser.add_argument('-I', '--iupac', action='store_true',
51 help='Filter on any non-IUPAC base (shortcut for "--invert --filt-bases ACGTUWSMKRYBDHVN-").') 47 help='Filter on any non-IUPAC base (shortcut for "--invert --filt-bases ACGTUWSMKRYBDHVN-").')
52 48 parser.add_argument('-q', '--quiet', action='store_true',
49 help='Don\'t print trimming stats on completion.')
50 parser.add_argument('-T', '--tsv', dest='stats_format', default='human',
51 action='store_const', const='tsv')
52 return parser
53
54
55 def main(argv):
56 parser = make_argparser()
53 args = parser.parse_args(argv[1:]) 57 args = parser.parse_args(argv[1:])
54
55 if args.error:
56 fail('Error: '+args.error)
57 58
58 # Catch invalid argument combinations. 59 # Catch invalid argument combinations.
59 if args.infile1 and args.infile2 and not (args.outfile1 and args.outfile2): 60 if args.infile1 and args.infile2 and not (args.outfile1 and args.outfile2):
60 fail('Error: If giving two input files (paired end), must specify both output files.') 61 fail('Error: If giving two input files (paired end), must specify both output files.')
61 # Determine filetypes, open input file parsers. 62 # Determine filetypes, open input file parsers.
83 elif args.iupac: 84 elif args.iupac:
84 filt_bases = 'ACGTUWSMKRYBDHVN-' 85 filt_bases = 'ACGTUWSMKRYBDHVN-'
85 invert = True 86 invert = True
86 87
87 # Do the actual trimming. 88 # Do the actual trimming.
89 filters = {'win_len':args.win_len, 'thres':args.thres, 'filt_bases':filt_bases, 'invert':invert,
90 'min_len':args.min_length}
88 try: 91 try:
89 trim_reads(file1_parser, file2_parser, args.outfile1, args.outfile2, filetype1, filetype2, 92 stats = trim_reads(file1_parser, file2_parser, args.outfile1, args.outfile2,
90 paired, args.win_len, args.thres, filt_bases, invert, args.min_length) 93 filetype1, filetype2, paired, filters)
91 finally: 94 finally:
92 for filehandle in (args.infile1, args.infile2, args.outfile1, args.outfile2): 95 for filehandle in (args.infile1, args.infile2, args.outfile1, args.outfile2):
93 if filehandle and filehandle is not sys.stdin and filehandle is not sys.stdout: 96 if filehandle and filehandle is not sys.stdin and filehandle is not sys.stdout:
94 filehandle.close() 97 filehandle.close()
95 98
99 if not args.quiet:
100 print_stats(stats, args.stats_format)
101
96 102
97 def trim_reads(file1_parser, file2_parser, outfile1, outfile2, filetype1, filetype2, paired, 103 def trim_reads(file1_parser, file2_parser, outfile1, outfile2, filetype1, filetype2, paired,
98 win_len, thres, filt_bases, invert, min_length): 104 filters):
99 """Trim all the reads in the input file(s), writing to the output file(s).""" 105 """Trim all the reads in the input file(s), writing to the output file(s)."""
106 min_len = filters['min_len']
107 trims1 = collections.Counter()
108 trims2 = collections.Counter()
109 omitted1 = collections.Counter()
110 omitted2 = collections.Counter()
100 read1 = None 111 read1 = None
101 read2 = None 112 read2 = None
102 while True: 113 while True:
103 # Read in the reads. 114 # Read in the reads.
104 try: 115 try:
106 if paired: 117 if paired:
107 read2 = next(file2_parser) 118 read2 = next(file2_parser)
108 except StopIteration: 119 except StopIteration:
109 break 120 break
110 # Do trimming. 121 # Do trimming.
111 read1.seq = trim_read(read1.seq, win_len, thres, filt_bases, invert) 122 read1, trim_len1 = trim_read(read1, filters, filetype1)
112 if filetype1 == 'fastq': 123 trims1[trim_len1] += 1
113 # If the output filetype is FASTQ, trim the quality scores too.
114 # If there are no input quality scores (i.e. the input is FASTA), use dummy scores instead.
115 # "z" is the highest alphanumeric score (PHRED 89), higher than any expected real score.
116 qual1 = read1.qual or 'z' * len(read1.seq)
117 read1.qual = qual1[:len(read1.seq)]
118 if paired: 124 if paired:
119 read2.seq = trim_read(read2.seq, win_len, thres, filt_bases, invert) 125 read2, trim_len2 = trim_read(read2, filters, filetype2)
120 if filetype2 == 'fastq': 126 trims2[trim_len2] += 1
121 qual2 = read2.qual or 'z' * len(read2.seq)
122 read2.qual = qual2[:len(read2.seq)]
123 # Output reads if they both pass the minimum length threshold (if any was given). 127 # Output reads if they both pass the minimum length threshold (if any was given).
124 if min_length is None or (len(read1.seq) >= min_length and len(read2.seq) >= min_length): 128 if min_len is None or (len(read1.seq) >= min_len and len(read2.seq) >= min_len):
125 write_read(outfile1, read1, filetype1) 129 write_read(outfile1, read1, filetype1)
126 write_read(outfile2, read2, filetype2) 130 write_read(outfile2, read2, filetype2)
131 else:
132 if len(read1.seq) < min_len:
133 omitted1[trim_len1] += 1
134 if len(read2.seq) < min_len:
135 omitted2[trim_len2] += 1
127 else: 136 else:
128 # Output read if it passes the minimum length threshold (if any was given). 137 # Output read if it passes the minimum length threshold (if any was given).
129 if min_length is None or len(read1.seq) >= min_length: 138 if min_len is None or len(read1.seq) >= min_len:
130 write_read(outfile1, read1, filetype1) 139 write_read(outfile1, read1, filetype1)
140 else:
141 omitted1[trim_len1] += 1
142 # Compile stats.
143 stats = {}
144 stats['reads'] = sum(trims1.values()) + sum(trims2.values())
145 stats['trimmed'] = stats['reads'] - trims1[0] - trims2[0]
146 stats['omitted'] = sum(omitted1.values()) + sum(omitted2.values())
147 if paired:
148 stats['trimmed1'] = stats['reads']//2 - trims1[0]
149 stats['trimmed2'] = stats['reads']//2 - trims2[0]
150 stats['omitted1'] = sum(omitted1.values())
151 stats['omitted2'] = sum(omitted2.values())
152 # Quintiles for trim lengths.
153 stats['quants'] = {'order':QUANT_ORDER}
154 if paired:
155 stats['quants']['trim1'] = get_counter_quantiles(trims1, order=QUANT_ORDER)
156 stats['quants']['trim2'] = get_counter_quantiles(trims2, order=QUANT_ORDER)
157 stats['quants']['trim'] = get_counter_quantiles(trims1 + trims2, order=QUANT_ORDER)
158 stats['quants']['omitted_trim1'] = get_counter_quantiles(omitted1, order=QUANT_ORDER)
159 stats['quants']['omitted_trim2'] = get_counter_quantiles(omitted2, order=QUANT_ORDER)
160 stats['quants']['omitted_trim'] = get_counter_quantiles(omitted1 + omitted2, order=QUANT_ORDER)
161 else:
162 stats['quants']['trim'] = get_counter_quantiles(trims1)
163 stats['quants']['omitted_trim'] = get_counter_quantiles(omitted1)
164 return stats
131 165
132 166
133 def get_filetype(infile, filetype_arg): 167 def get_filetype(infile, filetype_arg):
134 if infile is sys.stdin: 168 if infile is sys.stdin:
135 if filetype_arg: 169 if filetype_arg:
156 filehandle.write('>{name}\n{seq}\n'.format(**vars(read))) 190 filehandle.write('>{name}\n{seq}\n'.format(**vars(read)))
157 elif filetype == 'fastq': 191 elif filetype == 'fastq':
158 filehandle.write('@{name}\n{seq}\n+\n{qual}\n'.format(**vars(read))) 192 filehandle.write('@{name}\n{seq}\n+\n{qual}\n'.format(**vars(read)))
159 193
160 194
161 def trim_read(seq, win_len, thres, filt_bases, invert): 195 def trim_read(read, filters, filetype):
196 trimmed_seq = trim_seq(read.seq, **filters)
197 trim_len = len(read.seq) - len(trimmed_seq)
198 read.seq = trimmed_seq
199 if filetype == 'fastq':
200 # If the output filetype is FASTQ, trim the quality scores too.
201 # If there are no input quality scores (i.e. the input is FASTA), use dummy scores instead.
202 # "z" is the highest alphanumeric score (PHRED 89), higher than any expected real score.
203 qual = read.qual or 'z' * len(read.seq)
204 read.qual = qual[:len(read.seq)]
205 return read, trim_len
206
207
208 def trim_seq(seq, win_len=1, thres=1.0, filt_bases='N', invert=False, **kwargs):
162 """Trim an individual read and return its trimmed sequence. 209 """Trim an individual read and return its trimmed sequence.
163 This will track the frequency of bad bases in a window of length win_len, and trim once the 210 This will track the frequency of bad bases in a window of length win_len, and trim once the
164 frequency goes below thres. The trim point will be just before the first (leftmost) bad base in 211 frequency goes below thres. The trim point will be just before the first (leftmost) bad base in
165 the window (the first window with a frequency below thres). The "bad" bases are the ones in 212 the window (the first window with a frequency below thres). The "bad" bases are the ones in
166 filt_bases if invert is False, or any base NOT in filt_bases if invert is True.""" 213 filt_bases if invert is False, or any base NOT in filt_bases if invert is True."""
197 bad_base = first_base in filt_bases 244 bad_base = first_base in filt_bases
198 # If so, decrement the total and remove its coordinate from the window. 245 # If so, decrement the total and remove its coordinate from the window.
199 if bad_base: 246 if bad_base:
200 bad_bases_count -= 1 247 bad_bases_count -= 1
201 bad_bases_coords.pop(0) 248 bad_bases_coords.pop(0)
202 # print bad_bases_coords
203 # Are we over the threshold? 249 # Are we over the threshold?
204 if bad_bases_count > max_bad_bases: 250 if bad_bases_count > max_bad_bases:
205 break 251 break
206 # If we exceeded the threshold, trim the sequence at the first (leftmost) bad base in the window. 252 # If we exceeded the threshold, trim the sequence at the first (leftmost) bad base in the window.
207 if bad_bases_count > max_bad_bases: 253 if bad_bases_count > max_bad_bases:
209 return seq[0:first_bad_base] 255 return seq[0:first_bad_base]
210 else: 256 else:
211 return seq 257 return seq
212 258
213 259
260 def get_counter_quantiles(counter, order=5):
261 """Return an arbitrary set of quantiles (including min and max values).
262 `counter` is a collections.Counter.
263 `order` is which quantile to perform (4 = quartiles, 5 = quintiles).
264 Warning: This expects a counter which has counted at least `order` elements.
265 If it receives a counter with fewer elements, it will simply return `list(counter.elements())`.
266 This will have fewer than the usual order+1 elements, and may not fit normal expectations of
267 what "quantiles" should be."""
268 quantiles = []
269 total = sum(counter.values())
270 if total <= order:
271 return list(counter.elements())
272 span_size = total / order
273 # Sort the items and go through them, looking for the one at the break points.
274 items = list(sorted(counter.items(), key=lambda i: i[0]))
275 quantiles.append(items[0][0])
276 total_seen = 0
277 current_span = 1
278 cut_point = int(round(current_span*span_size))
279 for item, count in items:
280 total_seen += count
281 if total_seen >= cut_point:
282 quantiles.append(item)
283 current_span += 1
284 cut_point = int(round(current_span*span_size))
285 return quantiles
286
287
288 def print_stats(stats, format='human'):
289 if format == 'human':
290 lines = get_stats_lines_human(stats)
291 elif format == 'tsv':
292 lines = get_stats_lines_tsv(stats)
293 else:
294 fail('Error: Unrecognized format {!r}'.format(format))
295 sys.stderr.write('\n'.join(lines).format(**stats)+'\n')
296
297
298 def get_stats_lines_human(stats):
299 # Single-stat lines:
300 lines = [
301 'Total reads in input:\t{reads}',
302 'Reads trimmed:\t{trimmed}'
303 ]
304 if 'trimmed1' in stats and 'trimmed2' in stats:
305 lines.append(' For mate 1:\t{trimmed1}')
306 lines.append(' For mate 2:\t{trimmed2}')
307 lines.append('Reads filtered out:\t{omitted}')
308 if 'omitted1' in stats and 'omitted2' in stats:
309 lines.append(' For mate 1:\t{omitted1}')
310 lines.append(' For mate 2:\t{omitted2}')
311 # Quantile lines:
312 quantile_lines = [
313 ('Bases trimmed quintiles', 'trim'),
314 (' For mate 1', 'trim1'),
315 (' For mate 2', 'trim2'),
316 ('Bases trimmed quintiles from filtered reads', 'omitted_trim'),
317 (' For mate 1', 'omitted_trim1'),
318 (' For mate 2', 'omitted_trim2')
319 ]
320 for desc, stat_name in quantile_lines:
321 if stat_name in stats['quants']:
322 quants_values = stats['quants'][stat_name]
323 if quants_values:
324 quants_str = ', '.join(map(str, quants_values))
325 else:
326 quants_str = 'N/A'
327 line = desc+':\t'+quants_str
328 lines.append(line)
329 return lines
330
331
332 def get_stats_lines_tsv(stats):
333 lines = ['{reads}']
334 if 'trimmed1' in stats and 'trimmed2' in stats:
335 lines.append('{trimmed}\t{trimmed1}\t{trimmed2}')
336 else:
337 lines.append('{trimmed}')
338 if 'omitted1' in stats and 'omitted2' in stats:
339 lines.append('{omitted}\t{omitted1}\t{omitted2}')
340 else:
341 lines.append('{omitted}')
342 for stat_name in ('trim', 'trim1', 'trim2', 'omitted_trim', 'omitted_trim1', 'omitted_trim2'):
343 if stat_name in stats['quants']:
344 quants_values = stats['quants'][stat_name]
345 lines.append('\t'.join(map(str, quants_values)))
346 return lines
347
348
214 def fail(message): 349 def fail(message):
215 sys.stderr.write(message+"\n") 350 sys.stderr.write(message+"\n")
216 sys.exit(1) 351 sys.exit(1)
217 352
218 353