view sre_yield.py @ 0:5397da1ef896 draft

Uploaded
author gianmarco_piccinno
date Tue, 21 May 2019 05:05:15 -0400
parents
children
line wrap: on
line source

#!/usr/bin/env python2
#
# Copyright 2011-2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# vim: sw=2 sts=2 et

"""This module can generate all strings that match a regular expression.

The regex is parsed using the SRE module that is standard in python,
then the data structure is executed to form a bunch of iterators.
"""

__author__ = 'alexperry@google.com (Alex Perry)'
__all__ = ['Values', 'AllStrings', 'AllMatches', 'ParseError']


import bisect
import math
import re
import sre_constants
import sre_parse
import string
import sys
import types

import cachingseq
import fastdivmod

try:
    xrange = xrange
except NameError:
    xrange = range

_RE_METACHARS = r'$^{}*+\\'
_ESCAPED_METACHAR = r'\\[' + _RE_METACHARS + r']'
ESCAPED_METACHAR_RE = re.compile(_ESCAPED_METACHAR)
# ASCII by default, see https://github.com/google/sre_yield/issues/3
CHARSET = [chr(c) for c in range(256)]

WORD = string.ascii_letters + string.digits + '_'

try:
    DEFAULT_RE_FLAGS = re.ASCII
except AttributeError:
    DEFAULT_RE_FLAGS = 0

STATE_START, STATE_MIDDLE, STATE_END = list(range(3))

def Not(chars):
    return ''.join(sorted(set(CHARSET) - set(chars)))


CATEGORIES = {
    sre_constants.CATEGORY_WORD: WORD,
    sre_constants.CATEGORY_NOT_WORD: Not(WORD),
    sre_constants.CATEGORY_DIGIT: string.digits,
    sre_constants.CATEGORY_NOT_DIGIT: Not(string.digits),
    sre_constants.CATEGORY_SPACE: string.whitespace,
    sre_constants.CATEGORY_NOT_SPACE: Not(string.whitespace),
}

# This constant varies between builds of Python; this is the lower value.
MAX_REPEAT_COUNT = 65535


class ParseError(Exception):
    pass


def slice_indices(slice_obj, size):
    """slice_obj.indices() except this one supports longs."""
    # start stop step
    start = slice_obj.start
    stop = slice_obj.stop
    step = slice_obj.step

    # We don't always update a value for negative indices (if we wrote it here
    # due to None).
    if step is None:
        step = 1
    if start is None:
        if step > 0:
            start = 0
        else:
            start = size - 1
    else:
        start = _adjust_index(start, size)

    if stop is None:
        if step > 0:
            stop = size
        else:
            stop = -1
    else:
        stop = _adjust_index(stop, size)

    return (start, stop, step)


def _adjust_index(n, size):
    if n < 0:
        n += size

    if n < 0:
        raise IndexError("Out of range")
    if n > size:
        n = size
    return n


def _xrange(*args):
    """Because xrange doesn't support longs :("""
    # prefer real xrange if it works
    try:
        return xrange(*args)
    except OverflowError:
        return _bigrange(*args)


def _bigrange(*args):
    if len(args) == 1:
        start = 0; stop = args[0]; step = 1
    elif len(args) == 2:
        start, stop = args
        step = 1
    elif len(args) == 3:
        start, stop, step = args
    else:
        raise ValueError("Too many args for _bigrange")

    i = start
    while True:
        yield i
        i += step
        if step < 0 and i <= stop:
            break
        if step > 0 and i >= stop:
            break


class WrappedSequence(object):
    """This wraps a sequence, purely as a base clase for the other uses."""

    def __init__(self, raw):
        # Derived classes will likely override this constructor
        self.raw = raw
        # Note that we can't use the function len() because it insists on trying
        # to convert the returned number from a long-int to an ordinary int.
        self.length = raw.__len__()

    def get_item(self, i, d=None):
        i = _adjust_index(i, self.length)
        if hasattr(self.raw, 'get_item'):
            return self.raw.get_item(i, d)
        return self.raw[i]

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        # If the user wanted a slice, we provide a wrapper
        if isinstance(i, slice):
            result = SlicedSequence(self, slicer=i)
            if result.__len__() < 16:
                # Short lists are unpacked
                result = [item for item in result]
            return result
        i = _adjust_index(i, self.length)
        # Usually we just call the user-provided function
        return self.get_item(i)

    def __iter__(self):
        for i in _xrange(int(self.length)):
            yield self.get_item(i)


def _sign(x):
    if x > 0:
        return 1
    else:
        return -1


class SlicedSequence(WrappedSequence):
    """This is part of an immutable and potentially arbitrarily long list."""

    def __init__(self, raw, slicer=None):
        # Derived classes will likely override this constructor
        self.raw = raw
        if slicer is None:
            self.start, self.stop, self.step = 0, raw.__len__(), 1
        else:
            self.start, self.stop, self.step = slice_indices(slicer, raw.__len__())

        # Integer round up, depending on step direction
        self.length = ((self.stop - self.start + self.step - _sign(self.step)) /
                       self.step)

    def get_item(self, i, d=None):
        j = i * self.step + self.start
        return self.raw[j]


class ConcatenatedSequence(WrappedSequence):
    """This is equivalent to using extend() but without unpacking the lists."""

    def __init__(self, *alternatives):
        self.list_lengths = [(a, a.__len__()) for a in alternatives]
        self.length = sum(a_len for _, a_len in self.list_lengths)

    def get_item(self, i, d=None):
        for a, a_len in self.list_lengths:
            if i < a_len:
                return a[i]
            i -= a_len
        raise IndexError('Too Big')

    def __contains__(self, item):
        for a, _ in self.list_lengths:
            if item in a:
                return True
        return False

    def __repr__(self):
        return '{concat ' + repr(self.list_lengths) + '}'


class CombinatoricsSequence(WrappedSequence):
    """This uses all combinations of one item from each passed list."""

    def __init__(self, *components):
        self.list_lengths = [(a, a.__len__()) for a in components]
        self.length = 1
        for _, c_len in self.list_lengths:
            self.length *= c_len

    def get_item(self, i, d=None):
        result = []
        if i < 0:
            i += self.length
        if i < 0 or i >= self.length:
            raise IndexError("Index %d out of bounds" % (i,))

        if len(self.list_lengths) == 1:
            # skip unnecessary ''.join -- big speedup
            return self.list_lengths[0][0][i]

        for c, c_len in self.list_lengths:
            i, mod = divmod(i, c_len)
            if hasattr(c, 'get_item'):
                result.append(c.get_item(mod, d))
            else:
                result.append(c[mod])
        return ''.join(result)

    def __repr__(self):
        return '{combin ' + repr(self.list_lengths) + '}'


class RepetitiveSequence(WrappedSequence):
    """This chooses an entry from a list, many times, and concatenates."""

    def __init__(self, content, lowest=1, highest=1):
        self.content = content
        self.content_length = content.__len__()
        self.length = fastdivmod.powersum(self.content_length, lowest, highest)
        self.lowest = lowest
        self.highest = highest

        def arbitrary_entry(i):
            return (fastdivmod.powersum(self.content_length, lowest, i+lowest-1), i+lowest)

        def entry_from_prev(i, prev):
            return (prev[0] + (self.content_length ** prev[1]), prev[1] + 1)

        self.offsets = cachingseq.CachingFuncSequence(
            arbitrary_entry, highest - lowest+1, entry_from_prev)
        # This needs to be a constant in order to reuse caclulations in future
        # calls to bisect (a moving target will produce more misses).
        if self.offsets[-1][0] > sys.maxsize:
            i = 0
            while i + 2 < len(self.offsets):
                if self.offsets[i+1][0] > sys.maxsize:
                    self.index_of_offset = i
                    self.offset_break = self.offsets[i][0]
                    break
                i += 1
        else:
            self.index_of_offset = len(self.offsets)
            self.offset_break = sys.maxsize

    def get_item(self, i, d=None):
        """Finds out how many repeats this index implies, then picks strings."""
        if i < self.offset_break:
            by_bisect = bisect.bisect_left(self.offsets, (i, -1), hi=self.index_of_offset)
        else:
            by_bisect = bisect.bisect_left(self.offsets, (i, -1), lo=self.index_of_offset)

        if by_bisect == len(self.offsets) or self.offsets[by_bisect][0] > i:
            by_bisect -= 1

        num = i - self.offsets[by_bisect][0]
        count = self.offsets[by_bisect][1]

        if count > 100 and self.content_length < 1000:
            content = list(self.content)
        else:
            content = self.content

        result = []

        if count == 0:
            return ''

        for modulus in fastdivmod.divmod_iter(num, self.content_length):
            result.append(content[modulus])

        leftover = count - len(result)
        if leftover:
            assert leftover > 0
            result.extend([content[0]] * leftover)

        # smallest place value ends up on the right
        return ''.join(result[::-1])

    def __repr__(self):
        return '{repeat base=%d low=%d high=%d}' % (self.content_length, self.lowest, self.highest)


class SaveCaptureGroup(WrappedSequence):
    def __init__(self, parsed, key):
        self.key = key
        super(SaveCaptureGroup, self).__init__(parsed)

    def get_item(self, n, d=None):
        rv = super(SaveCaptureGroup, self).get_item(n, d)
        if d is not None:
            d[self.key] = rv
        return rv


class ReadCaptureGroup(WrappedSequence):
    def __init__(self, n):
        self.num = n
        self.length = 1

    def get_item(self, i, d=None):
        if i != 0:
            raise IndexError(i)
        if d is None:
            raise ValueError('ReadCaptureGroup with no dict')
        return d.get(self.num, "fail")


class RegexMembershipSequence(WrappedSequence):
    """Creates a sequence from the regex, knows how to test membership."""

    def empty_list(self, *_):
        return []

    def nothing_added(self, *_):
        return ['']

    def branch_values(self, _, items):
        """Converts SRE parser data into literals and merges those lists."""
        return ConcatenatedSequence(
            *[self.sub_values(parsed) for parsed in items])

    def max_repeat_values(self, min_count, max_count, items):
        """Sequential expansion of the count to be combinatorics."""
        max_count = min(max_count, self.max_count)
        return RepetitiveSequence(
            self.sub_values(items), min_count, max_count)

    def in_values(self, items):
        # Special case which distinguishes branch from charset operator
        if items and items[0][0] == sre_constants.NEGATE:
            items = self.branch_values(None, items[1:])
            return [item for item in self.charset if item not in items]
        return self.branch_values(None, items)

    def not_literal(self, y):
        return self.in_values(((sre_constants.NEGATE,),
                              (sre_constants.LITERAL, y),))

    def category(self, y):
        return CATEGORIES[y]

    def groupref(self, n):
        self.has_groupref = True
        return ReadCaptureGroup(n)

    def get_item(self, i, d=None):
        """Typically only pass i.  d is an internal detail, for consistency with other classes.

        If you care about the capture groups, you should use
        RegexMembershipSequenceMatches instead, which returns a Match object
        instead of a string."""
        if self.has_groupref or d is not None:
            if d is None:
                d = {}
            return super(RegexMembershipSequence, self).get_item(i, d)
        else:
            return super(RegexMembershipSequence, self).get_item(i)

    def sub_values(self, parsed):
        """This knows how to convert one piece of parsed pattern."""
        # If this is a subpattern object, we just want its data
        if isinstance(parsed, sre_parse.SubPattern):
            parsed = parsed.data
        # A list indicates sequential elements of a string
        if isinstance(parsed, list):
            elements = [self.sub_values(p) for p in parsed]
            return CombinatoricsSequence(*elements)
        # If not a list, a tuple represents a specific match type
        if isinstance(parsed, tuple) and parsed:
            matcher, arguments = parsed
            if not isinstance(arguments, tuple):
                arguments = (arguments,)
            if matcher in self.backends:
                self.check_anchor_state(matcher, arguments)
                return self.backends[matcher](*arguments)
        # No idea what to do here
        raise ParseError(repr(parsed))

    def maybe_save(self, *args):
        # Python 3.6 has group, add_flags, del_flags, parsed
        # while earlier versions just have group, parsed
        group = args[0]
        parsed = args[-1]
        rv = self.sub_values(parsed)
        if group is not None:
            rv = SaveCaptureGroup(rv, group)
        return rv

    def check_anchor_state(self, matcher, arguments):
        # A bit of a hack to support zero-width leading anchors.  The goal is
        # that /^(a|b)$/ will match properly, and that /a^b/ or /a\bb/ throws
        # an error.  (It's unfortunate that I couldn't easily handle /$^/ which
        # matches the empty string; I went for the common case.)
        #
        # There are three states, for example:
        # / STATE_START
        # | / STATE_START (^ causes no transition here, but is illegal at STATE_MIDDLE or STATE_END)
        # | |  / STATE_START (\b causes no transition here, but advances MIDDLE to END)
        # | |  | / (same as above for ^)
        # | |  | | / STATE_MIDDLE (anything besides ^ and \b advances START to MIDDLE)
        # | |  | | | / still STATE_MIDDLE
        # . .  . . . .  / advances MIDDLE to END
        #  ^ \b ^ X Y \b $
        old_state = self.state
        if self.state == STATE_START:
            if matcher == sre_constants.AT:
                if arguments[0] in (sre_constants.AT_END, sre_constants.AT_END_STRING):
                    self.state = STATE_END
                elif arguments[0] == sre_constants.AT_NON_BOUNDARY:
                    # This is nonsensical at beginning of string
                    raise ParseError('Anchor %r found at START state' % (arguments[0],))
                # All others (AT_BEGINNING, AT_BEGINNING_STRING, and AT_BOUNDARY) remain in START.
            elif matcher != sre_constants.SUBPATTERN:
                self.state = STATE_MIDDLE
            # subpattern remains in START
        elif self.state == STATE_END:
            if matcher == sre_constants.AT:
                if arguments[0] not in (
                    sre_constants.AT_END, sre_constants.AT_END_STRING,
                    sre_constants.AT_BOUNDARY):
                    raise ParseError('Anchor %r found at END state' % (arguments[0],))
                # those three remain in END
            elif matcher != sre_constants.SUBPATTERN:
                raise ParseError('Non-end-anchor %r found at END state' % (arguments[0],))
            # subpattern remains in END
        else:  # self.state == STATE_MIDDLE
            if matcher == sre_constants.AT:
                if arguments[0] not in (
                    sre_constants.AT_END, sre_constants.AT_END_STRING,
                    sre_constants.AT_BOUNDARY):
                    raise ParseError('Anchor %r found at MIDDLE state' % (arguments[0],))
                # All others (AT_END, AT_END_STRING, AT_BOUNDARY) advance to END.
                self.state = STATE_END

    def __init__(self, pattern, flags=0, charset=CHARSET, max_count=None):
        # If the RE module cannot compile it, we give up quickly
        self.matcher = re.compile(r'(?:%s)\Z' % pattern, flags)
        if not flags & re.DOTALL:
            charset = ''.join(c for c in charset if c != '\n')
        self.charset = charset

        self.named_group_lookup = self.matcher.groupindex

        flags |= DEFAULT_RE_FLAGS # https://github.com/google/sre_yield/issues/3
        if flags & re.IGNORECASE:
            raise ParseError('Flag "i" not supported. https://github.com/google/sre_yield/issues/4')
        elif flags & re.UNICODE:
            raise ParseError('Flag "u" not supported. https://github.com/google/sre_yield/issues/3')
        elif flags & re.LOCALE:
            raise ParseError('Flag "l" not supported. https://github.com/google/sre_yield/issues/5')

        if max_count is None:
            self.max_count = MAX_REPEAT_COUNT
        else:
            self.max_count = max_count

        self.has_groupref = False

        # Configure the parser backends
        self.backends = {
            sre_constants.LITERAL: lambda y: [chr(y)],
            sre_constants.RANGE: lambda l, h: [chr(c) for c in range(l, h+1)],
            sre_constants.SUBPATTERN: self.maybe_save,
            sre_constants.BRANCH: self.branch_values,
            sre_constants.MIN_REPEAT: self.max_repeat_values,
            sre_constants.MAX_REPEAT: self.max_repeat_values,
            sre_constants.AT: self.nothing_added,
            sre_constants.ASSERT: self.empty_list,
            sre_constants.ASSERT_NOT: self.empty_list,
            sre_constants.ANY:
                lambda _: self.in_values(((sre_constants.NEGATE,),)),
            sre_constants.IN: self.in_values,
            sre_constants.NOT_LITERAL: self.not_literal,
            sre_constants.CATEGORY: self.category,
            sre_constants.GROUPREF: self.groupref,
        }
        self.state = STATE_START
        # Now build a generator that knows all possible patterns
        self.raw = self.sub_values(sre_parse.parse(pattern, flags))
        # Configure this class instance to know about that result
        self.length = self.raw.__len__()

    def __contains__(self, item):
        # Since we have a regex, we can search the list really cheaply
        return self.matcher.match(item) is not None


class RegexMembershipSequenceMatches(RegexMembershipSequence):
    def __getitem__(self, i):
        if isinstance(i, slice):
            result = SlicedSequence(self, slicer=i)
            if result.__len__() < 16:
                # Short lists are unpacked
                result = [item for item in result]
            return result

        d = {}
        s = super(RegexMembershipSequenceMatches, self).get_item(i, d)
        return Match(s, d, self.named_group_lookup)


def AllStrings(regex, flags=0, charset=CHARSET, max_count=None):
    """Constructs an object that will generate all matching strings."""
    return RegexMembershipSequence(regex, flags, charset, max_count=max_count)

Values = AllStrings


class Match(object):
    def __init__(self, string, groups, named_groups):
        # TODO keep group(0) only, and spans for the rest.
        self._string = string
        self._groups = groups
        self._named_groups = named_groups
        self.lastindex = len(groups) + 1

    def group(self, n=0):
        if n == 0:
            return self._string
        if not isinstance(n, int):
            n = self._named_groups[n]
        return self._groups[n]

    def groups(self):
        return tuple(self._groups[i] for i in range(1, self.lastindex))

    def groupdict(self):
        d = {}
        for k, v in self._named_groups.items():
            d[k] = self._groups[v]
        return d

    def span(self, n=0):
        raise NotImplementedError()


def AllMatches(regex, flags=0, charset=CHARSET, max_count=None):
    """Constructs an object that will generate all matching strings."""
    return RegexMembershipSequenceMatches(regex, flags, charset, max_count=max_count)


def main(argv=None):
    """This module can be executed on the command line for testing."""
    if argv is None:
        argv = sys.argv
    for arg in argv[1:]:
        for i in AllStrings(arg):
            print(i)


if __name__ == '__main__':
    main()