blob: 11715c2c570f52b46fca92ea7f40b383790c0da5 [file] [log] [blame]
# coding: utf-8
from __future__ import unicode_literals, division, absolute_import, print_function
import sys
import unittest
import re
if sys.version_info < (3,):
str_cls = unicode # noqa
else:
str_cls = str
_non_local = {'patched': False}
def patch():
if sys.version_info >= (3, 0):
return
if _non_local['patched']:
return
if sys.version_info < (2, 7):
unittest.TestCase.assertIsInstance = _assert_is_instance
unittest.TestCase.assertRegex = _assert_regex
unittest.TestCase.assertRaises = _assert_raises
unittest.TestCase.assertRaisesRegex = _assert_raises_regex
unittest.TestCase.assertGreaterEqual = _assert_greater_equal
unittest.TestCase.assertLess = _assert_less
unittest.TestCase.assertLessEqual = _assert_less_equal
unittest.TestCase.assertIn = _assert_in
unittest.TestCase.assertNotIn = _assert_not_in
else:
unittest.TestCase.assertRegex = unittest.TestCase.assertRegexpMatches
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
_non_local['patched'] = True
def _safe_repr(obj):
try:
return repr(obj)
except Exception:
return object.__repr__(obj)
def _format_message(msg, standard_msg):
return msg or standard_msg
def _assert_greater_equal(self, a, b, msg=None):
if not a >= b:
standard_msg = '%s not greater than or equal to %s' % (_safe_repr(a), _safe_repr(b))
self.fail(_format_message(msg, standard_msg))
def _assert_less(self, a, b, msg=None):
if not a < b:
standard_msg = '%s not less than %s' % (_safe_repr(a), _safe_repr(b))
self.fail(_format_message(msg, standard_msg))
def _assert_less_equal(self, a, b, msg=None):
if not a <= b:
standard_msg = '%s not less than or equal to %s' % (_safe_repr(a), _safe_repr(b))
self.fail(_format_message(msg, standard_msg))
def _assert_is_instance(self, obj, cls, msg=None):
if not isinstance(obj, cls):
if not msg:
msg = '%s is not an instance of %r' % (obj, cls)
self.fail(msg)
def _assert_in(self, member, container, msg=None):
if member not in container:
standard_msg = '%s not found in %s' % (_safe_repr(member), _safe_repr(container))
self.fail(_format_message(msg, standard_msg))
def _assert_not_in(self, member, container, msg=None):
if member in container:
standard_msg = '%s found in %s' % (_safe_repr(member), _safe_repr(container))
self.fail(_format_message(msg, standard_msg))
def _assert_regex(self, text, expected_regexp, msg=None):
"""Fail the test unless the text matches the regular expression."""
if isinstance(expected_regexp, str_cls):
expected_regexp = re.compile(expected_regexp)
if not expected_regexp.search(text):
msg = msg or "Regexp didn't match"
msg = '%s: %r not found in %r' % (msg, expected_regexp.pattern, text)
self.fail(msg)
def _assert_raises(self, excClass, callableObj=None, *args, **kwargs): # noqa
context = _AssertRaisesContext(excClass, self)
if callableObj is None:
return context
with context:
callableObj(*args, **kwargs)
def _assert_raises_regex(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs):
if expected_regexp is not None:
expected_regexp = re.compile(expected_regexp)
context = _AssertRaisesContext(expected_exception, self, expected_regexp)
if callable_obj is None:
return context
with context:
callable_obj(*args, **kwargs)
class _AssertRaisesContext(object):
def __init__(self, expected, test_case, expected_regexp=None):
self.expected = expected
self.failureException = test_case.failureException
self.expected_regexp = expected_regexp
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
if exc_type is None:
try:
exc_name = self.expected.__name__
except AttributeError:
exc_name = str(self.expected)
raise self.failureException(
"{0} not raised".format(exc_name))
if not issubclass(exc_type, self.expected):
# let unexpected exceptions pass through
return False
self.exception = exc_value # store for later retrieval
if self.expected_regexp is None:
return True
expected_regexp = self.expected_regexp
if not expected_regexp.search(str(exc_value)):
raise self.failureException(
'"%s" does not match "%s"' %
(expected_regexp.pattern, str(exc_value))
)
return True