view blast2html.py @ 124:6719353162b0

change sanitize_all_html warning style
author Jan Kanis <jan.code@jankanis.nl>
date Mon, 11 Aug 2014 17:32:09 +0200
parents 720efb818f1b
children 0fa962dd83b4
line wrap: on
line source

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Actually this program works with both python 2 and 3, tested against python 2.6

# Copyright The Hyve B.V. 2014
# License: GPL version 3 or (at your option) any higher version

from __future__ import unicode_literals, division

import sys
import math
import warnings
import six, codecs, io
from six.moves import builtins
from os import path
from itertools import repeat
from collections import defaultdict, namedtuple
from array import array
import glob
import argparse
from lxml import objectify
import jinja2

builtin_str = str
str = six.text_type



_filters = dict(float='float')
def filter(func_or_name):
    "Decorator to register a function as filter in the current jinja environment"
    if isinstance(func_or_name, six.string_types):
        def inner(func):
            _filters[func_or_name] = func.__name__
            return func
        return inner
    else:
        _filters[func_or_name.__name__] = func_or_name.__name__
        return func_or_name


def color_idx(length):
    if length < 40:
        return 0
    elif length < 50:
        return 1
    elif length < 80:
        return 2
    elif length < 200:
        return 3
    return 4

@filter
def fmt(val, fmt):
    return format(float(val), fmt)

@filter
def numfmt(val):
    """Format numbers in decimal notation, but without excessive trailing 0's.
    Default python float formatting will use scientific notation for some values,
    or append trailing zeros with the 'f' format type, and the number of digits differs
    between python 2 and 3."""
    fpart, ipart = math.modf(val)
    if fpart == 0:
        return str(int(val))
    # round to 10 to get identical representations in python 2 and 3
    s = format(round(val, 10), '.10f').rstrip('0')
    if s[-1] == '.':
        s += '0'
    return s

@filter
def firsttitle(hit):
    return str(hit.Hit_def).split('>')[0]

@filter
def othertitles(hit):
    """Split a hit.Hit_def that contains multiple titles up, splitting out the hit ids from the titles."""
    id_titles = str(hit.Hit_def).split('>')

    titles = []
    for t in id_titles[1:]:
        id, title = t.split(' ', 1)
        titles.append(argparse.Namespace(Hit_id = id,
                                         Hit_def = title,
                                         Hit_accession = '',
                                         getroottree = hit.getroottree))
    return titles

@filter
def hitid(hit):
    return str(hit.Hit_id)


@filter
def alignment_pre(hsp):
    """Create the preformatted alignment blocks"""

    # line break length
    linewidth = 60

    qfrom = int(hsp['Hsp_query-from'])
    qto = int(hsp['Hsp_query-to'])
    qframe = int(hsp['Hsp_query-frame'])
    hfrom = int(hsp['Hsp_hit-from'])
    hto = int(hsp['Hsp_hit-to'])
    hframe = int(hsp['Hsp_hit-frame'])
    
    qseq = hsp.Hsp_qseq.text
    midline = hsp.Hsp_midline.text
    hseq = hsp.Hsp_hseq.text

    if not qframe in (1, -1):
        warnings.warn("Error in BlastXML input: Hsp node {0} has a Hsp_query-frame of {1}. (should be 1 or -1)".format(nodeid(hsp), qframe))
        qframe = -1 if qframe < 0 else 1
    if not hframe in (1, -1):
        warnings.warn("Error in BlastXML input: Hsp node {0} has a Hsp_hit-frame of {1}. (should be 1 or -1)".format(nodeid(hsp), hframe))
        hframe = -1 if hframe < 0 else 1
    
    def split(txt):
        return [txt[i:i+linewidth] for i in range(0, len(txt), linewidth)]

    for qs, mid, hs, offset in zip(split(qseq), split(midline), split(hseq), range(0, len(qseq), linewidth)):
        yield (
            "Query  {0:>7}  {1}  {2}\n".format(qfrom+offset*qframe, qs, qfrom+(offset+len(qs)-1)*qframe) +
            "       {0:7}  {1}\n".format('', mid) +
            "Subject{0:>7}  {1}  {2}".format(hfrom+offset*hframe, hs, hfrom+(offset+len(hs)-1)*hframe)
        )
        
    if qfrom+(len(qseq)-1)*qframe != qto:
        warnings.warn("Error in BlastXML input: Hsp node {0} qseq length mismatch: from {1} to {2} length {3}".format(
            nodeid(hsp), qfrom, qto, len(qseq)))
    if hfrom+(len(hseq)-1)*hframe != hto:
        warnings.warn("Error in BlastXML input: Hsp node {0} hseq length mismatch: from {1} to {2} length {3}".format(
            nodeid(hsp), hfrom, hto, len(hseq)))

    

@filter('len')
def blastxml_len(node):
    if node.tag == 'Hsp':
        return int(node['Hsp_align-len'])
    elif node.tag == 'Iteration':
        return int(node['Iteration_query-len'])
    raise Exception("Unknown XML node type: "+node.tag)

@filter
def nodeid(node):
    id = []
    if node.tag == 'Hsp':
        id.insert(0, node.Hsp_num.text)
        node = node.getparent().getparent()
        assert node.tag == 'Hit'
    if node.tag == 'Hit':
        id.insert(0, node.Hit_num.text)
        node = node.getparent().getparent()
        assert node.tag == 'Iteration'
    if node.tag == 'Iteration':
        id.insert(0, node['Iteration_iter-num'].text)
        return '-'.join(id)
    raise ValueError("The nodeid filter can only be applied to Hsp, Hit or Iteration nodes in a BlastXML document")

    
@filter
def asframe(frame):
    if frame == 1:
        return 'Plus'
    elif frame == -1:
        return 'Minus'
    raise Exception("frame should be either +1 or -1")

# def genelink(hit, type='genbank', hsp=None):
#     if not isinstance(hit, six.string_types):
#         hit = hitid(hit)
#     link = "http://www.ncbi.nlm.nih.gov/nucleotide/{0}?report={1}&log$=nuclalign".format(hit, type)
#     if hsp != None:
#         link += "&from={0}&to={1}".format(hsp['Hsp_hit-from'], hsp['Hsp_hit-to'])
#     return link


# javascript escape filter based on Django's, from https://github.com/dsissitka/khan-website/blob/master/templatefilters.py#L112-139
# I've removed the html escapes, since html escaping is already being performed by the template engine.

# The r'\u0027' syntax doesn't work the way we need to in python 2.6 with unicode_literals
_base_js_escapes = (
    ('\\', '\\u005C'),
    ('\'', '\\u0027'),
    ('"', '\\u0022'),
    # ('>', '\\u003E'),
    # ('<', '\\u003C'),
    # ('&', '\\u0026'),
    # ('=', '\\u003D'),
    # ('-', '\\u002D'),
    # (';', '\\u003B'),
    (u'\u2028', '\\u2028'),
    (u'\u2029', '\\u2029')
)

# Escape every ASCII character with a value less than 32. This is
# needed a.o. to prevent html parsers from jumping out of javascript
# parsing mode.
_js_escapes = (_base_js_escapes +
               tuple(('%c' % z, '\\u%04X' % z) for z in range(32)))

@filter
def js_string_escape(value):
    """
    Javascript string literal escape. Note that this only escapes data
    for embedding within javascript string literals, not in general
    javascript snippets.
    """

    value = str(value)

    for bad, good in _js_escapes:
        value = value.replace(bad, good)

    return value

@filter
def hits(result):
    # Use findall so we get an empty list if there are no Hit elements at all
    return result.Iteration_hits.findall('Hit')

@filter('params')
def result_params(iteration):
    return (('Query number', iteration['Iteration_iter-num']),
            ('Query ID', iteration['Iteration_query-ID']),
            ('Definition line', iteration['Iteration_query-def']),
            ('Length', blastxml_len(iteration)))


class BlastVisualize:

    colors = ('black', 'blue', 'green', 'magenta', 'red')

    max_scale_labels = 10

    def __init__(self, input, templatedir, templatename, genelinks):
        self.input = input
        self.templatename = templatename
        self.genelinks = genelinks

        self.blast = objectify.parse(self.input).getroot()
        self.loader = jinja2.FileSystemLoader(searchpath=templatedir)
        self.environment = jinja2.Environment(loader=self.loader,
                                              lstrip_blocks=True, trim_blocks=True, autoescape=True)

        self._addfilters(self.environment)


    def _addfilters(self, environment):
        for filtername, funcname in _filters.items():
            try:
                environment.filters[filtername] = getattr(self, funcname)
            except AttributeError:
                try:
                    environment.filters[filtername] = globals()[funcname]
                except KeyError:
                    environment.filters[filtername] = getattr(builtins, funcname)

    def render(self, output):
        template = self.environment.get_template(self.templatename)

        params = (('Program', self.blast.BlastOutput_version),
                  ('Database', self.blast.BlastOutput_db),
        )

        result = template.stream(blast=self.blast,
                                 iterations=self.blast.BlastOutput_iterations.Iteration,
                                 colors=self.colors,
                                 params=params)

        result.dump(output)

    @filter
    def match_colors(self, result):
        """
        An iterator that yields lists of length-color pairs. 
        """

        query_length = blastxml_len(result)
        
        percent_multiplier = 100 / query_length

        for hit in hits(result):
            # sort hotspots from short to long, so we can overwrite index colors of
            # short matches with those of long ones.
            hotspots = sorted(hit.Hit_hsps.Hsp, key=lambda hsp: blastxml_len(hsp))
            table = bytearray([255]) * query_length
            for hsp in hotspots:
                frm = hsp['Hsp_query-from'] - 1
                to = int(hsp['Hsp_query-to'])
                table[frm:to] = repeat(color_idx(blastxml_len(hsp)), to - frm)

            matches = []
            last = table[0]
            count = 0
            for i in range(query_length):
                if table[i] == last:
                    count += 1
                    continue
                matches.append((count * percent_multiplier, self.colors[last] if last != 255 else 'transparent'))
                last = table[i]
                count = 1
            matches.append((count * percent_multiplier, self.colors[last] if last != 255 else 'transparent'))

            yield dict(colors=matches, hit=hit, defline=firsttitle(hit))

    @filter
    def queryscale(self, result):
        query_length = blastxml_len(result)
        skip = math.ceil(query_length / self.max_scale_labels)
        percent_multiplier = 100 / query_length
        for i in range(1, query_length+1):
            if i % skip == 0:
                yield dict(label = i, width = skip * percent_multiplier)
        if query_length % skip != 0:
            yield dict(label = query_length,
                       width = (query_length % skip) * percent_multiplier)

    @filter
    def hit_info(self, result):

        query_length = blastxml_len(result)

        for hit in hits(result):
            hsps = hit.Hit_hsps.Hsp

            # In python 2.6 array doesn't accept unicode type codes, but in 3.4 it requires them
            typecode_B = builtin_str('B')
            cover = array(typecode_B, [0]) * query_length
            for hsp in hsps:
                cover[hsp['Hsp_query-from']-1 : int(hsp['Hsp_query-to'])] = array(typecode_B, [1]) * blastxml_len(hsp)
            cover_count = cover.count(1)

            best_hsp = max(hsps, key=lambda h: h['Hsp_bit-score'])

            yield dict(hit = hit,
                       title = firsttitle(hit),
                       maxscore = format(float(best_hsp['Hsp_bit-score']), '.1f'),
                       e_value = format(float(best_hsp.Hsp_evalue), '.4'),
                       # float(...) because non-flooring division doesn't work with lxml elements in python 2.6
                       ident = format(float(best_hsp.Hsp_identity) / blastxml_len(best_hsp), '.0%'),
                       totalscore = format(sum(hsp['Hsp_bit-score'] for hsp in hsps), '.1f'),
                       cover = format(cover_count / query_length, '.0%'),
                  )

    @filter
    def genelink(self, hit, text=None, text_from='hitid', cssclass=None, display_nolink=True):
        """Create a html link from a hit node to a configured gene bank webpage.
        text: The text of the link. If not set applies text_from.
        text_from: string, if text is not specified, take it from specified source. Either 'hitid' (default) or 'dbname'.
        cssclass: extra css classes that will be added to the <a> element
        display_nolink: boolean, if false don't display anything if no link can be created. Default True.
        """
        
        db = hit.getroottree().getroot().BlastOutput_db

        template = self.genelinks[db].template

        if text is None:
            if text_from == 'hitid':
                text = hitid(hit)
            elif text_from == 'dbname':
                text = self.genelinks[db].dbname
            else:
                raise ValueError("Unknown value for text_from: '{0}'. Use 'hitid' or 'dbname'.".format(text_from))

        if template is None:
            return text if display_nolink else ''

        args = dict(id=hitid(hit).split('|'),
                    fullid=hitid(hit),
                    defline=str(hit.Hit_def).split(' ', 1)[0].split('|'),
                    fulldefline=str(hit.Hit_def).split(' ', 1)[0],
                    accession=str(hit.Hit_accession))
        try:
            link = template.format(**args)
        except Exception as e:
            warnings.warn('Error in formatting gene bank link {} with {}: {}'.format(template, args, e))
            return text if display_nolink else ''

        classattr = 'class="{0}" '.format(jinja2.escape(cssclass)) if cssclass is not None else ''
        return jinja2.Markup("<a {0}target=\"_top\" href=\"{1}\">{2}</a>".format(classattr, jinja2.escape(link), jinja2.escape(text)))


genelinks_entry = namedtuple('genelinks_entry', 'dbname template')
def read_blastdb(dir, default):
    links = defaultdict(lambda: default)
    # blastdb.loc, blastdb_p.loc, blastdb_d.loc, etc.
    files = sorted(glob.glob(path.join(dir, 'blastdb*.loc')))
    # reversed, so blastdb.loc will take precedence
    for f in reversed(files):
        try:
            f = open(path.join(dir, f))
            for l in f.readlines():
                if l.strip().startswith('#'):
                    continue
                line = l.rstrip('\n').split('\t')
                try:
                    links[line[2]] = genelinks_entry(dbname=line[3] or default.dbname, template=line[4])
                except IndexError:
                    continue
            f.close()
        except OSError:
            continue
    if not links:
        if not files:
            warnings.warn("No gene bank link templates found (no blastdb*.loc files found in {0})".format(dir))
        else:
            warnings.warn("No gene bank link templates found in {0}".format(', '.join(files)))
    return links


def main():
    default_template = path.join(path.dirname(__file__), 'blast2html.html.jinja')

    parser = argparse.ArgumentParser(description="Convert a BLAST XML result into a nicely readable html page",
                                     usage="{0} [-i] INPUT [-o OUTPUT] [--genelink-template URL_TEMPLATE] [--dbname DBNAME]".format(sys.argv[0]))
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('positional_arg', metavar='INPUT', nargs='?', type=argparse.FileType(mode='r'),
                             help='The input Blast XML file, same as -i/--input')
    input_group.add_argument('-i', '--input', type=argparse.FileType(mode='r'), 
                             help='The input Blast XML file')
    parser.add_argument('-o', '--output', type=argparse.FileType(mode='w'), default=sys.stdout,
                        help='The output html file')
    # We just want the file name here, so jinja can open the file
    # itself. But it is easier to just use a FileType so argparse can
    # handle the errors. This introduces a small race condition when
    # jinja later tries to re-open the template file, but we don't
    # care too much.
    parser.add_argument('--template', type=argparse.FileType(mode='r'), default=default_template,
                        help='The template file to use. Defaults to blast_html.html.jinja')

    parser.add_argument('--dbname', type=str, default='Gene Bank',
                        help="The link text to use for external links to a gene bank database. Defaults to 'Gene Bank'")
    parser.add_argument('--genelink-template', metavar='URL_TEMPLATE',
                        default='http://www.ncbi.nlm.nih.gov/nucleotide/{accession}?report=genbank&log$=nuclalign',
                        help="""A link template to link hits to a gene bank webpage. The template string is a 
                        Python format string. It can contain the following replacement elements: {id[N]}, {fullid}, 
                        {defline[N]}, {fulldefline}, {accession}, where N is a number. id[N] and defline[N] will be 
                        replaced by the Nth element of the id or defline, where '|' is the field separator. 
                        
                        The default is 'http://www.ncbi.nlm.nih.gov/nucleotide/{accession}?report=genbank&log$=nuclalign',
                        which is a link to the NCBI nucleotide database.""")

    parser.add_argument('--db-config-dir',
                        help="""The directory where databases are configured in blastdb*.loc files. These files
                        are consulted for creating a gene bank link. The files should conform to the format that
                        Galaxy's BLAST expect, i.e. tab-separated tables (with lines starting with '#' ignored),
                        with two extra fields. The third field of a line should be a database path and the fourth
                        a genebank link template conforming to the --genelink-template option syntax. Entries in
                        these config files override links specified using --genelink-template and --dbname.""")
    
    args = parser.parse_args()
    if args.input == None:
        args.input = args.positional_arg
    if args.input == None:
        parser.error('no input specified')

    if six.PY2:
        # The argparse.FileType wrapper doesn't support an encoding
        # argument, so for python 2 we need to wrap or reopen the
        # output. The input files are already read as utf-8 by the
        # respective libraries.
        #
        # One option is using codecs, but the codecs' writelines()
        # method doesn't support streaming but collects all output and
        # writes at once (see Python issues #5445 and #21910). On the
        # other hand the io module is slower (though not
        # significantly).

        # args.output = codecs.getwriter('utf-8')(args.output)
        # def fixed_writelines(iter, self=args.output):
        #     for i in iter:
        #         self.write(i)
        # args.output.writelines = fixed_writelines

        args.output.close()
        args.output = io.open(args.output.name, 'w', encoding='utf-8')

    templatedir, templatename = path.split(args.template.name)
    args.template.close()
    if not templatedir:
        templatedir = '.'

    defaultentry = genelinks_entry(args.dbname, args.genelink_template)
    if args.db_config_dir is None:
        genelinks = defaultdict(lambda: defaultentry)
    elif not path.isdir(args.db_config_dir):
        parser.error('db-config-dir does not exist or is not a directory')
    else:
        genelinks = read_blastdb(args.db_config_dir, default=defaultentry)

    b = BlastVisualize(args.input, templatedir, templatename, genelinks=genelinks)
    b.render(args.output)
    args.output.close()


if __name__ == '__main__':
    main()