blob: e7f8a945c1a9afb0214428f7ae06d5fb54f61565 [file] [log] [blame]
"""This script contains the actual auditing tests.
It should not be imported directly, but should be run by the test_audit
module with arguments identifying each test.
"""
import contextlib
import os
import sys
class TestHook:
"""Used in standard hook tests to collect any logged events.
Should be used in a with block to ensure that it has no impact
after the test completes.
"""
def __init__(self, raise_on_events=None, exc_type=RuntimeError):
self.raise_on_events = raise_on_events or ()
self.exc_type = exc_type
self.seen = []
self.closed = False
def __enter__(self, *a):
sys.addaudithook(self)
return self
def __exit__(self, *a):
self.close()
def close(self):
self.closed = True
@property
def seen_events(self):
return [i[0] for i in self.seen]
def __call__(self, event, args):
if self.closed:
return
self.seen.append((event, args))
if event in self.raise_on_events:
raise self.exc_type("saw event " + event)
# Simple helpers, since we are not in unittest here
def assertEqual(x, y):
if x != y:
raise AssertionError(f"{x!r} should equal {y!r}")
def assertIn(el, series):
if el not in series:
raise AssertionError(f"{el!r} should be in {series!r}")
def assertNotIn(el, series):
if el in series:
raise AssertionError(f"{el!r} should not be in {series!r}")
def assertSequenceEqual(x, y):
if len(x) != len(y):
raise AssertionError(f"{x!r} should equal {y!r}")
if any(ix != iy for ix, iy in zip(x, y)):
raise AssertionError(f"{x!r} should equal {y!r}")
@contextlib.contextmanager
def assertRaises(ex_type):
try:
yield
assert False, f"expected {ex_type}"
except BaseException as ex:
if isinstance(ex, AssertionError):
raise
assert type(ex) is ex_type, f"{ex} should be {ex_type}"
def test_basic():
with TestHook() as hook:
sys.audit("test_event", 1, 2, 3)
assertEqual(hook.seen[0][0], "test_event")
assertEqual(hook.seen[0][1], (1, 2, 3))
def test_block_add_hook():
# Raising an exception should prevent a new hook from being added,
# but will not propagate out.
with TestHook(raise_on_events="sys.addaudithook") as hook1:
with TestHook() as hook2:
sys.audit("test_event")
assertIn("test_event", hook1.seen_events)
assertNotIn("test_event", hook2.seen_events)
def test_block_add_hook_baseexception():
# Raising BaseException will propagate out when adding a hook
with assertRaises(BaseException):
with TestHook(
raise_on_events="sys.addaudithook", exc_type=BaseException
) as hook1:
# Adding this next hook should raise BaseException
with TestHook() as hook2:
pass
def test_marshal():
import marshal
o = ("a", "b", "c", 1, 2, 3)
payload = marshal.dumps(o)
with TestHook() as hook:
assertEqual(o, marshal.loads(marshal.dumps(o)))
try:
with open("test-marshal.bin", "wb") as f:
marshal.dump(o, f)
with open("test-marshal.bin", "rb") as f:
assertEqual(o, marshal.load(f))
finally:
os.unlink("test-marshal.bin")
actual = [(a[0], a[1]) for e, a in hook.seen if e == "marshal.dumps"]
assertSequenceEqual(actual, [(o, marshal.version)] * 2)
actual = [a[0] for e, a in hook.seen if e == "marshal.loads"]
assertSequenceEqual(actual, [payload])
actual = [e for e, a in hook.seen if e == "marshal.load"]
assertSequenceEqual(actual, ["marshal.load"])
def test_pickle():
import pickle
class PicklePrint:
def __reduce_ex__(self, p):
return str, ("Pwned!",)
payload_1 = pickle.dumps(PicklePrint())
payload_2 = pickle.dumps(("a", "b", "c", 1, 2, 3))
# Before we add the hook, ensure our malicious pickle loads
assertEqual("Pwned!", pickle.loads(payload_1))
with TestHook(raise_on_events="pickle.find_class") as hook:
with assertRaises(RuntimeError):
# With the hook enabled, loading globals is not allowed
pickle.loads(payload_1)
# pickles with no globals are okay
pickle.loads(payload_2)
def test_monkeypatch():
class A:
pass
class B:
pass
class C(A):
pass
a = A()
with TestHook() as hook:
# Catch name changes
C.__name__ = "X"
# Catch type changes
C.__bases__ = (B,)
# Ensure bypassing __setattr__ is still caught
type.__dict__["__bases__"].__set__(C, (B,))
# Catch attribute replacement
C.__init__ = B.__init__
# Catch attribute addition
C.new_attr = 123
# Catch class changes
a.__class__ = B
actual = [(a[0], a[1]) for e, a in hook.seen if e == "object.__setattr__"]
assertSequenceEqual(
[(C, "__name__"), (C, "__bases__"), (C, "__bases__"), (a, "__class__")], actual
)
def test_open():
# SSLContext.load_dh_params uses _Py_fopen_obj rather than normal open()
try:
import ssl
load_dh_params = ssl.create_default_context().load_dh_params
except ImportError:
load_dh_params = None
# Try a range of "open" functions.
# All of them should fail
with TestHook(raise_on_events={"open"}) as hook:
for fn, *args in [
(open, sys.argv[2], "r"),
(open, sys.executable, "rb"),
(open, 3, "wb"),
(open, sys.argv[2], "w", -1, None, None, None, False, lambda *a: 1),
(load_dh_params, sys.argv[2]),
]:
if not fn:
continue
with assertRaises(RuntimeError):
fn(*args)
actual_mode = [(a[0], a[1]) for e, a in hook.seen if e == "open" and a[1]]
actual_flag = [(a[0], a[2]) for e, a in hook.seen if e == "open" and not a[1]]
assertSequenceEqual(
[
i
for i in [
(sys.argv[2], "r"),
(sys.executable, "r"),
(3, "w"),
(sys.argv[2], "w"),
(sys.argv[2], "rb") if load_dh_params else None,
]
if i is not None
],
actual_mode,
)
assertSequenceEqual([], actual_flag)
def test_cantrace():
traced = []
def trace(frame, event, *args):
if frame.f_code == TestHook.__call__.__code__:
traced.append(event)
old = sys.settrace(trace)
try:
with TestHook() as hook:
# No traced call
eval("1")
# No traced call
hook.__cantrace__ = False
eval("2")
# One traced call
hook.__cantrace__ = True
eval("3")
# Two traced calls (writing to private member, eval)
hook.__cantrace__ = 1
eval("4")
# One traced call (writing to private member)
hook.__cantrace__ = 0
finally:
sys.settrace(old)
assertSequenceEqual(["call"] * 4, traced)
def test_mmap():
import mmap
with TestHook() as hook:
mmap.mmap(-1, 8)
assertEqual(hook.seen[0][1][:2], (-1, 8))
def test_excepthook():
def excepthook(exc_type, exc_value, exc_tb):
if exc_type is not RuntimeError:
sys.__excepthook__(exc_type, exc_value, exc_tb)
def hook(event, args):
if event == "sys.excepthook":
if not isinstance(args[2], args[1]):
raise TypeError(f"Expected isinstance({args[2]!r}, " f"{args[1]!r})")
if args[0] != excepthook:
raise ValueError(f"Expected {args[0]} == {excepthook}")
print(event, repr(args[2]))
sys.addaudithook(hook)
sys.excepthook = excepthook
raise RuntimeError("fatal-error")
def test_unraisablehook():
from _testcapi import write_unraisable_exc
def unraisablehook(hookargs):
pass
def hook(event, args):
if event == "sys.unraisablehook":
if args[0] != unraisablehook:
raise ValueError(f"Expected {args[0]} == {unraisablehook}")
print(event, repr(args[1].exc_value), args[1].err_msg)
sys.addaudithook(hook)
sys.unraisablehook = unraisablehook
write_unraisable_exc(RuntimeError("nonfatal-error"), "for audit hook test", None)
def test_winreg():
from winreg import OpenKey, EnumKey, CloseKey, HKEY_LOCAL_MACHINE
def hook(event, args):
if not event.startswith("winreg."):
return
print(event, *args)
sys.addaudithook(hook)
k = OpenKey(HKEY_LOCAL_MACHINE, "Software")
EnumKey(k, 0)
try:
EnumKey(k, 10000)
except OSError:
pass
else:
raise RuntimeError("Expected EnumKey(HKLM, 10000) to fail")
kv = k.Detach()
CloseKey(kv)
def test_socket():
import socket
def hook(event, args):
if event.startswith("socket."):
print(event, *args)
sys.addaudithook(hook)
socket.gethostname()
# Don't care if this fails, we just want the audit message
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
# Don't care if this fails, we just want the audit message
sock.bind(('127.0.0.1', 8080))
except Exception:
pass
finally:
sock.close()
def test_gc():
import gc
def hook(event, args):
if event.startswith("gc."):
print(event, *args)
sys.addaudithook(hook)
gc.get_objects(generation=1)
x = object()
y = [x]
gc.get_referrers(x)
gc.get_referents(y)
def test_http_client():
import http.client
def hook(event, args):
if event.startswith("http.client."):
print(event, *args[1:])
sys.addaudithook(hook)
conn = http.client.HTTPConnection('www.python.org')
try:
conn.request('GET', '/')
except OSError:
print('http.client.send', '[cannot send]')
finally:
conn.close()
def test_sqlite3():
import sqlite3
def hook(event, *args):
if event.startswith("sqlite3."):
print(event, *args)
sys.addaudithook(hook)
cx1 = sqlite3.connect(":memory:")
cx2 = sqlite3.Connection(":memory:")
# Configured without --enable-loadable-sqlite-extensions
if hasattr(sqlite3.Connection, "enable_load_extension"):
cx1.enable_load_extension(False)
try:
cx1.load_extension("test")
except sqlite3.OperationalError:
pass
else:
raise RuntimeError("Expected sqlite3.load_extension to fail")
def test_sys_getframe():
import sys
def hook(event, args):
if event.startswith("sys."):
print(event, args[0].f_code.co_name)
sys.addaudithook(hook)
sys._getframe()
def test_syslog():
import syslog
def hook(event, args):
if event.startswith("syslog."):
print(event, *args)
sys.addaudithook(hook)
syslog.openlog('python')
syslog.syslog('test')
syslog.setlogmask(syslog.LOG_DEBUG)
syslog.closelog()
# implicit open
syslog.syslog('test2')
# open with default ident
syslog.openlog(logoption=syslog.LOG_NDELAY, facility=syslog.LOG_LOCAL0)
sys.argv = None
syslog.openlog()
syslog.closelog()
def test_not_in_gc():
import gc
hook = lambda *a: None
sys.addaudithook(hook)
for o in gc.get_objects():
if isinstance(o, list):
assert hook not in o
if __name__ == "__main__":
from test.support import suppress_msvcrt_asserts
suppress_msvcrt_asserts()
test = sys.argv[1]
globals()[test]()