view env/lib/python3.9/site-packages/boltons/excutils.py @ 0:4f3585e2f14b draft default tip

"planemo upload commit 60cee0fc7c0cda8592644e1aad72851dec82c959"
author shellac
date Mon, 22 Mar 2021 18:12:50 +0000
parents
children
line wrap: on
line source

# -*- coding: utf-8 -*-

import sys
import traceback
import linecache
from collections import namedtuple

# TODO: last arg or first arg?  (last arg makes it harder to *args
#       into, but makes it more readable in the default exception
#       __repr__ output)
# TODO: Multiexception wrapper


__all__ = ['ExceptionCauseMixin']


class ExceptionCauseMixin(Exception):
    """
    A mixin class for wrapping an exception in another exception, or
    otherwise indicating an exception was caused by another exception.

    This is most useful in concurrent or failure-intolerant scenarios,
    where just because one operation failed, doesn't mean the remainder
    should be aborted, or that it's the appropriate time to raise
    exceptions.

    This is still a work in progress, but an example use case at the
    bottom of this module.

    NOTE: when inheriting, you will probably want to put the
    ExceptionCauseMixin first. Builtin exceptions are not good about
    calling super()
    """

    cause = None

    def __new__(cls, *args, **kw):
        cause = None
        if args and isinstance(args[0], Exception):
            cause, args = args[0], args[1:]
        ret = super(ExceptionCauseMixin, cls).__new__(cls, *args, **kw)
        ret.cause = cause
        if cause is None:
            return ret
        root_cause = getattr(cause, 'root_cause', None)
        if root_cause is None:
            ret.root_cause = cause
        else:
            ret.root_cause = root_cause

        full_trace = getattr(cause, 'full_trace', None)
        if full_trace is not None:
            ret.full_trace = list(full_trace)
            ret._tb = list(cause._tb)
            ret._stack = list(cause._stack)
            return ret

        try:
            exc_type, exc_value, exc_tb = sys.exc_info()
            if exc_type is None and exc_value is None:
                return ret
            if cause is exc_value or root_cause is exc_value:
                # handles when cause is the current exception or when
                # there are multiple wraps while handling the original
                # exception, but a cause was never provided
                ret._tb = _extract_from_tb(exc_tb)
                ret._stack = _extract_from_frame(exc_tb.tb_frame)
                ret.full_trace = ret._stack[:-1] + ret._tb
        finally:
            del exc_tb
        return ret

    def get_str(self):
        """
        Get formatted the formatted traceback and exception
        message. This function exists separately from __str__()
        because __str__() is somewhat specialized for the built-in
        traceback module's particular usage.
        """
        ret = []
        trace_str = self._get_trace_str()
        if trace_str:
            ret.extend(['Traceback (most recent call last):\n', trace_str])
        ret.append(self._get_exc_str())
        return ''.join(ret)

    def _get_message(self):
        args = getattr(self, 'args', [])
        if self.cause:
            args = args[1:]
        if args and args[0]:
            return args[0]
        return ''

    def _get_trace_str(self):
        if not self.cause:
            return super(ExceptionCauseMixin, self).__repr__()
        if self.full_trace:
            return ''.join(traceback.format_list(self.full_trace))
        return ''

    def _get_exc_str(self, incl_name=True):
        cause_str = _format_exc(self.root_cause)
        message = self._get_message()
        ret = []
        if incl_name:
            ret = [self.__class__.__name__, ': ']
        if message:
            ret.extend([message, ' (caused by ', cause_str, ')'])
        else:
            ret.extend([' caused by ', cause_str])
        return ''.join(ret)

    def __str__(self):
        if not self.cause:
            return super(ExceptionCauseMixin, self).__str__()
        trace_str = self._get_trace_str()
        ret = []
        if trace_str:
            message = self._get_message()
            if message:
                ret.extend([message, ' --- '])
            ret.extend(['Wrapped traceback (most recent call last):\n',
                        trace_str,
                        self._get_exc_str(incl_name=True)])
            return ''.join(ret)
        else:
            return self._get_exc_str(incl_name=False)


def _format_exc(exc, message=None):
    if message is None:
        message = exc
    exc_str = traceback._format_final_exc_line(exc.__class__.__name__, message)
    return exc_str.rstrip()


_BaseTBItem = namedtuple('_BaseTBItem', 'filename, lineno, name, line')


class _TBItem(_BaseTBItem):
    def __repr__(self):
        ret = super(_TBItem, self).__repr__()
        ret += ' <%r>' % self.frame_id
        return ret


class _DeferredLine(object):
    def __init__(self, filename, lineno, module_globals=None):
        self.filename = filename
        self.lineno = lineno
        module_globals = module_globals or {}
        self.module_globals = dict([(k, v) for k, v in module_globals.items()
                                    if k in ('__name__', '__loader__')])

    def __eq__(self, other):
        return (self.lineno, self.filename) == (other.lineno, other.filename)

    def __ne__(self, other):
        return (self.lineno, self.filename) != (other.lineno, other.filename)

    def __str__(self):
        if hasattr(self, '_line'):
            return self._line
        linecache.checkcache(self.filename)
        line = linecache.getline(self.filename,
                                 self.lineno,
                                 self.module_globals)
        if line:
            line = line.strip()
        else:
            line = None
        self._line = line
        return line

    def __repr__(self):
        return repr(str(self))

    def __len__(self):
        return len(str(self))

    def strip(self):
        return str(self).strip()


def _extract_from_frame(f=None, limit=None):
    ret = []
    if f is None:
        f = sys._getframe(1)  # cross-impl yadayada
    if limit is None:
        limit = getattr(sys, 'tracebacklimit', 1000)
    n = 0
    while f is not None and n < limit:
        filename = f.f_code.co_filename
        lineno = f.f_lineno
        name = f.f_code.co_name
        line = _DeferredLine(filename, lineno, f.f_globals)
        item = _TBItem(filename, lineno, name, line)
        item.frame_id = id(f)
        ret.append(item)
        f = f.f_back
        n += 1
    ret.reverse()
    return ret


def _extract_from_tb(tb, limit=None):
    ret = []
    if limit is None:
        limit = getattr(sys, 'tracebacklimit', 1000)
    n = 0
    while tb is not None and n < limit:
        filename = tb.tb_frame.f_code.co_filename
        lineno = tb.tb_lineno
        name = tb.tb_frame.f_code.co_name
        line = _DeferredLine(filename, lineno, tb.tb_frame.f_globals)
        item = _TBItem(filename, lineno, name, line)
        item.frame_id = id(tb.tb_frame)
        ret.append(item)
        tb = tb.tb_next
        n += 1
    return ret


# An Example/Prototest:


class MathError(ExceptionCauseMixin, ValueError):
    pass


def whoops_math():
    return 1/0


def math_lol(n=0):
    if n < 3:
        return math_lol(n=n+1)
    try:
        return whoops_math()
    except ZeroDivisionError as zde:
        exc = MathError(zde, 'ya done messed up')
        raise exc

def main():
    try:
        math_lol()
    except ValueError as me:
        exc = MathError(me, 'hi')
        raise exc


if __name__ == '__main__':
    try:
        main()
    except Exception:
        import pdb;pdb.post_mortem()
        raise