Mercurial > repos > nick > sequence_content_trimmer
view trimmer.py @ 1:464aee13e2df draft default tip
"planemo upload commit 8e52aac4afce4ab7c4d244e2b70f205f70c16749-dirty"
author | nick |
---|---|
date | Fri, 27 May 2022 23:29:45 +0000 |
parents | 7f170cb06e2e |
children |
line wrap: on
line source
#!/usr/bin/env python3 import sys import argparse import collections import getreads QUANT_ORDER = 5 USAGE = "%(prog)s [options] [input_1.fq [input_2.fq output_1.fq output_2.fq]]" DESCRIPTION = """Trim the 5' ends of reads by sequence content, e.g. by GC content or presence of N's.""" def make_argparser(): parser = argparse.ArgumentParser(description=DESCRIPTION, usage=USAGE) parser.add_argument('infile1', metavar='reads_1.fq', nargs='?', type=argparse.FileType('r'), default=sys.stdin, help='Input reads (mate 1). Omit to read from stdin.') parser.add_argument('infile2', metavar='reads_2.fq', nargs='?', type=argparse.FileType('r'), help='Input reads (mate 2). If given, it will preserve pairs (if one read is filtered out ' 'entirely, the other will also be lost).') parser.add_argument('outfile1', metavar='output_1.fq', nargs='?', type=argparse.FileType('w'), default=sys.stdout, help='Output file for mate 1. WARNING: Will overwrite.') parser.add_argument('outfile2', metavar='output_2.fq', nargs='?', type=argparse.FileType('w'), help='Output file for mate 2. WARNING: Will overwrite.') parser.add_argument('-f', '--format', dest='filetype', choices=('fasta', 'fastq'), help='Input read format.') parser.add_argument('-F', '--out-format', dest='out_filetype', choices=('fasta', 'fastq'), help='Output read format. Default: whatever the input format is.') parser.add_argument('-b', '--filt-bases', default='N', help='The bases to filter on. Case-insensitive. Default: %(default)s.') parser.add_argument('-t', '--thres', type=float, default=0.5, help='The threshold. The read will be trimmed once the proportion of filter bases in the ' 'window exceed this fraction (not a percentage). Default: %(default)s.') parser.add_argument('-w', '--window', dest='win_len', type=int, default=1, help='Window size for trimming. Default: %(default)s.') parser.add_argument('-i', '--invert', action='store_true', help='Invert the filter bases: filter on bases NOT present in the --filt-bases.') parser.add_argument('-m', '--min-length', type=int, default=1, help='Set a minimum read length. Reads which are trimmed below this length will be filtered ' 'out (omitted entirely from the output). Read pairs will be preserved: both reads in a ' 'pair must exceed this length to be kept. Set to 1 to only omit empty reads. ' 'Default: %(default)s.') parser.add_argument('-A', '--acgt', action='store_true', help='Filter on any non-ACGT base (shortcut for "--invert --filt-bases ACGT").') parser.add_argument('-I', '--iupac', action='store_true', help='Filter on any non-IUPAC base (shortcut for "--invert --filt-bases ACGTUWSMKRYBDHVN-").') parser.add_argument('-q', '--quiet', action='store_true', help='Don\'t print trimming stats on completion.') parser.add_argument('-T', '--tsv', dest='stats_format', default='human', action='store_const', const='tsv') return parser def main(argv): parser = make_argparser() args = parser.parse_args(argv[1:]) # Catch invalid argument combinations. if args.infile1 and args.infile2 and not (args.outfile1 and args.outfile2): fail('Error: If giving two input files (paired end), must specify both output files.') # Determine filetypes, open input file parsers. filetype1 = get_filetype(args.infile1, args.filetype) file1_parser = iter(getreads.getparser(args.infile1, filetype=filetype1)) if args.infile2: paired = True filetype2 = get_filetype(args.infile2, args.filetype) file2_parser = iter(getreads.getparser(args.infile2, filetype=filetype2)) else: filetype2 = None file2_parser = None paired = False # Override output filetypes if it was specified on the command line. if args.out_filetype: filetype1 = args.out_filetype filetype2 = args.out_filetype # Determine the filter bases and whether to invert the selection. filt_bases = args.filt_bases invert = args.invert if args.acgt: filt_bases = 'ACGT' invert = True elif args.iupac: filt_bases = 'ACGTUWSMKRYBDHVN-' invert = True # Do the actual trimming. filters = {'win_len':args.win_len, 'thres':args.thres, 'filt_bases':filt_bases, 'invert':invert, 'min_len':args.min_length} try: stats = trim_reads(file1_parser, file2_parser, args.outfile1, args.outfile2, filetype1, filetype2, paired, filters) finally: for filehandle in (args.infile1, args.infile2, args.outfile1, args.outfile2): if filehandle and filehandle is not sys.stdin and filehandle is not sys.stdout: filehandle.close() if not args.quiet: print_stats(stats, args.stats_format) def trim_reads(file1_parser, file2_parser, outfile1, outfile2, filetype1, filetype2, paired, filters): """Trim all the reads in the input file(s), writing to the output file(s).""" min_len = filters['min_len'] trims1 = collections.Counter() trims2 = collections.Counter() omitted1 = collections.Counter() omitted2 = collections.Counter() read1 = None read2 = None while True: # Read in the reads. try: read1 = next(file1_parser) if paired: read2 = next(file2_parser) except StopIteration: break # Do trimming. read1, trim_len1 = trim_read(read1, filters, filetype1) trims1[trim_len1] += 1 if paired: read2, trim_len2 = trim_read(read2, filters, filetype2) trims2[trim_len2] += 1 # Output reads if they both pass the minimum length threshold (if any was given). if min_len is None or (len(read1.seq) >= min_len and len(read2.seq) >= min_len): write_read(outfile1, read1, filetype1) write_read(outfile2, read2, filetype2) else: if len(read1.seq) < min_len: omitted1[trim_len1] += 1 if len(read2.seq) < min_len: omitted2[trim_len2] += 1 else: # Output read if it passes the minimum length threshold (if any was given). if min_len is None or len(read1.seq) >= min_len: write_read(outfile1, read1, filetype1) else: omitted1[trim_len1] += 1 # Compile stats. stats = {} stats['reads'] = sum(trims1.values()) + sum(trims2.values()) stats['trimmed'] = stats['reads'] - trims1[0] - trims2[0] stats['omitted'] = sum(omitted1.values()) + sum(omitted2.values()) if paired: stats['trimmed1'] = stats['reads']//2 - trims1[0] stats['trimmed2'] = stats['reads']//2 - trims2[0] stats['omitted1'] = sum(omitted1.values()) stats['omitted2'] = sum(omitted2.values()) # Quintiles for trim lengths. stats['quants'] = {'order':QUANT_ORDER} if paired: stats['quants']['trim1'] = get_counter_quantiles(trims1, order=QUANT_ORDER) stats['quants']['trim2'] = get_counter_quantiles(trims2, order=QUANT_ORDER) stats['quants']['trim'] = get_counter_quantiles(trims1 + trims2, order=QUANT_ORDER) stats['quants']['omitted_trim1'] = get_counter_quantiles(omitted1, order=QUANT_ORDER) stats['quants']['omitted_trim2'] = get_counter_quantiles(omitted2, order=QUANT_ORDER) stats['quants']['omitted_trim'] = get_counter_quantiles(omitted1 + omitted2, order=QUANT_ORDER) else: stats['quants']['trim'] = get_counter_quantiles(trims1) stats['quants']['omitted_trim'] = get_counter_quantiles(omitted1) return stats def get_filetype(infile, filetype_arg): if infile is sys.stdin: if filetype_arg: filetype = filetype_arg else: fail('Error: You must specify the --format if reading from stdin.') elif infile: if filetype_arg: filetype = filetype_arg else: if infile.name.endswith('.fa') or infile.name.endswith('.fasta'): filetype = 'fasta' elif infile.name.endswith('.fq') or infile.name.endswith('.fastq'): filetype = 'fastq' else: fail('Error: Unrecognized file ending on "{}". Please specify the --format.'.format(infile)) else: fail('Error: infile is {}'.format(infile)) return filetype def write_read(filehandle, read, filetype): if filetype == 'fasta': filehandle.write('>{name}\n{seq}\n'.format(**vars(read))) elif filetype == 'fastq': filehandle.write('@{name}\n{seq}\n+\n{qual}\n'.format(**vars(read))) def trim_read(read, filters, filetype): trimmed_seq = trim_seq(read.seq, **filters) trim_len = len(read.seq) - len(trimmed_seq) read.seq = trimmed_seq if filetype == 'fastq': # If the output filetype is FASTQ, trim the quality scores too. # If there are no input quality scores (i.e. the input is FASTA), use dummy scores instead. # "z" is the highest alphanumeric score (PHRED 89), higher than any expected real score. qual = read.qual or 'z' * len(read.seq) read.qual = qual[:len(read.seq)] return read, trim_len def trim_seq(seq, win_len=1, thres=1.0, filt_bases='N', invert=False, **kwargs): """Trim an individual read and return its trimmed sequence. This will track the frequency of bad bases in a window of length win_len, and trim once the frequency goes below thres. The trim point will be just before the first (leftmost) bad base in the window (the first window with a frequency below thres). The "bad" bases are the ones in filt_bases if invert is False, or any base NOT in filt_bases if invert is True.""" # Algorithm: # The window is a list which acts as a FIFO. As we scan from the left (3') end to the right (5') # end, we append new bases to the right end of the window and pop them from the left end. # Each base is only examined twice: when it enters the window and when it leaves it. # We keep a running total of the number of bad bases in bad_bases_count, incrementing it when bad # bases enter the window and decrementing it when they leave. # We also track the location of bad bases in the window with bad_bases_coords so we can figure out # where to cut if we have to trim. max_bad_bases = win_len * thres window = [] bad_bases_count = 0 bad_bases_coords = [] for coord, base in enumerate(seq.upper()): # Shift window, adjust bad_bases_count and bad_bases_coords list. window.append(base) # Is the new base we're adding to the window a bad base? if invert: bad_base = base not in filt_bases else: bad_base = base in filt_bases # If so, increment the total and add its coordinate to the window. if bad_base: bad_bases_count += 1 bad_bases_coords.append(coord) if len(window) > win_len: first_base = window.pop(0) # Is the base we're removing (the first base in the window) a bad base? if invert: bad_base = first_base not in filt_bases else: bad_base = first_base in filt_bases # If so, decrement the total and remove its coordinate from the window. if bad_base: bad_bases_count -= 1 bad_bases_coords.pop(0) # Are we over the threshold? if bad_bases_count > max_bad_bases: break # If we exceeded the threshold, trim the sequence at the first (leftmost) bad base in the window. if bad_bases_count > max_bad_bases: first_bad_base = bad_bases_coords[0] return seq[0:first_bad_base] else: return seq def get_counter_quantiles(counter, order=5): """Return an arbitrary set of quantiles (including min and max values). `counter` is a collections.Counter. `order` is which quantile to perform (4 = quartiles, 5 = quintiles). Warning: This expects a counter which has counted at least `order` elements. If it receives a counter with fewer elements, it will simply return `list(counter.elements())`. This will have fewer than the usual order+1 elements, and may not fit normal expectations of what "quantiles" should be.""" quantiles = [] total = sum(counter.values()) if total <= order: return list(counter.elements()) span_size = total / order # Sort the items and go through them, looking for the one at the break points. items = list(sorted(counter.items(), key=lambda i: i[0])) quantiles.append(items[0][0]) total_seen = 0 current_span = 1 cut_point = int(round(current_span*span_size)) for item, count in items: total_seen += count if total_seen >= cut_point: quantiles.append(item) current_span += 1 cut_point = int(round(current_span*span_size)) return quantiles def print_stats(stats, format='human'): if format == 'human': lines = get_stats_lines_human(stats) elif format == 'tsv': lines = get_stats_lines_tsv(stats) else: fail('Error: Unrecognized format {!r}'.format(format)) sys.stderr.write('\n'.join(lines).format(**stats)+'\n') def get_stats_lines_human(stats): # Single-stat lines: lines = [ 'Total reads in input:\t{reads}', 'Reads trimmed:\t{trimmed}' ] if 'trimmed1' in stats and 'trimmed2' in stats: lines.append(' For mate 1:\t{trimmed1}') lines.append(' For mate 2:\t{trimmed2}') lines.append('Reads filtered out:\t{omitted}') if 'omitted1' in stats and 'omitted2' in stats: lines.append(' For mate 1:\t{omitted1}') lines.append(' For mate 2:\t{omitted2}') # Quantile lines: quantile_lines = [ ('Bases trimmed quintiles', 'trim'), (' For mate 1', 'trim1'), (' For mate 2', 'trim2'), ('Bases trimmed quintiles from filtered reads', 'omitted_trim'), (' For mate 1', 'omitted_trim1'), (' For mate 2', 'omitted_trim2') ] for desc, stat_name in quantile_lines: if stat_name in stats['quants']: quants_values = stats['quants'][stat_name] if quants_values: quants_str = ', '.join(map(str, quants_values)) else: quants_str = 'N/A' line = desc+':\t'+quants_str lines.append(line) return lines def get_stats_lines_tsv(stats): lines = ['{reads}'] if 'trimmed1' in stats and 'trimmed2' in stats: lines.append('{trimmed}\t{trimmed1}\t{trimmed2}') else: lines.append('{trimmed}') if 'omitted1' in stats and 'omitted2' in stats: lines.append('{omitted}\t{omitted1}\t{omitted2}') else: lines.append('{omitted}') for stat_name in ('trim', 'trim1', 'trim2', 'omitted_trim', 'omitted_trim1', 'omitted_trim2'): if stat_name in stats['quants']: quants_values = stats['quants'][stat_name] lines.append('\t'.join(map(str, quants_values))) return lines def fail(message): sys.stderr.write(message+"\n") sys.exit(1) if __name__ == '__main__': sys.exit(main(sys.argv))