diff blast_html.py @ 19:67ddcb807b7d

make it work with multiple queries
author Jan Kanis <jan.code@jankanis.nl>
date Tue, 13 May 2014 18:06:36 +0200
parents 4434ffab721a
children 53cd304c5f26
line wrap: on
line diff
--- a/blast_html.py	Tue May 13 15:26:20 2014 +0200
+++ b/blast_html.py	Tue May 13 18:06:36 2014 +0200
@@ -19,11 +19,11 @@
     "Decorator to register a function as filter in the current jinja environment"
     if isinstance(func_or_name, str):
         def inner(func):
-            _filters[func_or_name] = func
+            _filters[func_or_name] = func.__name__
             return func
         return inner
     else:
-        _filters[func_or_name.__name__] = func_or_name
+        _filters[func_or_name.__name__] = func_or_name.__name__
         return func_or_name
 
 
@@ -78,8 +78,13 @@
     )
 
 @filter('len')
-def hsplen(node):
-    return int(node['Hsp_align-len'])
+def blastxml_len(node):
+    if node.tag == 'Hsp':
+        return int(node['Hsp_align-len'])
+    elif node.tag == 'Iteration':
+        return int(node['Iteration_query-len'])
+    raise Exception("Unknown XML node type: "+node.tag)
+        
 
 @filter
 def asframe(frame):
@@ -134,6 +139,13 @@
 
     return value
 
+@filter
+def hits(result):
+    # sort hits by longest hotspot first
+    return sorted(result.Iteration_hits.findall('Hit'),
+                  key=lambda h: max(blastxml_len(hsp) for hsp in h.Hit_hsps.Hsp),
+                  reverse=True)
+
 
 
 class BlastVisualize:
@@ -151,15 +163,15 @@
         self.environment = jinja2.Environment(loader=self.loader,
                                               lstrip_blocks=True, trim_blocks=True, autoescape=True)
 
-        self.environment.filters.update(_filters)
-        self.environment.filters['color'] = lambda length: match_colors[color_idx(length)]
+        self._addfilters(self.environment)
+
 
-        self.query_length = int(self.blast["BlastOutput_query-len"])
-        self.hits = self.blast.BlastOutput_iterations.Iteration.Iteration_hits.Hit
-        # sort hits by longest hotspot first
-        self.ordered_hits = sorted(self.hits,
-                                   key=lambda h: max(hsplen(hsp) for hsp in h.Hit_hsps.Hsp),
-                                   reverse=True)
+    def _addfilters(self, environment):
+        for filtername, funcname in _filters.items():
+            try:
+                environment.filters[filtername] = getattr(self, funcname)
+            except AttributeError:
+                environment.filters[filtername] = globals()[funcname]
 
     def render(self, output):
         template = self.environment.get_template(self.templatename)
@@ -171,41 +183,38 @@
                   ('Database', self.blast.BlastOutput_db),
         )
 
-        if len(self.blast.BlastOutput_iterations.Iteration) > 1:
-            warnings.warn("Multiple 'Iteration' elements found, showing only the first")
-
         output.write(template.render(blast=self.blast,
-                                     length=self.query_length,
-                                     hits=self.blast.BlastOutput_iterations.Iteration.Iteration_hits.Hit,
+                                     iterations=self.blast.BlastOutput_iterations.Iteration,
                                      colors=self.colors,
-                                     match_colors=self.match_colors(),
-                                     queryscale=self.queryscale(),
-                                     hit_info=self.hit_info(),
+                                     # match_colors=self.match_colors(),
+                                     # hit_info=self.hit_info(),
                                      genelink=genelink,
                                      params=params))
-        
 
-    def match_colors(self):
+    @filter
+    def match_colors(self, result):
         """
         An iterator that yields lists of length-color pairs. 
         """
 
-        percent_multiplier = 100 / self.query_length
+        query_length = blastxml_len(result)
+        
+        percent_multiplier = 100 / query_length
 
-        for hit in self.hits:
+        for hit in hits(result):
             # sort hotspots from short to long, so we can overwrite index colors of
             # short matches with those of long ones.
-            hotspots = sorted(hit.Hit_hsps.Hsp, key=lambda hsp: hsplen(hsp))
-            table = bytearray([255]) * self.query_length
+            hotspots = sorted(hit.Hit_hsps.Hsp, key=lambda hsp: blastxml_len(hsp))
+            table = bytearray([255]) * query_length
             for hsp in hotspots:
                 frm = hsp['Hsp_query-from'] - 1
                 to = int(hsp['Hsp_query-to'])
-                table[frm:to] = repeat(color_idx(hsplen(hsp)), to - frm)
+                table[frm:to] = repeat(color_idx(blastxml_len(hsp)), to - frm)
 
             matches = []
             last = table[0]
             count = 0
-            for i in range(self.query_length):
+            for i in range(query_length):
                 if table[i] == last:
                     count += 1
                     continue
@@ -216,25 +225,28 @@
 
             yield dict(colors=matches, link="#hit"+hit.Hit_num.text, defline=firsttitle(hit))
 
-
-    def queryscale(self):
-        skip = math.ceil(self.query_length / self.max_scale_labels)
-        percent_multiplier = 100 / self.query_length
-        for i in range(1, self.query_length+1):
+    @filter
+    def queryscale(self, result):
+        query_length = blastxml_len(result)
+        skip = math.ceil(query_length / self.max_scale_labels)
+        percent_multiplier = 100 / query_length
+        for i in range(1, query_length+1):
             if i % skip == 0:
                 yield dict(label = i, width = skip * percent_multiplier)
-        if self.query_length % skip != 0:
-            yield dict(label = self.query_length, width = (self.query_length % skip) * percent_multiplier)
-
+        if query_length % skip != 0:
+            yield dict(label = query_length, width = (query_length % skip) * percent_multiplier)
 
-    def hit_info(self):
+    @filter
+    def hit_info(self, result):
 
-        for hit in self.ordered_hits:
+        query_length = blastxml_len(result)
+
+        for hit in hits(result):
             hsps = hit.Hit_hsps.Hsp
 
-            cover = [False] * self.query_length
+            cover = [False] * query_length
             for hsp in hsps:
-                cover[hsp['Hsp_query-from']-1 : int(hsp['Hsp_query-to'])] = repeat(True, hsplen(hsp))
+                cover[hsp['Hsp_query-from']-1 : int(hsp['Hsp_query-to'])] = repeat(True, blastxml_len(hsp))
             cover_count = cover.count(True)
 
             def hsp_val(path):
@@ -245,10 +257,10 @@
                        link_id = hit.Hit_num,
                        maxscore = "{:.1f}".format(max(hsp_val('Hsp_bit-score'))),
                        totalscore = "{:.1f}".format(sum(hsp_val('Hsp_bit-score'))),
-                       cover = "{:.0%}".format(cover_count / self.query_length),
+                       cover = "{:.0%}".format(cover_count / query_length),
                        e_value = "{:.4g}".format(min(hsp_val('Hsp_evalue'))),
                        # FIXME: is this the correct formula vv?
-                       ident = "{:.0%}".format(float(min(hsp.Hsp_identity / hsplen(hsp) for hsp in hsps))),
+                       ident = "{:.0%}".format(float(min(hsp.Hsp_identity / blastxml_len(hsp) for hsp in hsps))),
                        accession = hit.Hit_accession)
 
 def main():