diff cdblib.py @ 0:2c498d40ecde

Uploaded
author miller-lab
date Mon, 09 Apr 2012 12:03:06 -0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/cdblib.py	Mon Apr 09 12:03:06 2012 -0400
@@ -0,0 +1,230 @@
+#!/usr/bin/env python
+
+'''
+Manipulate DJB's Constant Databases. These are 2 level disk-based hash tables
+that efficiently handle many keys, while remaining space-efficient.
+
+    http://cr.yp.to/cdb.html
+
+When generated databases are only used with Python code, consider using hash()
+rather than djb_hash() for a tidy speedup.
+'''
+
+from _struct import Struct
+from itertools import chain
+
+
+def py_djb_hash(s):
+    '''Return the value of DJB's hash function for the given 8-bit string.'''
+    h = 5381
+    for c in s:
+        h = (((h << 5) + h) ^ ord(c)) & 0xffffffff
+    return h
+
+try:
+    from _cdblib import djb_hash
+except ImportError:
+    djb_hash = py_djb_hash
+
+read_2_le4 = Struct('<LL').unpack
+write_2_le4 = Struct('<LL').pack
+
+
+class Reader(object):
+    '''A dictionary-like object for reading a Constant Database accessed
+    through a string or string-like sequence, such as mmap.mmap().'''
+
+    def __init__(self, data, hashfn=djb_hash):
+        '''Create an instance reading from a sequence and using hashfn to hash
+        keys.'''
+        if len(data) < 2048:
+            raise IOError('CDB too small')
+
+        self.data = data
+        self.hashfn = hashfn
+
+        self.index = [read_2_le4(data[i:i+8]) for i in xrange(0, 2048, 8)]
+        self.table_start = min(p[0] for p in self.index)
+        # Assume load load factor is 0.5 like official CDB.
+        self.length = sum(p[1] >> 1 for p in self.index)
+
+    def iteritems(self):
+        '''Like dict.iteritems(). Items are returned in insertion order.'''
+        pos = 2048
+        while pos < self.table_start:
+            klen, dlen = read_2_le4(self.data[pos:pos+8])
+            pos += 8
+
+            key = self.data[pos:pos+klen]
+            pos += klen
+
+            data = self.data[pos:pos+dlen]
+            pos += dlen
+
+            yield key, data
+
+    def items(self):
+        '''Like dict.items().'''
+        return list(self.iteritems())
+
+    def iterkeys(self):
+        '''Like dict.iterkeys().'''
+        return (p[0] for p in self.iteritems())
+    __iter__ = iterkeys
+
+    def itervalues(self):
+        '''Like dict.itervalues().'''
+        return (p[1] for p in self.iteritems())
+
+    def keys(self):
+        '''Like dict.keys().'''
+        return [p[0] for p in self.iteritems()]
+
+    def values(self):
+        '''Like dict.values().'''
+        return [p[1] for p in self.iteritems()]
+
+    def __getitem__(self, key):
+        '''Like dict.__getitem__().'''
+        value = self.get(key)
+        if value is None:
+            raise KeyError(key)
+        return value
+
+    def has_key(self, key):
+        '''Return True if key exists in the database.'''
+        return self.get(key) is not None
+    __contains__ = has_key
+
+    def __len__(self):
+        '''Return the number of records in the database.'''
+        return self.length
+
+    def gets(self, key):
+        '''Yield values for key in insertion order.'''
+        # Truncate to 32 bits and remove sign.
+        h = self.hashfn(key) & 0xffffffff
+        start, nslots = self.index[h & 0xff]
+
+        if nslots:
+            end = start + (nslots << 3)
+            slot_off = start + (((h >> 8) % nslots) << 3)
+
+            for pos in chain(xrange(slot_off, end, 8),
+                             xrange(start, slot_off, 8)):
+                rec_h, rec_pos = read_2_le4(self.data[pos:pos+8])
+
+                if not rec_h:
+                    break
+                elif rec_h == h:
+                    klen, dlen = read_2_le4(self.data[rec_pos:rec_pos+8])
+                    rec_pos += 8
+
+                    if self.data[rec_pos:rec_pos+klen] == key:
+                        rec_pos += klen
+                        yield self.data[rec_pos:rec_pos+dlen]
+
+    def get(self, key, default=None):
+        '''Get the first value for key, returning default if missing.'''
+        # Avoid exception catch when handling default case; much faster.
+        return chain(self.gets(key), (default,)).next()
+
+    def getint(self, key, default=None, base=0):
+        '''Get the first value for key converted it to an int, returning
+        default if missing.'''
+        value = self.get(key, default)
+        if value is not default:
+            return int(value, base)
+        return value
+
+    def getints(self, key, base=0):
+        '''Yield values for key in insertion order after converting to int.'''
+        return (int(v, base) for v in self.gets(key))
+
+    def getstring(self, key, default=None, encoding='utf-8'):
+        '''Get the first value for key decoded as unicode, returning default if
+        not found.'''
+        value = self.get(key, default)
+        if value is not default:
+            return value.decode(encoding)
+        return value
+
+    def getstrings(self, key, encoding='utf-8'):
+        '''Yield values for key in insertion order after decoding as
+        unicode.'''
+        return (v.decode(encoding) for v in self.gets(key))
+
+
+class Writer(object):
+    '''Object for building new Constant Databases, and writing them to a
+    seekable file-like object.'''
+
+    def __init__(self, fp, hashfn=djb_hash):
+        '''Create an instance writing to a file-like object, using hashfn to
+        hash keys.'''
+        self.fp = fp
+        self.hashfn = hashfn
+
+        fp.write('\x00' * 2048)
+        self._unordered = [[] for i in xrange(256)]
+
+    def put(self, key, value=''):
+        '''Write a string key/value pair to the output file.'''
+        assert type(key) is str and type(value) is str
+
+        pos = self.fp.tell()
+        self.fp.write(write_2_le4(len(key), len(value)))
+        self.fp.write(key)
+        self.fp.write(value)
+
+        h = self.hashfn(key) & 0xffffffff
+        self._unordered[h & 0xff].append((h, pos))
+
+    def puts(self, key, values):
+        '''Write more than one value for the same key to the output file.
+        Equivalent to calling put() in a loop.'''
+        for value in values:
+            self.put(key, value)
+
+    def putint(self, key, value):
+        '''Write an integer as a base-10 string associated with the given key
+        to the output file.'''
+        self.put(key, str(value))
+
+    def putints(self, key, values):
+        '''Write zero or more integers for the same key to the output file.
+        Equivalent to calling putint() in a loop.'''
+        self.puts(key, (str(value) for value in values))
+
+    def putstring(self, key, value, encoding='utf-8'):
+        '''Write a unicode string associated with the given key to the output
+        file after encoding it as UTF-8 or the given encoding.'''
+        self.put(key, unicode.encode(value, encoding))
+
+    def putstrings(self, key, values, encoding='utf-8'):
+        '''Write zero or more unicode strings to the output file. Equivalent to
+        calling putstring() in a loop.'''
+        self.puts(key, (unicode.encode(value, encoding) for value in values))
+
+    def finalize(self):
+        '''Write the final hash tables to the output file, and write out its
+        index. The output file remains open upon return.'''
+        index = []
+        for tbl in self._unordered:
+            length = len(tbl) << 1
+            ordered = [(0, 0)] * length
+            for pair in tbl:
+                where = (pair[0] >> 8) % length
+                for i in chain(xrange(where, length), xrange(0, where)):
+                    if not ordered[i][0]:
+                        ordered[i] = pair
+                        break
+
+            index.append((self.fp.tell(), length))
+            for pair in ordered:
+                self.fp.write(write_2_le4(*pair))
+
+        self.fp.seek(0)
+        for pair in index:
+            self.fp.write(write_2_le4(*pair))
+        self.fp = None # prevent double finalize()