| import contextlib |
| import functools |
| import io |
| import re |
| import sqlite3 |
| import test.support |
| |
| |
| # Helper for temporary memory databases |
| def memory_database(*args, **kwargs): |
| cx = sqlite3.connect(":memory:", *args, **kwargs) |
| return contextlib.closing(cx) |
| |
| |
| # Temporarily limit a database connection parameter |
| @contextlib.contextmanager |
| def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128): |
| try: |
| _prev = cx.setlimit(category, limit) |
| yield limit |
| finally: |
| cx.setlimit(category, _prev) |
| |
| |
| def with_tracebacks(exc, regex="", name=""): |
| """Convenience decorator for testing callback tracebacks.""" |
| def decorator(func): |
| _regex = re.compile(regex) if regex else None |
| @functools.wraps(func) |
| def wrapper(self, *args, **kwargs): |
| with test.support.catch_unraisable_exception() as cm: |
| # First, run the test with traceback enabled. |
| with check_tracebacks(self, cm, exc, _regex, name): |
| func(self, *args, **kwargs) |
| |
| # Then run the test with traceback disabled. |
| func(self, *args, **kwargs) |
| return wrapper |
| return decorator |
| |
| |
| @contextlib.contextmanager |
| def check_tracebacks(self, cm, exc, regex, obj_name): |
| """Convenience context manager for testing callback tracebacks.""" |
| sqlite3.enable_callback_tracebacks(True) |
| try: |
| buf = io.StringIO() |
| with contextlib.redirect_stderr(buf): |
| yield |
| |
| self.assertEqual(cm.unraisable.exc_type, exc) |
| if regex: |
| msg = str(cm.unraisable.exc_value) |
| self.assertIsNotNone(regex.search(msg)) |
| if obj_name: |
| self.assertEqual(cm.unraisable.object.__name__, obj_name) |
| finally: |
| sqlite3.enable_callback_tracebacks(False) |
| |
| |
| class MemoryDatabaseMixin: |
| |
| def setUp(self): |
| self.con = sqlite3.connect(":memory:") |
| self.cur = self.con.cursor() |
| |
| def tearDown(self): |
| self.cur.close() |
| self.con.close() |
| |
| @property |
| def cx(self): |
| return self.con |
| |
| @property |
| def cu(self): |
| return self.cur |