Merge master into oreo-dev
This pulls in:
net-test: try to detect vsyscall=none uml and warn about it.
Add tests for unset output marks on floating policies
Revert "Test Updating OUTPUT_MARK on Active SAs"
Test Updating OUTPUT_MARK on Active SAs
anycast_test.py: increase waiting time to 3 sec to wait CloseFileDescriptorThread finished
anycast_test.py: increase waiting time to 3 sec to wait CloseFileDescriptorThread finished
Test to check tcp initial rwnd size
net_test: fix job control in console's bash terminal
net_test: workaround for 3.18 and 4.4 late urandom init
Revert "Implement a workaround for clang + PARAVIRT failure."
net-test: deflake 4.19 entropy installation
anycast_test.py: change to use thread.join to wait CloseFileDescriptorThread finished
Use blocking sockets with timeout for xfrm_tunnel_test
Close the socket in socketCreateTest before exist
Test to verify cgroup socket filter
net_test - extra debugging for ReadProcNetSocket() regexp match failures
net/test/OWNERS: passing the torch from ek@ to maze@
run_net_test - further boost UML entropy
run_net_test.sh: Add support for arm64
build_rootfs.sh: Add support for arm64
Improve xfrm net test
Improve xfrm net test
Filter neighbour dumps by interface.
run_net_test: UML - fix insufficient entropy problems
run_net_test: handle UML's tendency to leave stdout in non-blocking mode
run_net_test.sh: switch to readonly by default
run_net_test: add 'no_test' test
Filter neighbour dumps by interface.
run_net_test.sh: fix UML --readonly flag
Enable virtio rng device for net tests on qemu/kvm.
Add lspci & lsusb commands to stretch image.
Remove mutable default parameter in tunnel_test
Document/enforce a bug in udp_dump_one.
Allow ROOTFS to use environment variables
Check xfrm state to delete embryonic SA
Fix sysfs mount in net_test.sh.
net_test: fix sock_diag_test.py to handle ipv5 correctly
Add tests for netfilter reject policies
Add tests for VTI rekey procedure
Always test UDP_DIAG for 4.9 kernel
Refactor VTI tests to support null encryption
Add tunnel input tests to net_tests
Refactor parameterization logic in net tests
Fix nobuild runs of run_net_test.sh without KERNEL_BINARY env var set
Add scripts for building the net tests rootfs.
anycast_test.py: change to use thread.join to wait CloseFileDescriptorThread finished
Fix net tests for 32-bit kernel
Fix net tests for 32-bit kernel
Enable FHANDLE to support systemd
Implement a workaround for clang + PARAVIRT failure.
Add support for running the harness with QEMU.
Annotate non-common kernel config options.
Fix some invalid config options.
Drop unnecessary CONFIG_ prefixes.
Fix selection of bpf syscall number with COMPAT_UTS_MACHINE.
Add __NR_bpf constant for i686.
Fix the flaky cgroup uid bpf test
Test for getFirstMapKey of bpf maps
Test for getFirstMapKey of bpf maps
Test experimental xfrm interfaces if supported.
Be flexible about TCP RST and SOCK_DESTROY poll return values.
Set SA mark to unused for Tunnel Mode
Be flexible about TCP RST and SOCK_DESTROY poll return values.
Verify VTI Modification using RTM_NEWLINK
Verify VTI Modification using RTM_NEWLINK
Disable qtaguid tests if qtaguid is not present.
Enable algorithm net tests for 3.18 kernels
Fix algorithm tests to ensure no lingering sockets
Enable algorithm net tests for 3.18 kernels
Fix algorithm tests to ensure no lingering sockets
Test for bpf read/write only map
Make sock_diag_test.py pass on 4.9 devices that don't have SCTP.
Make BPF tests pass on device.
Test multiple VTI tunnels at the same time.
Add a new VtiInterface class and use it in XfrmVtiTest.
net_test: fix exit code.
Test that ICMP errors work on VTI interfaces.
Don't require XFRM tests to flush all state on tearDown.
Move some utility code from xfrm_base to xfrm.
Add handling for parameterized test modules
Make tcp_fastopen_test pass on new scapy.
Switch from using epoll to poll.
Don't require that we be the only sockets in the system.
Fix Tunnel Encryption with New Scapy
Unbreak tests that use mapped addresses on bionic.
Make xfrm_test pass with newer versions of iproute.
Don't fail if CONFIG_INET_UDP_DIAG is not enabled.
Fix tests to pass on newer versions of scapy.
Initial test for dt for early mount.
Make testIPv6PktinfoRouting pass on 32-bit Python.
Make incoming mark tests pass on device.
Don't use the fwmark client when running tests on device.
Don't expect eBPF support when running on 4.4.
net_test: return non-zero exit code on failure.
Fix incorrect fallback path for iptables.
Build network tests for android device.
Add ability for all_tests.sh to run based on prefix
Check that rxmem and txmem don't differ too much from each other
Un-hardcode the VTI iface and netid
Split the VTI tests and the tunnel tests.
Move the input marking code into multinetwork_base.
Support sending ICMP PTBs for non-UDP packets.
Fix decoding RTA_UID.
Test that an SA Can be Updated with a Mark
Add tests for transport mode re-key procedure
Make GetEspPacketLength calculate lengths dynamically
Use RTPROT_RA for kernel 4.14 instead of RTPROT_BOOT in DelRA6()
Fix missing config UBD, HOSTFS and NF_SOCKET_IPV4/6 for kernel 4.14
Support Tunnel Mode Null Encryption
Verify Security Policies May Differ by Only Direction
Test IPv4 transport mode on dual-stack sockets.
Test removing socket policies.
Tests for Adding and Updating Global Policies
Fix missing argument in kernel tests
Add AES-GCM kernel tests
Add Test for Invalid Algorithms
Move OWNERS file for net_test to net/test.
Enforce that there is exactly one field name per field.
Check that VTI input/output correctly increment counters.
Specify SPIs in host byte order in xfrm code.
Properly use dual-stack SAs.
Remove duplicated code to create policies and templates.
Group the parameters passed in to create SAs.
Pass around XfrmSelector instead of its parameters.
Add helpers for null encryption.
Add code to parse link stats.
Move _SetInboundMarking to multinetwork_base.py.
Stop calling InvalidateDstCache in xfrm tests.
Improve netlink debugging.
Provide an AddressVersion utility function.
Enable null crypto for kernel tests.
Use a UDP socket instead of a ping socket for ioctls.
Tests for TCP_FASTOPEN_CONNECT.
Support the generic netlink interface to TCP metrics.
Basic support for generic netlink.
Test that dst cache is cleared with socket policy.
Add Input Checking to the VTI tests
Wait for IPv6 addresses to be actually added.
Move SetSocketTimeout to csocket.
Minor fixes to iproute.
Add Debug Dump for XfrmUserpolicyInfo
Enable XfrmOutputMarkTest everywhere.
Skip multicast packets in TunTwister
Enable XfrmOutputMarkTest on more kernels.
Use Network saddr in Tunnel Mode Test
Update XFRM/VTI Tunnel mode tests to use output_mark
Add Tests for XFRM Tunnel and VTI Interface
Convert TunTwister's Twisting Methods to Classmethod
Add a test for XFRMA_OUTPUT_MARK.
Minor cleanups for existing tests
Factor a base class out of xfrm_test
Use ApplySocketPolicy in the basic xfrm tests.
Re-enable socket policy test by invalidating the dst cache.
Allow parallel_tests.sh to run an arbitrary test.
Exhaustive socket policy tests.
Test for a bug deleting invalid UID ranges.
Create 'TunTwister' util for twisting tun/tap.
Add XfrmMark Support for SAs and SPs
Add Methods to create User Policy to xfrm.py
Make the qtaguid test more stable
Test that SOCK_DESTROY affects poll() like a TCP RST.
Add more coverage for errors after a socket is closed.
Switch InvalidateDstCache to using throw routes.
Test verify qtaguid not drop packet without socket.
forwarding_test: Add a test for forwarding UDPv6
Check for negative values in SO_{RCV,SND}BUFFORCE.
Trim OWNERS
Revert "Revert "Add VTI Configs to Networking Unit Tests""
Revert "Add VTI Configs to Networking Unit Tests"
Add VTI Configs to Networking Unit Tests
Add tests for cgroup v2 bpf and new helper functions
Add OWNERS in kernel/tests
Add attribute offset support to cstruct
Turn on /proc/net/xfrm_stat in kernel config.
Test case to check getsockopt operation SO_COOKIE
New ways to instantate cstruct objects.
Add tests for ALLOCSPI.
Support "with"-style errno assertions.
Clean up and reorgnize the bpf Test
Remove unused imports.
More tests on xt_qtaguid owner match function
Test that SHA2 hashes use 128-bit truncation with PF_KEY.
Add code to use the PF_KEY interface.
Enable qtaguid sk fd test
Support flushing XFRM state.
De-duplicate iptables command code.
Also test link-local ping on connected sockets.
Make net_test enable CONFIG_NETFILTER_TPROXY as well.
Add test to check socket get untagged after closed
Fix incorrect protocol argument to RTM_DELROUTE
Support more device-like filesystem layout.
Use actual pointer objects instead of integers.
Revert "Unit test for socket cookie upstream patch"
Unit test for socket cookie upstream patch
Enable full RIOTest.testZeroLengthPrefix test on all kernel versions
net_test: Add test for RFC7559 router solicitation backoff
Signed-off-by: Maciej Żenczykowski <maze@google.com>
Change-Id: I3af734167b09fa68085941e34f2369f5a19028e4
diff --git a/Android.bp b/Android.bp
new file mode 100644
index 0000000..3b4e960
--- /dev/null
+++ b/Android.bp
@@ -0,0 +1,12 @@
+python_defaults {
+ name: "kernel_tests_defaults",
+ version: {
+ py2: {
+ embedded_launcher: true,
+ enabled: true,
+ },
+ py3: {
+ enabled: false,
+ },
+ },
+}
diff --git a/devicetree/early_mount/Android.bp b/devicetree/early_mount/Android.bp
new file mode 100644
index 0000000..63131c6
--- /dev/null
+++ b/devicetree/early_mount/Android.bp
@@ -0,0 +1,29 @@
+// Copyright (C) 2016 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+python_test {
+ name: "dt_early_mount_test",
+ srcs: [
+ "**/*.py",
+ ],
+ version: {
+ py2: {
+ embedded_launcher: true,
+ enabled: true,
+ },
+ py3: {
+ enabled: false,
+ },
+ },
+}
diff --git a/devicetree/early_mount/dt_early_mount_test.py b/devicetree/early_mount/dt_early_mount_test.py
new file mode 100755
index 0000000..6cabde4
--- /dev/null
+++ b/devicetree/early_mount/dt_early_mount_test.py
@@ -0,0 +1,84 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test cases for device tree overlays for early mounting partitions."""
+
+import os
+import unittest
+
+
+# Early mount fstab entry must have following properties defined.
+REQUIRED_FSTAB_PROPERTIES = [
+ 'dev',
+ 'type',
+ 'mnt_flags',
+ 'fsmgr_flags',
+]
+
+
+def ReadFile(file_path):
+ with open(file_path, 'r') as f:
+ # Strip all trailing spaces, newline and null characters.
+ return f.read().rstrip(' \n\x00')
+
+
+def GetAndroidDtDir():
+ """Returns location of android device tree directory."""
+ with open('/proc/cmdline', 'r') as f:
+ cmdline_list = f.read().split()
+
+ # Find android device tree directory path passed through kernel cmdline.
+ for option in cmdline_list:
+ if option.startswith('androidboot.android_dt_dir'):
+ return option.split('=')[1]
+
+ # If no custom path found, return the default location.
+ return '/proc/device-tree/firmware/android'
+
+
+class DtEarlyMountTest(unittest.TestCase):
+ """Test device tree overlays for early mounting."""
+
+ def setUp(self):
+ self._android_dt_dir = GetAndroidDtDir()
+ self._fstab_dt_dir = os.path.join(self._android_dt_dir, 'fstab')
+
+ def GetEarlyMountedPartitions(self):
+ """Returns a list of partitions specified in fstab for early mount."""
+ # Device tree nodes are represented as directories in the filesystem.
+ return [x for x in os.listdir(self._fstab_dt_dir) if os.path.isdir(x)]
+
+ def VerifyFstabEntry(self, partition):
+ partition_dt_dir = os.path.join(self._fstab_dt_dir, partition)
+ properties = [x for x in os.listdir(partition_dt_dir)]
+
+ self.assertTrue(
+ set(REQUIRED_FSTAB_PROPERTIES).issubset(properties),
+ 'fstab entry for /%s is missing required properties' % partition)
+
+ def testFstabCompatible(self):
+ """Verify fstab compatible string."""
+ compatible = ReadFile(os.path.join(self._fstab_dt_dir, 'compatible'))
+ self.assertEqual('android,fstab', compatible)
+
+ def testFstabEntries(self):
+ """Verify properties of early mount fstab entries."""
+ for partition in self.GetEarlyMountedPartitions():
+ self.VerifyFstabEntry(partition)
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/net/test/Android.bp b/net/test/Android.bp
new file mode 100644
index 0000000..2151015
--- /dev/null
+++ b/net/test/Android.bp
@@ -0,0 +1,13 @@
+python_test {
+ name: "kernel_net_tests",
+ main: "all_tests.py",
+ srcs: [
+ "*.py",
+ ],
+ libs: [
+ "scapy",
+ ],
+ defaults: [
+ "kernel_tests_defaults"
+ ],
+}
diff --git a/net/test/OWNERS b/net/test/OWNERS
new file mode 100644
index 0000000..cbbfa70
--- /dev/null
+++ b/net/test/OWNERS
@@ -0,0 +1,2 @@
+lorenzo@google.com
+maze@google.com
diff --git a/net/test/all_tests.py b/net/test/all_tests.py
new file mode 100755
index 0000000..bfba0e5
--- /dev/null
+++ b/net/test/all_tests.py
@@ -0,0 +1,58 @@
+#!/usr/bin/python
+#
+# Copyright 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from importlib import import_module
+import sys
+import unittest
+
+test_modules = [
+ 'anycast_test',
+ 'bpf_test',
+ 'csocket_test',
+ 'cstruct_test',
+ 'forwarding_test',
+ 'leak_test',
+ 'multinetwork_test',
+ 'neighbour_test',
+ 'nf_test',
+ 'pf_key_test',
+ 'ping6_test',
+ 'qtaguid_test',
+ 'removed_feature_test',
+ 'resilient_rs_test',
+ 'sock_diag_test',
+ 'srcaddr_selection_test',
+ 'tcp_fastopen_test',
+ 'tcp_nuke_addr_test',
+ 'tcp_test',
+ 'xfrm_algorithm_test',
+ 'xfrm_test',
+ 'xfrm_tunnel_test',
+]
+
+if __name__ == '__main__':
+ # First, run InjectTests on all modules, to ensure that any parameterized
+ # tests in those modules are injected.
+ for name in test_modules:
+ import_module(name)
+ if hasattr(sys.modules[name], "InjectTests"):
+ sys.modules[name].InjectTests()
+
+ loader = unittest.defaultTestLoader
+ test_suite = loader.loadTestsFromNames(test_modules)
+ runner = unittest.TextTestRunner(verbosity=2)
+ result = runner.run(test_suite)
+ sys.exit(not result.wasSuccessful())
diff --git a/net/test/all_tests.sh b/net/test/all_tests.sh
index a5476f9..63576b0 100755
--- a/net/test/all_tests.sh
+++ b/net/test/all_tests.sh
@@ -16,6 +16,27 @@
readonly PREFIX="#####"
readonly RETRIES=2
+test_prefix=
+
+function checkArgOrExit() {
+ if [[ $# -lt 2 ]]; then
+ echo "Missing argument for option $1" >&2
+ exit 1
+ fi
+}
+
+function usageAndExit() {
+ cat >&2 << EOF
+ all_tests.sh - test runner with support for flake testing
+
+ all_tests.sh [options]
+
+ options:
+ -h, --help show this menu
+ -p, --prefix=TEST_PREFIX specify a prefix for the tests to be run
+EOF
+ exit 0
+}
function maybePlural() {
# $1 = integer to use for plural check
@@ -28,7 +49,6 @@
fi
}
-
function runTest() {
local cmd="$1"
@@ -46,10 +66,39 @@
echo >&2 "Warning: '$cmd' FLAKY, passed $RETRIES/$((RETRIES + 1))"
}
+# Parse arguments
+while [ -n "$1" ]; do
+ case "$1" in
+ -h|--help)
+ usageAndExit
+ ;;
+ -p|--prefix)
+ checkArgOrExit $@
+ test_prefix=$2
+ shift 2
+ ;;
+ *)
+ echo "Unknown option $1" >&2
+ echo >&2
+ usageAndExit
+ ;;
+ esac
+done
-readonly tests=$(find . -name '*_test.py' -type f -executable)
+# Find the relevant tests
+if [[ -z $test_prefix ]]; then
+ readonly tests=$(eval "find . -name '*_test.py' -type f -executable")
+else
+ readonly tests=$(eval "find . -name '$test_prefix*' -type f -executable")
+fi
+
+# Give some readable status messages
readonly count=$(echo $tests | wc -w)
-echo "$PREFIX Found $count $(maybePlural $count test tests)."
+if [[ -z $test_prefix ]]; then
+ echo "$PREFIX Found $count $(maybePlural $count test tests)."
+else
+ echo "$PREFIX Found $count $(maybePlural $count test tests) with prefix $test_prefix."
+fi
exit_code=0
diff --git a/net/test/anycast_test.py b/net/test/anycast_test.py
old mode 100755
new mode 100644
index 82130db..6222580
--- a/net/test/anycast_test.py
+++ b/net/test/anycast_test.py
@@ -93,7 +93,14 @@
# This will hang if the kernel has the bug.
thread = CloseFileDescriptorThread(self.tuns[netid])
thread.start()
- time.sleep(0.1)
+ # Wait up to 3 seconds for the thread to finish, but
+ # continue and fail the test if the thread hangs.
+
+ # For kernels with MPTCP ported, closing tun interface need more
+ # than 0.5 sec. DAD procedure within MPTCP fullmesh module takes
+ # more time, because duplicate address-timer takes a refcount
+ # on the IPv6-address, preventing it from getting closed.
+ thread.join(3)
# Make teardown work.
del self.tuns[netid]
diff --git a/net/test/bpf.py b/net/test/bpf.py
index 50add04..43502bd 100755
--- a/net/test/bpf.py
+++ b/net/test/bpf.py
@@ -15,15 +15,31 @@
# limitations under the License.
import ctypes
+import os
import csocket
import cstruct
import net_test
import socket
+import platform
-# TODO: figure out how to make this arch-dependent if we run these tests
-# on non-X86
-__NR_bpf = 321
+# __NR_bpf syscall numbers for various architectures.
+# NOTE: If python inherited COMPAT_UTS_MACHINE, uname's 'machine' field will
+# return the 32-bit architecture name, even if python itself is 64-bit. To work
+# around this problem and pick the right syscall nr, we can additionally check
+# the bitness of the python interpreter. Assume that the 64-bit architectures
+# are not running with COMPAT_UTS_MACHINE and must be 64-bit at all times.
+# TODO: is there a better way of doing this?
+__NR_bpf = {
+ "aarch64-64bit": 280,
+ "armv7l-32bit": 386,
+ "armv8l-32bit": 386,
+ "armv8l-64bit": 280,
+ "i686-32bit": 357,
+ "i686-64bit": 321,
+ "x86_64-64bit": 321,
+}[os.uname()[4] + "-" + platform.architecture()[0]]
+
LOG_LEVEL = 1
LOG_SIZE = 65536
@@ -36,6 +52,8 @@
BPF_PROG_LOAD = 5
BPF_OBJ_PIN = 6
BPF_OBJ_GET = 7
+BPF_PROG_ATTACH = 8
+BPF_PROG_DETACH = 9
SO_ATTACH_BPF = 50
# BPF map type constant.
@@ -51,6 +69,16 @@
BPF_PROG_TYPE_KPROBE = 2
BPF_PROG_TYPE_SCHED_CLS = 3
BPF_PROG_TYPE_SCHED_ACT = 4
+BPF_PROG_TYPE_TRACEPOINT = 5
+BPF_PROG_TYPE_XDP = 6
+BPF_PROG_TYPE_PERF_EVENT = 7
+BPF_PROG_TYPE_CGROUP_SKB = 8
+BPF_PROG_TYPE_CGROUP_SOCK = 9
+
+# BPF program attach type.
+BPF_CGROUP_INET_INGRESS = 0
+BPF_CGROUP_INET_EGRESS = 1
+BPF_CGROUP_INET_SOCK_CREATE = 2
# BPF register constant
BPF_REG_0 = 0
@@ -124,15 +152,25 @@
BPF_FUNC_map_lookup_elem = 1
BPF_FUNC_map_update_elem = 2
BPF_FUNC_map_delete_elem = 3
+BPF_FUNC_get_current_uid_gid = 15
+BPF_FUNC_get_socket_cookie = 46
+BPF_FUNC_get_socket_uid = 47
-# BPF attr struct
-BpfAttrCreate = cstruct.Struct("bpf_attr_create", "=IIII",
- "map_type key_size value_size max_entries")
+BPF_F_RDONLY = 1 << 3
+BPF_F_WRONLY = 1 << 4
+
+# These object below belongs to the same kernel union and the types below
+# (e.g., bpf_attr_create) aren't kernel struct names but just different
+# variants of the union.
+BpfAttrCreate = cstruct.Struct("bpf_attr_create", "=IIIII",
+ "map_type key_size value_size max_entries, map_flags")
BpfAttrOps = cstruct.Struct("bpf_attr_ops", "=QQQQ",
"map_fd key_ptr value_ptr flags")
BpfAttrProgLoad = cstruct.Struct(
"bpf_attr_prog_load", "=IIQQIIQI", "prog_type insn_cnt insns"
" license log_level log_size log_buf kern_version")
+BpfAttrProgAttach = cstruct.Struct(
+ "bpf_attr_prog_attach", "=III", "target_fd attach_bpf_fd attach_type")
BpfInsn = cstruct.Struct("bpf_insn", "=BBhi", "code dst_src_reg off imm")
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
@@ -140,12 +178,15 @@
# BPF program syscalls
-def CreateMap(map_type, key_size, value_size, max_entries):
- attr = BpfAttrCreate((map_type, key_size, value_size, max_entries))
- ret = libc.syscall(__NR_bpf, BPF_MAP_CREATE, attr.CPointer(), len(attr))
+def BpfSyscall(op, attr):
+ ret = libc.syscall(__NR_bpf, op, csocket.VoidPointer(attr), len(attr))
csocket.MaybeRaiseSocketError(ret)
return ret
+def CreateMap(map_type, key_size, value_size, max_entries, map_flags=0):
+ attr = BpfAttrCreate((map_type, key_size, value_size, max_entries, map_flags))
+ return BpfSyscall(BPF_MAP_CREATE, attr)
+
def UpdateMap(map_fd, key, value, flags=0):
c_value = ctypes.c_uint32(value)
@@ -153,9 +194,7 @@
value_ptr = ctypes.addressof(c_value)
key_ptr = ctypes.addressof(c_key)
attr = BpfAttrOps((map_fd, key_ptr, value_ptr, flags))
- ret = libc.syscall(__NR_bpf, BPF_MAP_UPDATE_ELEM,
- attr.CPointer(), len(attr))
- csocket.MaybeRaiseSocketError(ret)
+ BpfSyscall(BPF_MAP_UPDATE_ELEM, attr)
def LookupMap(map_fd, key):
@@ -163,48 +202,60 @@
c_key = ctypes.c_uint32(key)
attr = BpfAttrOps(
(map_fd, ctypes.addressof(c_key), ctypes.addressof(c_value), 0))
- ret = libc.syscall(__NR_bpf, BPF_MAP_LOOKUP_ELEM,
- attr.CPointer(), len(attr))
- csocket.MaybeRaiseSocketError(ret)
+ BpfSyscall(BPF_MAP_LOOKUP_ELEM, attr)
return c_value
def GetNextKey(map_fd, key):
- c_key = ctypes.c_uint32(key)
+ if key is not None:
+ c_key = ctypes.c_uint32(key)
+ c_next_key = ctypes.c_uint32(0)
+ key_ptr = ctypes.addressof(c_key)
+ else:
+ key_ptr = 0;
c_next_key = ctypes.c_uint32(0)
attr = BpfAttrOps(
- (map_fd, ctypes.addressof(c_key), ctypes.addressof(c_next_key), 0))
- ret = libc.syscall(__NR_bpf, BPF_MAP_GET_NEXT_KEY,
- attr.CPointer(), len(attr))
- csocket.MaybeRaiseSocketError(ret)
+ (map_fd, key_ptr, ctypes.addressof(c_next_key), 0))
+ BpfSyscall(BPF_MAP_GET_NEXT_KEY, attr)
return c_next_key
+def GetFirstKey(map_fd):
+ return GetNextKey(map_fd, None)
def DeleteMap(map_fd, key):
c_key = ctypes.c_uint32(key)
attr = BpfAttrOps((map_fd, ctypes.addressof(c_key), 0, 0))
- ret = libc.syscall(__NR_bpf, BPF_MAP_DELETE_ELEM,
- attr.CPointer(), len(attr))
- csocket.MaybeRaiseSocketError(ret)
+ BpfSyscall(BPF_MAP_DELETE_ELEM, attr)
-def BpfProgLoad(prog_type, insn_ptr, prog_len, insn_len):
+def BpfProgLoad(prog_type, instructions):
+ bpf_prog = "".join(instructions)
+ insn_buff = ctypes.create_string_buffer(bpf_prog)
gpl_license = ctypes.create_string_buffer(b"GPL")
log_buf = ctypes.create_string_buffer(b"", LOG_SIZE)
- attr = BpfAttrProgLoad(
- (prog_type, prog_len / insn_len, insn_ptr, ctypes.addressof(gpl_license),
- LOG_LEVEL, LOG_SIZE, ctypes.addressof(log_buf), 0))
- ret = libc.syscall(__NR_bpf, BPF_PROG_LOAD, attr.CPointer(), len(attr))
- csocket.MaybeRaiseSocketError(ret)
- return ret
+ attr = BpfAttrProgLoad((prog_type, len(insn_buff) / len(BpfInsn),
+ ctypes.addressof(insn_buff),
+ ctypes.addressof(gpl_license), LOG_LEVEL,
+ LOG_SIZE, ctypes.addressof(log_buf), 0))
+ return BpfSyscall(BPF_PROG_LOAD, attr)
-
-def BpfProgAttach(sock_fd, prog_fd):
- prog_ptr = ctypes.c_uint32(prog_fd)
+# Attach a socket eBPF filter to a target socket
+def BpfProgAttachSocket(sock_fd, prog_fd):
+ uint_fd = ctypes.c_uint32(prog_fd)
ret = libc.setsockopt(sock_fd, socket.SOL_SOCKET, SO_ATTACH_BPF,
- ctypes.addressof(prog_ptr), ctypes.sizeof(prog_ptr))
+ ctypes.pointer(uint_fd), ctypes.sizeof(uint_fd))
csocket.MaybeRaiseSocketError(ret)
+# Attach a eBPF filter to a cgroup
+def BpfProgAttach(prog_fd, target_fd, prog_type):
+ attr = BpfAttrProgAttach((target_fd, prog_fd, prog_type))
+ return BpfSyscall(BPF_PROG_ATTACH, attr)
+
+# Detach a eBPF filter from a cgroup
+def BpfProgDetach(target_fd, prog_type):
+ attr = BpfAttrProgAttach((target_fd, 0, prog_type))
+ return BpfSyscall(BPF_PROG_DETACH, attr)
+
# BPF program command constructors
def BpfMov64Reg(dst, src):
@@ -275,22 +326,8 @@
return insn1.Pack() + insn2.Pack()
-def BpfFuncLookupMap():
+def BpfFuncCall(func):
code = BPF_JMP | BPF_CALL
dst_src = 0
- ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_lookup_elem))
- return ret.Pack()
-
-
-def BpfFuncUpdateMap():
- code = BPF_JMP | BPF_CALL
- dst_src = 0
- ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_update_elem))
- return ret.Pack()
-
-
-def BpfFuncDeleteMap():
- code = BPF_JMP | BPF_CALL
- dst_src = 0
- ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_delete_elem))
+ ret = BpfInsn((code, dst_src, 0, func))
return ret.Pack()
diff --git a/net/test/bpf_test.py b/net/test/bpf_test.py
index 1e294c9..ea3e56b 100755
--- a/net/test/bpf_test.py
+++ b/net/test/bpf_test.py
@@ -16,137 +16,431 @@
import ctypes
import errno
+import os
import socket
+import struct
+import subprocess
+import tempfile
import unittest
from bpf import * # pylint: disable=wildcard-import
import csocket
import net_test
+import sock_diag
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
-HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0)
+HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0)
+HAVE_EBPF_SOCKET = net_test.LINUX_VERSION >= (4, 14, 0)
+KEY_SIZE = 8
+VALUE_SIZE = 4
+TOTAL_ENTRIES = 20
+TEST_UID = 54321
+TEST_GID = 12345
+# Offset to store the map key in stack register REG10
+key_offset = -8
+# Offset to store the map value in stack register REG10
+value_offset = -16
-@unittest.skipUnless(HAVE_EBPF_SUPPORT,
- "eBPF function not fully supported")
+# Debug usage only.
+def PrintMapInfo(map_fd):
+ # A random key that the map does not contain.
+ key = 10086
+ while 1:
+ try:
+ nextKey = GetNextKey(map_fd, key).value
+ value = LookupMap(map_fd, nextKey)
+ print repr(nextKey) + " : " + repr(value.value)
+ key = nextKey
+ except:
+ print "no value"
+ break
+
+
+# A dummy loopback function that causes a socket to send traffic to itself.
+def SocketUDPLoopBack(packet_count, version, prog_fd):
+ family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
+ sock = socket.socket(family, socket.SOCK_DGRAM, 0)
+ if prog_fd is not None:
+ BpfProgAttachSocket(sock.fileno(), prog_fd)
+ net_test.SetNonBlocking(sock)
+ addr = {4: "127.0.0.1", 6: "::1"}[version]
+ sock.bind((addr, 0))
+ addr = sock.getsockname()
+ sockaddr = csocket.Sockaddr(addr)
+ for i in xrange(packet_count):
+ sock.sendto("foo", addr)
+ data, retaddr = csocket.Recvfrom(sock, 4096, 0)
+ assert "foo" == data
+ assert sockaddr == retaddr
+ return sock
+
+
+# The main code block for eBPF packet counting program. It takes a preloaded
+# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
+# exist in the map yet, the program will update the map with a new <key, 1>
+# pair. Otherwise it will jump to next code block to handle it.
+# REG0: regiter storing return value from helper function and the final return
+# value of eBPF program.
+# REG1 - REG5: temporary register used for storing values and load parameters
+# into eBPF helper function. After calling helper function, the value for these
+# registers will be reset.
+# REG6 - REG9: registers store values that will not be cleared when calling
+# eBPF helper function.
+# REG10: A stack stores values need to be accessed by the address. Program can
+# retrieve the address of a value by specifying the position of the value in
+# the stack.
+def BpfFuncCountPacketInit(map_fd):
+ key_pos = BPF_REG_7
+ insPackCountStart = [
+ # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
+ BpfMov64Reg(key_pos, BPF_REG_10),
+ BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
+ # Load map fd and look up the key in the map
+ BpfLoadMapFd(map_fd, BPF_REG_1),
+ BpfMov64Reg(BPF_REG_2, key_pos),
+ BpfFuncCall(BPF_FUNC_map_lookup_elem),
+ # if the map element already exist, jump out of this
+ # code block and let next part to handle it
+ BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
+ BpfLoadMapFd(map_fd, BPF_REG_1),
+ BpfMov64Reg(BPF_REG_2, key_pos),
+ # Initial a new <key, value> pair with value equal to 1 and update to map
+ BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
+ BpfMov64Reg(BPF_REG_3, BPF_REG_10),
+ BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
+ BpfMov64Imm(BPF_REG_4, 0),
+ BpfFuncCall(BPF_FUNC_map_update_elem)
+ ]
+ return insPackCountStart
+
+
+INS_BPF_EXIT_BLOCK = [
+ BpfMov64Imm(BPF_REG_0, 0),
+ BpfExitInsn()
+]
+
+# Bpf instruction for cgroup bpf filter to accept a packet and exit.
+INS_CGROUP_ACCEPT = [
+ # Set return value to 1 and exit.
+ BpfMov64Imm(BPF_REG_0, 1),
+ BpfExitInsn()
+]
+
+# Bpf instruction for socket bpf filter to accept a packet and exit.
+INS_SK_FILTER_ACCEPT = [
+ # Precondition: BPF_REG_6 = sk_buff context
+ # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
+ # return value.
+ BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
+ BpfExitInsn()
+]
+
+# Update a existing map element with +1.
+INS_PACK_COUNT_UPDATE = [
+ # Precondition: BPF_REG_0 = Value retrieved from BPF maps
+ # Add one to the corresponding eBPF value field for a specific eBPF key.
+ BpfMov64Reg(BPF_REG_2, BPF_REG_0),
+ BpfMov64Imm(BPF_REG_1, 1),
+ BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
+]
+
+INS_BPF_PARAM_STORE = [
+ BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
+]
+
+@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+ "BPF helper function is not fully supported")
class BpfTest(net_test.NetworkTest):
+ def setUp(self):
+ self.map_fd = -1
+ self.prog_fd = -1
+ self.sock = None
+
+ def tearDown(self):
+ if self.prog_fd >= 0:
+ os.close(self.prog_fd)
+ if self.map_fd >= 0:
+ os.close(self.map_fd)
+ if self.sock:
+ self.sock.close()
+
def testCreateMap(self):
key, value = 1, 1
- map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
- UpdateMap(map_fd, key, value)
- self.assertEquals(LookupMap(map_fd, key).value, value)
- DeleteMap(map_fd, key)
- self.assertRaisesErrno(errno.ENOENT, LookupMap, map_fd, key)
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ UpdateMap(self.map_fd, key, value)
+ self.assertEquals(value, LookupMap(self.map_fd, key).value)
+ DeleteMap(self.map_fd, key)
+ self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
- def testIterateMap(self):
- map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
- value = 1024
- for key in xrange(1, 100):
- UpdateMap(map_fd, key, value)
- for key in xrange(1, 100):
- self.assertEquals(LookupMap(map_fd, key).value, value)
- self.assertRaisesErrno(errno.ENOENT, LookupMap, map_fd, 101)
- key = 0
+ def CheckAllMapEntry(self, nonexistent_key, totalEntries, value):
count = 0
- while 1:
- if count == 99:
- self.assertRaisesErrno(errno.ENOENT, GetNextKey, map_fd, key)
+ key = nonexistent_key
+ while True:
+ if count == totalEntries:
+ self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
break
else:
- result = GetNextKey(map_fd, key)
+ result = GetNextKey(self.map_fd, key)
key = result.value
- self.assertGreater(key, 0)
- self.assertEquals(LookupMap(map_fd, key).value, value)
+ self.assertGreaterEqual(key, 0)
+ self.assertEquals(value, LookupMap(self.map_fd, key).value)
count += 1
+ def testIterateMap(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ value = 1024
+ for key in xrange(0, TOTAL_ENTRIES):
+ UpdateMap(self.map_fd, key, value)
+ for key in xrange(0, TOTAL_ENTRIES):
+ self.assertEquals(value, LookupMap(self.map_fd, key).value)
+ self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
+ nonexistent_key = -1
+ self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value)
+
+ def testFindFirstMapKey(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ value = 1024
+ for key in xrange(0, TOTAL_ENTRIES):
+ UpdateMap(self.map_fd, key, value)
+ firstKey = GetFirstKey(self.map_fd)
+ key = firstKey.value
+ self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value)
+
+
+ def testRdOnlyMap(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES, map_flags=BPF_F_RDONLY)
+ value = 1024
+ key = 1
+ self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value)
+ self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
+
+ @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+ "BPF helper function is not fully supported")
+ def testWrOnlyMap(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES, map_flags=BPF_F_WRONLY)
+ value = 1024
+ key = 1
+ UpdateMap(self.map_fd, key, value)
+ self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key)
+
def testProgLoad(self):
- bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
- bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
- bpf_prog += BpfExitInsn()
- insn_buff = ctypes.create_string_buffer(bpf_prog)
- # Load a program that does nothing except pass every packet it receives
- # It should not block the packet transmission otherwise the test fails.
- prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
- ctypes.addressof(insn_buff),
- len(insn_buff), BpfInsn._length)
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
- sock.settimeout(1)
- BpfProgAttach(sock.fileno(), prog_fd)
- addr = "127.0.0.1"
- sock.bind((addr, 0))
- addr = sock.getsockname()
- sockaddr = csocket.Sockaddr(addr)
- sock.sendto("foo", addr)
- data, addr = csocket.Recvfrom(sock, 4096, 0)
- self.assertEqual("foo", data)
- self.assertEqual(sockaddr, addr)
+ # Move skb to BPF_REG_6 for further usage
+ instructions = [
+ BpfMov64Reg(BPF_REG_6, BPF_REG_1)
+ ]
+ instructions += INS_SK_FILTER_ACCEPT
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+ SocketUDPLoopBack(1, 4, self.prog_fd)
+ SocketUDPLoopBack(1, 6, self.prog_fd)
def testPacketBlock(self):
- bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
- bpf_prog += BpfMov64Imm(BPF_REG_0, 0)
- bpf_prog += BpfExitInsn()
- insn_buff = ctypes.create_string_buffer(bpf_prog)
- prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
- ctypes.addressof(insn_buff),
- len(insn_buff), BpfInsn._length)
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
- sock.settimeout(1)
- BpfProgAttach(sock.fileno(), prog_fd)
- addr = "127.0.0.1"
- sock.bind((addr, 0))
- addr = sock.getsockname()
- sock.sendto("foo", addr)
- self.assertRaisesErrno(errno.EAGAIN, csocket.Recvfrom, sock, 4096, 0)
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
+ self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
+ self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
def testPacketCount(self):
- map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
key = 0xf0f0
- bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
- bpf_prog += BpfLoadMapFd(map_fd, BPF_REG_1)
- bpf_prog += BpfMov64Imm(BPF_REG_7, key)
- bpf_prog += BpfStxMem(BPF_W, BPF_REG_10, BPF_REG_7, -4)
- bpf_prog += BpfMov64Reg(BPF_REG_8, BPF_REG_10)
- bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_8, -4)
- bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
- bpf_prog += BpfFuncLookupMap()
- bpf_prog += BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10)
- bpf_prog += BpfLoadMapFd(map_fd, BPF_REG_1)
- bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
- bpf_prog += BpfStMem(BPF_W, BPF_REG_10, -8, 1)
- bpf_prog += BpfMov64Reg(BPF_REG_3, BPF_REG_10)
- bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_3, -8)
- bpf_prog += BpfMov64Imm(BPF_REG_4, 0)
- bpf_prog += BpfFuncUpdateMap()
- bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
- bpf_prog += BpfExitInsn()
- bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_0)
- bpf_prog += BpfMov64Imm(BPF_REG_1, 1)
- bpf_prog += BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1,
- 0, 0)
- bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
- bpf_prog += BpfExitInsn()
- insn_buff = ctypes.create_string_buffer(bpf_prog)
- # this program loaded is used to counting the packet transmitted through
- # a target socket. It will store the packet count into the eBPF map and we
- # will verify if the counting result is correct.
- prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
- ctypes.addressof(insn_buff),
- len(insn_buff), BpfInsn._length)
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
- sock.settimeout(1)
- BpfProgAttach(sock.fileno(), prog_fd)
- addr = "127.0.0.1"
- sock.bind((addr, 0))
- addr = sock.getsockname()
- sockaddr = csocket.Sockaddr(addr)
- packet_count = 100
- for i in xrange(packet_count):
- sock.sendto("foo", addr)
- data, retaddr = csocket.Recvfrom(sock, 4096, 0)
- self.assertEqual("foo", data)
- self.assertEqual(sockaddr, retaddr)
- self.assertEquals(LookupMap(map_fd, key).value, packet_count)
+ # Set up instruction block with key loaded at BPF_REG_0.
+ instructions = [
+ BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+ BpfMov64Imm(BPF_REG_0, key)
+ ]
+ # Concatenate the generic packet count bpf program to it.
+ instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+ + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+ + INS_SK_FILTER_ACCEPT)
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+ packet_count = 10
+ SocketUDPLoopBack(packet_count, 4, self.prog_fd)
+ SocketUDPLoopBack(packet_count, 6, self.prog_fd)
+ self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value)
+ @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+ "BPF helper function is not fully supported")
+ def testGetSocketCookie(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ # Move skb to REG6 for further usage, call helper function to get socket
+ # cookie of current skb and return the cookie at REG0 for next code block
+ instructions = [
+ BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+ BpfFuncCall(BPF_FUNC_get_socket_cookie)
+ ]
+ instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+ + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+ + INS_SK_FILTER_ACCEPT)
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+ packet_count = 10
+ def PacketCountByCookie(version):
+ self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
+ cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
+ self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value)
+ self.sock.close()
+ PacketCountByCookie(4)
+ PacketCountByCookie(6)
+
+ @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+ "BPF helper function is not fully supported")
+ def testGetSocketUid(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ # Set up the instruction with uid at BPF_REG_0.
+ instructions = [
+ BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+ BpfFuncCall(BPF_FUNC_get_socket_uid)
+ ]
+ # Concatenate the generic packet count bpf program to it.
+ instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+ + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+ + INS_SK_FILTER_ACCEPT)
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+ packet_count = 10
+ uid = TEST_UID
+ with net_test.RunAsUid(uid):
+ self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
+ SocketUDPLoopBack(packet_count, 4, self.prog_fd)
+ self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+ DeleteMap(self.map_fd, uid);
+ SocketUDPLoopBack(packet_count, 6, self.prog_fd)
+ self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+
+@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+ "Cgroup BPF is not fully supported")
+class BpfCgroupTest(net_test.NetworkTest):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cg_dir = tempfile.mkdtemp(prefix="cg_bpf-")
+ cmd = "mount -t cgroup2 cg_bpf %s" % cls._cg_dir
+ try:
+ subprocess.check_call(cmd.split())
+ except subprocess.CalledProcessError:
+ # If an exception is thrown in setUpClass, the test fails and
+ # tearDownClass is not called.
+ os.rmdir(cls._cg_dir)
+ raise
+ cls._cg_fd = os.open(cls._cg_dir, os.O_DIRECTORY | os.O_RDONLY)
+
+ @classmethod
+ def tearDownClass(cls):
+ os.close(cls._cg_fd)
+ subprocess.call(('umount %s' % cls._cg_dir).split())
+ os.rmdir(cls._cg_dir)
+
+ def setUp(self):
+ self.prog_fd = -1
+ self.map_fd = -1
+
+ def tearDown(self):
+ if self.prog_fd >= 0:
+ os.close(self.prog_fd)
+ if self.map_fd >= 0:
+ os.close(self.map_fd)
+ try:
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
+ except socket.error:
+ pass
+ try:
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+ except socket.error:
+ pass
+ try:
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+ except socket.error:
+ pass
+
+ def testCgroupBpfAttach(self):
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+ BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+
+ def testCgroupIngress(self):
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+ BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+ self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
+ self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+ SocketUDPLoopBack(1, 4, None)
+ SocketUDPLoopBack(1, 6, None)
+
+ def testCgroupEgress(self):
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+ BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
+ self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
+ self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
+ SocketUDPLoopBack( 1, 4, None)
+ SocketUDPLoopBack( 1, 6, None)
+
+ def testCgroupBpfUid(self):
+ self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+ TOTAL_ENTRIES)
+ # Similar to the program used in testGetSocketUid.
+ instructions = [
+ BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+ BpfFuncCall(BPF_FUNC_get_socket_uid)
+ ]
+ instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+ + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT)
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
+ BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+ packet_count = 20
+ uid = TEST_UID
+ with net_test.RunAsUid(uid):
+ self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
+ SocketUDPLoopBack(packet_count, 4, None)
+ self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+ DeleteMap(self.map_fd, uid)
+ SocketUDPLoopBack(packet_count, 6, None)
+ self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+
+ def checkSocketCreate(self, family, socktype, success):
+ try:
+ sock = socket.socket(family, socktype, 0)
+ sock.close()
+ except socket.error, e:
+ if success:
+ self.fail("Failed to create socket family=%d type=%d err=%s" %
+ (family, socktype, os.strerror(e.errno)))
+ return;
+ if not success:
+ self.fail("unexpected socket family=%d type=%d created, should be blocked" %
+ (family, socktype))
+
+
+ def trySocketCreate(self, success):
+ for family in [socket.AF_INET, socket.AF_INET6]:
+ for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
+ self.checkSocketCreate(family, socktype, success)
+
+ @unittest.skipUnless(HAVE_EBPF_SOCKET,
+ "Cgroup BPF socket is not supported")
+ def testCgroupSocketCreateBlock(self):
+ instructions = [
+ BpfFuncCall(BPF_FUNC_get_current_uid_gid),
+ BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
+ BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
+ ]
+ instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT;
+ self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
+ BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+ with net_test.RunAsUid(TEST_UID):
+ # Socket creation with target uid should fail
+ self.trySocketCreate(False);
+ # Socket create with different uid should success
+ self.trySocketCreate(True)
+ BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+ with net_test.RunAsUid(TEST_UID):
+ self.trySocketCreate(True)
if __name__ == "__main__":
unittest.main()
diff --git a/net/test/build_rootfs.sh b/net/test/build_rootfs.sh
new file mode 100755
index 0000000..ce09da1
--- /dev/null
+++ b/net/test/build_rootfs.sh
@@ -0,0 +1,139 @@
+#!/bin/bash
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+set -e
+
+SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P)
+
+usage() {
+ echo -n "usage: $0 [-h] [-s wheezy|stretch] [-a amd64|arm64] "
+ echo "[-m http://mirror/debian] [-n net_test.rootfs.`date +%Y%m%d`]"
+ exit 1
+}
+
+mirror=http://ftp.debian.org/debian
+debootstrap=debootstrap
+suite=stretch
+arch=amd64
+
+while getopts ":hs:a:m:n:" opt; do
+ case $opt in
+ h)
+ usage
+ ;;
+ s)
+ if [ "$OPTARG" != "wheezy" -a "$OPTARG" != "stretch" ]; then
+ echo "Invalid suite: $OPTARG" >&2
+ usage
+ fi
+ suite=$OPTARG
+ ;;
+ a)
+ if [ "$OPTARG" != "amd64" -a "$OPTARG" != "arm64" ]; then
+ echo "Invalid arch: $OPTARG" >&2
+ usage
+ fi
+ arch=$OPTARG
+ ;;
+ m)
+ mirror=$OPTARG
+ ;;
+ n)
+ name=$OPTARG
+ ;;
+ \?)
+ echo "Invalid option: $OPTARG" >&2
+ usage
+ ;;
+ :)
+ echo "Invalid option: $OPTARG requires an argument" >&2
+ usage
+ ;;
+ esac
+done
+
+name=net_test.rootfs.$arch.`date +%Y%m%d`
+
+# Switch to qemu-debootstrap for incompatible architectures
+if [ "$arch" = "arm64" ]; then
+ debootstrap=qemu-debootstrap
+fi
+
+# Sometimes it isn't obvious when the script fails
+failure() {
+ echo "Filesystem generation process failed." >&2
+}
+trap failure ERR
+
+# Import the package list for this release
+packages=`cat $SCRIPT_DIR/rootfs/$suite.list | xargs | tr -s ' ' ','`
+
+# For the debootstrap intermediates
+workdir=`mktemp -d`
+workdir_remove() {
+ echo "Removing temporary files.." >&2
+ sudo rm -rf $workdir
+}
+trap workdir_remove EXIT
+
+# Run the debootstrap first
+cd $workdir
+sudo $debootstrap --arch=$arch --variant=minbase --include=$packages \
+ $suite . $mirror
+# Workarounds for bugs in the debootstrap suite scripts
+for mount in `cat /proc/mounts | cut -d' ' -f2 | grep -e ^$workdir`; do
+ echo "Unmounting mountpoint $mount.." >&2
+ sudo umount $mount
+done
+# Copy the chroot preparation scripts, and enter the chroot
+for file in $suite.sh common.sh net_test.sh; do
+ sudo cp -a $SCRIPT_DIR/rootfs/$file root/$file
+ sudo chown root:root root/$file
+done
+sudo chroot . /root/$suite.sh
+
+# Leave the workdir, to build the filesystem
+cd -
+
+# For the final image mount
+mount=`mktemp -d`
+mount_remove() {
+ rmdir $mount
+ workdir_remove
+}
+trap mount_remove EXIT
+
+# Create a 1G empty ext3 filesystem
+truncate -s 1G $name
+mke2fs -F -t ext3 -L ROOT $name
+
+# Mount the new filesystem locally
+sudo mount -o loop -t ext3 $name $mount
+image_unmount() {
+ sudo umount $mount
+ mount_remove
+}
+trap image_unmount EXIT
+
+# Copy the patched debootstrap results into the new filesystem
+sudo cp -a $workdir/* $mount
+
+# Fill the rest of the space with zeroes, to optimize compression
+sudo dd if=/dev/zero of=$mount/sparse bs=1M 2>/dev/null || true
+sudo rm -f $mount/sparse
+
+echo "Debian $suite for $arch filesystem generated at '$name'."
diff --git a/net/test/csocket.py b/net/test/csocket.py
index ee4a8f4..ccabf4a 100644
--- a/net/test/csocket.py
+++ b/net/test/csocket.py
@@ -17,18 +17,19 @@
import ctypes
import ctypes.util
import os
+import re
import socket
import struct
-import sys
import cstruct
+import util
# Data structures.
# These aren't constants, they're classes. So, pylint: disable=invalid-name
CMsgHdr = cstruct.Struct("cmsghdr", "@Lii", "len level type")
-Iovec = cstruct.Struct("iovec", "@LL", "base len")
-MsgHdr = cstruct.Struct("msghdr", "@LLLLLLi",
+Iovec = cstruct.Struct("iovec", "@PL", "base len")
+MsgHdr = cstruct.Struct("msghdr", "@LLPLPLi",
"name namelen iov iovlen control msg_controllen flags")
SockaddrIn = cstruct.Struct("sockaddr_in", "=HH4sxxxxxxxx", "family port addr")
SockaddrIn6 = cstruct.Struct("sockaddr_in6", "=HHI16sI",
@@ -75,19 +76,31 @@
libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
-# TODO: Move this to a utils.py or constants.py file, once we have one.
+# TODO: Unlike most of this file, these functions aren't specific to wrapping C
+# library calls. Move them to a utils.py or constants.py file, once we have one.
def LinuxVersion():
- # Example: "3.4.67-00753-gb7a556f".
- # Get the part before the dash.
- version = os.uname()[2].split("-")[0]
+ # Example: "3.4.67-00753-gb7a556f", "4.4.135+".
+ # Get the prefix consisting of digits and dots.
+ version = re.search("^[0-9.]*", os.uname()[2]).group()
# Convert it into a tuple such as (3, 4, 67). That allows comparing versions
# using < and >, since tuples are compared lexicographically.
version = tuple(int(i) for i in version.split("."))
return version
-def PaddedLength(length):
- return CMSG_ALIGNTO * ((length / CMSG_ALIGNTO) + (length % CMSG_ALIGNTO != 0))
+def AddressVersion(addr):
+ return 6 if ":" in addr else 4
+
+
+def SetSocketTimeout(sock, ms):
+ s = ms / 1000
+ us = (ms % 1000) * 1000
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVTIMEO,
+ struct.pack("LL", s, us))
+
+
+def VoidPointer(s):
+ return ctypes.cast(s.CPointer(), ctypes.c_void_p)
def MaybeRaiseSocketError(ret):
@@ -140,12 +153,15 @@
msg_level, msg_type, data = opt
if isinstance(data, int):
data = struct.pack("=I", data)
+ elif isinstance(data, ctypes.c_uint32):
+ data = struct.pack("=I", data.value)
elif not isinstance(data, str):
- raise TypeError("unknown data type for opt %i: %s" % (i, type(data)))
+ raise TypeError("unknown data type for opt (%d, %d): %s" % (
+ msg_level, msg_type, type(data)))
datalen = len(data)
msg_len = len(CMsgHdr) + datalen
- padding = "\x00" * (PaddedLength(datalen) - datalen)
+ padding = "\x00" * util.GetPadLength(CMSG_ALIGNTO, datalen)
msg_control += CMsgHdr((msg_len, msg_level, msg_type)).Pack()
msg_control += data + padding
@@ -158,7 +174,8 @@
while len(buf) > 0:
cmsghdr, buf = cstruct.Read(buf, CMsgHdr)
datalen = cmsghdr.len - len(CMsgHdr)
- data, buf = buf[:datalen], buf[PaddedLength(datalen):]
+ padlen = util.GetPadLength(CMSG_ALIGNTO, datalen)
+ data, buf = buf[:datalen], buf[padlen + datalen:]
if cmsghdr.level == socket.IPPROTO_IP:
if cmsghdr.type == IP_PKTINFO:
@@ -186,14 +203,14 @@
def Bind(s, to):
"""Python wrapper for bind."""
- ret = libc.bind(s.fileno(), to.CPointer(), len(to))
+ ret = libc.bind(s.fileno(), VoidPointer(to), len(to))
MaybeRaiseSocketError(ret)
return ret
def Connect(s, to):
"""Python wrapper for connect."""
- ret = libc.connect(s.fileno(), to.CPointer(), len(to))
+ ret = libc.connect(s.fileno(), VoidPointer(to), len(to))
MaybeRaiseSocketError(ret)
return ret
@@ -305,7 +322,7 @@
msghdr = MsgHdr((msg_name, msg_namelen, msg_iov, msg_iovlen,
msg_control, msg_controllen, flags))
- ret = libc.recvmsg(s.fileno(), msghdr.CPointer(), flags)
+ ret = libc.recvmsg(s.fileno(), VoidPointer(msghdr), flags)
MaybeRaiseSocketError(ret)
data = buf.raw[:ret]
@@ -333,3 +350,24 @@
addr = _ToSocketAddress(addr.raw, alen)
return data, addr
+
+
+def Setsockopt(s, level, optname, optval, optlen):
+ """Python wrapper for setsockopt.
+
+ Mostly identical to the built-in setsockopt, but allows passing in arbitrary
+ binary blobs, including NULL options, which the built-in python setsockopt does
+ not allow.
+
+ Args:
+ s: The socket object on which to set the option.
+ level: The level parameter.
+ optname: The option to set.
+ optval: A raw byte string, the value to set the option to (None for NULL).
+ optlen: An integer, the length of the option.
+
+ Raises:
+ socket.error: if setsockopt fails.
+ """
+ ret = libc.setsockopt(s.fileno(), level, optname, optval, optlen)
+ MaybeRaiseSocketError(ret)
diff --git a/net/test/cstruct.py b/net/test/cstruct.py
index 43c47a2..5e05263 100644
--- a/net/test/cstruct.py
+++ b/net/test/cstruct.py
@@ -25,7 +25,8 @@
>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
>>>
>>>
->>> # Create instances from tuples or raw bytes. Data past the end is ignored.
+>>> # Create instances from a tuple of values, raw bytes, zero-initialized, or
+>>> # using keywords.
... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
>>> print n1
NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
@@ -35,6 +36,14 @@
>>> print n2
NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
>>>
+>>> n3 = netlink.NLMsgHdr() # Zero-initialized
+>>> print n3
+NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0)
+>>>
+>>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized
+>>> print n4
+NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0)
+>>>
>>> # Serialize to raw bytes.
... print n1.Pack().encode("hex")
2c0000002000020000000000eb010000
@@ -103,8 +112,7 @@
# List of string fields that are ASCII strings.
_asciiz = set()
- if isinstance(_fieldnames, str):
- _fieldnames = _fieldnames.split(" ")
+ _fieldnames = _fieldnames.split(" ")
# Parse fmt into _format, converting any S format characters to "XXs",
# where XX is the length of the struct type's packed representation.
@@ -123,12 +131,31 @@
_asciiz.add(index)
_format += "s"
else:
- # Standard struct format character.
+ # Standard struct format character.
_format += fmt[i]
_length = CalcSize(_format)
+ offset_list = [0]
+ last_offset = 0
+ for i in xrange(len(_format)):
+ offset = CalcSize(_format[:i])
+ if offset > last_offset:
+ last_offset = offset
+ offset_list.append(offset)
+
+ # A dictionary that maps field names to their offsets in the struct.
+ _offsets = dict(zip(_fieldnames, offset_list))
+
+ # Check that the number of field names matches the number of fields.
+ numfields = len(struct.unpack(_format, "\x00" * _length))
+ if len(_fieldnames) != numfields:
+ raise ValueError("Invalid cstruct: \"%s\" has %d elements, \"%s\" has %d."
+ % (fmt, numfields, fieldnames, len(_fieldnames)))
+
def _SetValues(self, values):
+ # Replace self._values with the given list. We can't do direct assignment
+ # because of the __setattr__ overload on this class.
super(CStruct, self).__setattr__("_values", list(values))
def _Parse(self, data):
@@ -139,19 +166,37 @@
values[index] = self._nested[index](value)
self._SetValues(values)
- def __init__(self, values):
- # Initializing from a string.
- if isinstance(values, str):
- if len(values) < self._length:
+ def __init__(self, tuple_or_bytes=None, **kwargs):
+ """Construct an instance of this Struct.
+
+ 1. With no args, the whole struct is zero-initialized.
+ 2. With keyword args, the matching fields are populated; rest are zeroed.
+ 3. With one tuple as the arg, the fields are assigned based on position.
+ 4. With one string arg, the Struct is parsed from bytes.
+ """
+ if tuple_or_bytes and kwargs:
+ raise TypeError(
+ "%s: cannot specify both a tuple and keyword args" % self._name)
+
+ if tuple_or_bytes is None:
+ # Default construct from null bytes.
+ self._Parse("\x00" * len(self))
+ # If any keywords were supplied, set those fields.
+ for k, v in kwargs.iteritems():
+ setattr(self, k, v)
+ elif isinstance(tuple_or_bytes, str):
+ # Initializing from a string.
+ if len(tuple_or_bytes) < self._length:
raise TypeError("%s requires string of length %d, got %d" %
- (self._name, self._length, len(values)))
- self._Parse(values)
+ (self._name, self._length, len(tuple_or_bytes)))
+ self._Parse(tuple_or_bytes)
else:
# Initializing from a tuple.
- if len(values) != len(self._fieldnames):
+ if len(tuple_or_bytes) != len(self._fieldnames):
raise TypeError("%s has exactly %d fieldnames (%d given)" %
- (self._name, len(self._fieldnames), len(values)))
- self._SetValues(values)
+ (self._name, len(self._fieldnames),
+ len(tuple_or_bytes)))
+ self._SetValues(tuple_or_bytes)
def _FieldIndex(self, attr):
try:
@@ -164,8 +209,15 @@
return self._values[self._FieldIndex(name)]
def __setattr__(self, name, value):
+ # TODO: check value type against self._format and throw here, or else
+ # callers get an unhelpful exception when they call Pack().
self._values[self._FieldIndex(name)] = value
+ def offset(self, name):
+ if "." in name:
+ raise NotImplementedError("offset() on nested field")
+ return self._offsets[name]
+
@classmethod
def __len__(cls):
return cls._length
diff --git a/net/test/cstruct_test.py b/net/test/cstruct_test.py
index fdcbd55..6b27973 100755
--- a/net/test/cstruct_test.py
+++ b/net/test/cstruct_test.py
@@ -110,6 +110,80 @@
" int3=12345, ascii4=hello\x00visible123, word5=33210)")
self.assertEquals(expected, str(t))
+ def testZeroInitialization(self):
+ TestStruct = cstruct.Struct("TestStruct", "B16si16AH",
+ "byte1 string2 int3 ascii4 word5")
+ t = TestStruct()
+ self.assertEquals(0, t.byte1)
+ self.assertEquals("\x00" * 16, t.string2)
+ self.assertEquals(0, t.int3)
+ self.assertEquals("\x00" * 16, t.ascii4)
+ self.assertEquals(0, t.word5)
+ self.assertEquals("\x00" * len(TestStruct), t.Pack())
+
+ def testKeywordInitialization(self):
+ TestStruct = cstruct.Struct("TestStruct", "=B16sIH",
+ "byte1 string2 int3 word4")
+ text = "hello world! ^_^"
+ text_bytes = text.encode("hex")
+
+ # Populate all fields
+ t1 = TestStruct(byte1=1, string2=text, int3=0xFEDCBA98, word4=0x1234)
+ expected = ("01" + text_bytes + "98BADCFE" "3412").decode("hex")
+ self.assertEquals(expected, t1.Pack())
+
+ # Partially populated
+ t1 = TestStruct(string2=text, word4=0x1234)
+ expected = ("00" + text_bytes + "00000000" "3412").decode("hex")
+ self.assertEquals(expected, t1.Pack())
+
+ def testCstructOffset(self):
+ TestStruct = cstruct.Struct("TestStruct", "B16si16AH",
+ "byte1 string2 int3 ascii4 word5")
+ nullstr = "hello" + (16 - len("hello")) * "\x00"
+ t = TestStruct((2, nullstr, 12345, nullstr, 33210))
+ self.assertEquals(0, t.offset("byte1"))
+ self.assertEquals(1, t.offset("string2")) # sizeof(byte)
+ self.assertEquals(17, t.offset("int3")) # sizeof(byte) + 16*sizeof(char)
+ # The integer is automatically padded by the struct module
+ # to match native alignment.
+ # offset = sizeof(byte) + 16*sizeof(char) + padding + sizeof(int)
+ self.assertEquals(24, t.offset("ascii4"))
+ self.assertEquals(40, t.offset("word5"))
+ self.assertRaises(KeyError, t.offset, "random")
+
+ # TODO: Add support for nested struct offset
+ Nested = cstruct.Struct("Nested", "!HSSi", "word1 nest2 nest3 int4",
+ [TestStructA, TestStructB])
+ DoubleNested = cstruct.Struct("DoubleNested", "SSB", "nest1 nest2 byte3",
+ [TestStructA, Nested])
+ d = DoubleNested((TestStructA((1, 2)), Nested((5, TestStructA((3, 4)),
+ TestStructB((7, 8)), 9)), 6))
+ self.assertEqual(0, d.offset("nest1"))
+ self.assertEqual(len(TestStructA), d.offset("nest2"))
+ self.assertEqual(len(TestStructA) + len(Nested), d.offset("byte3"))
+ self.assertRaises(KeyError, t.offset, "word1")
+
+ def testDefinitionFieldMismatch(self):
+ cstruct.Struct("TestA", "=BI", "byte1 int2")
+ cstruct.Struct("TestA", "=BxxxxxIx", "byte1 int2")
+ with self.assertRaises(ValueError):
+ cstruct.Struct("TestA", "=B", "byte1 int2")
+ with self.assertRaises(ValueError):
+ cstruct.Struct("TestA", "=BI", "byte1")
+
+ Nested = cstruct.Struct("Nested", "!II", "int1 int2")
+ cstruct.Struct("TestB", "=BSI", "byte1 nest2 int3", [Nested])
+ cstruct.Struct("TestB", "=BxSxIx", "byte1 nest2 int3", [Nested])
+ with self.assertRaises(ValueError):
+ cstruct.Struct("TestB", "=BSI", "byte1 int3", [Nested])
+ with self.assertRaises(ValueError):
+ cstruct.Struct("TestB", "=BSI", "byte1 nest2", [Nested])
+
+ cstruct.Struct("TestC", "=BSSI", "byte1 nest2 nest3 int4", [Nested, Nested])
+ with self.assertRaises(ValueError):
+ cstruct.Struct("TestC", "=BSSI", "byte1 nest2 int4", [Nested, Nested])
+
if __name__ == "__main__":
unittest.main()
diff --git a/net/test/forwarding_test.py b/net/test/forwarding_test.py
index 527d780..34394cd 100755
--- a/net/test/forwarding_test.py
+++ b/net/test/forwarding_test.py
@@ -20,26 +20,11 @@
from socket import *
-import iproute
import multinetwork_base
import net_test
import packets
-
class ForwardingTest(multinetwork_base.MultiNetworkBaseTest):
- """Checks that IPv6 forwarding doesn't crash the system.
-
- Relevant kernel commits:
- upstream net-next:
- e7eadb4 ipv6: inet6_sk() should use sk_fullsock()
- android-3.10:
- feee3c1 ipv6: inet6_sk() should use sk_fullsock()
- cdab04e net: add sk_fullsock() helper
- android-3.18:
- 8246f18 ipv6: inet6_sk() should use sk_fullsock()
- bea19db net: add sk_fullsock() helper
- """
-
TCP_TIME_WAIT = 6
def ForwardBetweenInterfaces(self, enabled, iface1, iface2):
@@ -53,7 +38,63 @@
def tearDown(self):
self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
- def CheckForwardingCrash(self, netid, iface1, iface2):
+ """Checks that IPv6 forwarding works for UDP packets and is not broken by early demux.
+
+ Relevant kernel commits:
+ upstream:
+ 5425077d73e0c8e net: ipv6: Add early demux handler for UDP unicast
+ 0bd84065b19bca1 net: ipv6: Fix UDP early demux lookup with udp_l3mdev_accept=0
+ Ifa9c2ddfaa5b51 net: ipv6: reset daddr and dport in sk if connect() fails
+ """
+ def CheckForwardingUdp(self, netid, iface1, iface2):
+ # TODO: Make a test for IPv4
+ # 1. Make version as an argument. Pick address to bind from array based
+ # on version.
+ # 2. The prefix length of the address is hardcoded to /64. Use the subnet
+ # mask there instead.
+ # 3. We recreate the address with SendRA, which obviously only works for
+ # IPv6. Use AddAddress for IPv4.
+
+ # Create a UDP socket and bind to it
+ version = 6
+ s = net_test.UDPSocket(AF_INET6)
+ self.SetSocketMark(s, netid)
+ s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ s.bind(("::", 53))
+
+ remoteaddr = self.GetRemoteAddress(version)
+ myaddr = self.MyAddress(version, netid)
+
+ try:
+ # Delete address and check if packet is forwarded
+ # (and not dropped because an incorrect socket match happened)
+ self.iproute.DelAddress(myaddr, 64, self.ifindices[netid])
+ hoplimit = 39
+ desc, udp_pkt = packets.UDPWithOptions(version, myaddr, remoteaddr, 53)
+ # Decrements the hoplimit of a packet to simulate forwarding.
+ desc_fwded, udp_fwd = packets.UDPWithOptions(version, myaddr, remoteaddr,
+ 53, hoplimit - 1)
+ msg = "Sent %s, expected %s" % (desc, desc_fwded)
+ self.ReceivePacketOn(iface1, udp_pkt)
+ self.ExpectPacketOn(iface2, msg, udp_fwd)
+ finally:
+ # Recreate the address.
+ self.SendRA(netid)
+ s.close()
+
+ """Checks that IPv6 forwarding doesn't crash the system.
+
+ Relevant kernel commits:
+ upstream net-next:
+ e7eadb4 ipv6: inet6_sk() should use sk_fullsock()
+ android-3.10:
+ feee3c1 ipv6: inet6_sk() should use sk_fullsock()
+ cdab04e net: add sk_fullsock() helper
+ android-3.18:
+ 8246f18 ipv6: inet6_sk() should use sk_fullsock()
+ bea19db net: add sk_fullsock() helper
+ """
+ def CheckForwardingCrashTcp(self, netid, iface1, iface2):
version = 6
listensocket = net_test.IPv6TCPSocket()
self.SetSocketMark(listensocket, netid)
@@ -102,17 +143,30 @@
self.SendRA(netid)
listensocket.close()
- def testCrash(self):
+ def CheckForwardingHandlerByProto(self, protocol, netid, iif, oif):
+ if protocol == IPPROTO_UDP:
+ self.CheckForwardingUdp(netid, iif, oif)
+ elif protocol == IPPROTO_TCP:
+ self.CheckForwardingCrashTcp(netid, iif, oif)
+ else:
+ raise NotImplementedError
+
+ def CheckForwardingByProto(self, proto):
# Run the test a few times as it doesn't crash/hang the first time.
for netids in itertools.permutations(self.tuns):
# Pick an interface to send traffic on and two to forward traffic between.
netid, iface1, iface2 = random.sample(netids, 3)
self.ForwardBetweenInterfaces(True, iface1, iface2)
try:
- self.CheckForwardingCrash(netid, iface1, iface2)
+ self.CheckForwardingHandlerByProto(proto, netid, iface1, iface2)
finally:
self.ForwardBetweenInterfaces(False, iface1, iface2)
+ def testForwardingUdp(self):
+ self.CheckForwardingByProto(IPPROTO_UDP)
+
+ def testForwardingCrashTcp(self):
+ self.CheckForwardingByProto(IPPROTO_TCP)
if __name__ == "__main__":
unittest.main()
diff --git a/net/test/genetlink.py b/net/test/genetlink.py
new file mode 100755
index 0000000..dda3964
--- /dev/null
+++ b/net/test/genetlink.py
@@ -0,0 +1,123 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Classes for generic netlink."""
+
+import collections
+from socket import * # pylint: disable=wildcard-import
+import struct
+
+import cstruct
+import netlink
+
+### Generic netlink constants. See include/uapi/linux/genetlink.h.
+# The generic netlink control family.
+GENL_ID_CTRL = 16
+
+# Commands.
+CTRL_CMD_GETFAMILY = 3
+
+# Attributes.
+CTRL_ATTR_FAMILY_ID = 1
+CTRL_ATTR_FAMILY_NAME = 2
+CTRL_ATTR_VERSION = 3
+CTRL_ATTR_HDRSIZE = 4
+CTRL_ATTR_MAXATTR = 5
+CTRL_ATTR_OPS = 6
+CTRL_ATTR_MCAST_GROUPS = 7
+
+# Attributes netsted inside CTRL_ATTR_OPS.
+CTRL_ATTR_OP_ID = 1
+CTRL_ATTR_OP_FLAGS = 2
+
+
+# Data structure formats.
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+Genlmsghdr = cstruct.Struct("genlmsghdr", "BBxx", "cmd version")
+
+
+class GenericNetlink(netlink.NetlinkSocket):
+ """Base class for all generic netlink classes."""
+
+ NL_DEBUG = []
+
+ def __init__(self):
+ super(GenericNetlink, self).__init__(netlink.NETLINK_GENERIC)
+
+ def _SendCommand(self, family, command, version, data, flags):
+ genlmsghdr = Genlmsghdr((command, version))
+ self._SendNlRequest(family, genlmsghdr.Pack() + data, flags)
+
+ def _Dump(self, family, command, version):
+ msg = Genlmsghdr((command, version))
+ return super(GenericNetlink, self)._Dump(family, msg, Genlmsghdr, "")
+
+
+class GenericNetlinkControl(GenericNetlink):
+ """Generic netlink control class.
+
+ This interface is used to manage other generic netlink families. We currently
+ use it only to find the family ID for address families of interest."""
+
+ def _DecodeOps(self, data):
+ ops = []
+ Op = collections.namedtuple("Op", ["id", "flags"])
+ while data:
+ # Skip the nest marker.
+ datalen, index, data = data[:2], data[2:4], data[4:]
+
+ nla, nla_data, data = self._ReadNlAttr(data)
+ if nla.nla_type != CTRL_ATTR_OP_ID:
+ raise ValueError("Expected CTRL_ATTR_OP_ID, got %d" % nla.nla_type)
+ op_id = struct.unpack("=I", nla_data)[0]
+
+ nla, nla_data, data = self._ReadNlAttr(data)
+ if nla.nla_type != CTRL_ATTR_OP_FLAGS:
+ raise ValueError("Expected CTRL_ATTR_OP_FLAGS, got %d" % nla.type)
+ op_flags = struct.unpack("=I", nla_data)[0]
+
+ ops.append(Op(op_id, op_flags))
+ return ops
+
+ def _Decode(self, command, msg, nla_type, nla_data):
+ """Decodes generic netlink control attributes to human-readable format."""
+
+ name = self._GetConstantName(__name__, nla_type, "CTRL_ATTR_")
+
+ if name == "CTRL_ATTR_FAMILY_ID":
+ data = struct.unpack("=H", nla_data)[0]
+ elif name == "CTRL_ATTR_FAMILY_NAME":
+ data = nla_data.strip("\x00")
+ elif name in ["CTRL_ATTR_VERSION", "CTRL_ATTR_HDRSIZE", "CTRL_ATTR_MAXATTR"]:
+ data = struct.unpack("=I", nla_data)[0]
+ elif name == "CTRL_ATTR_OPS":
+ data = self._DecodeOps(nla_data)
+ else:
+ data = nla_data
+
+ return name, data
+
+ def GetFamily(self, name):
+ """Returns the family ID for the specified family name."""
+ data = self._NlAttrStr(CTRL_ATTR_FAMILY_NAME, name)
+ self._SendCommand(GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 0, data, netlink.NLM_F_REQUEST)
+ hdr, attrs = self._GetMsg(Genlmsghdr)
+ return attrs["CTRL_ATTR_FAMILY_ID"]
+
+
+if __name__ == "__main__":
+ g = GenericNetlinkControl()
+ print g.GetFamily("tcp_metrics")
diff --git a/net/test/iproute.py b/net/test/iproute.py
index a570d3d..8376eb6 100644
--- a/net/test/iproute.py
+++ b/net/test/iproute.py
@@ -18,29 +18,18 @@
# pylint: disable=g-bad-todo
+from socket import AF_INET
+from socket import AF_INET6
+
import errno
import os
import socket
import struct
-import sys
import csocket
import cstruct
import netlink
-### Base netlink constants. See include/uapi/linux/netlink.h.
-NETLINK_ROUTE = 0
-
-# Data structure formats.
-# These aren't constants, they're classes. So, pylint: disable=invalid-name
-NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
-NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
-NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
-
-# Alignment / padding.
-NLA_ALIGNTO = 4
-
-
### rtnetlink constants. See include/uapi/linux/rtnetlink.h.
# Message types.
RTM_NEWLINK = 16
@@ -63,10 +52,13 @@
RTN_UNSPEC = 0
RTN_UNICAST = 1
RTN_UNREACHABLE = 7
+RTN_THROW = 9
# Routing protocol values (rtm_protocol).
RTPROT_UNSPEC = 0
+RTPROT_BOOT = 3
RTPROT_STATIC = 4
+RTPROT_RA = 9
# Route scope values (rtm_scope).
RT_SCOPE_UNIVERSE = 0
@@ -78,6 +70,7 @@
# Routing attributes.
RTA_DST = 1
RTA_SRC = 2
+RTA_IIF = 3
RTA_OIF = 4
RTA_GATEWAY = 5
RTA_PRIORITY = 6
@@ -89,6 +82,9 @@
RTA_PREF = 20
RTA_UID = 25
+# Netlink groups.
+RTMGRP_IPV6_IFADDR = 0x100
+
# Route metric attributes.
RTAX_MTU = 2
RTAX_HOPLIMIT = 10
@@ -121,6 +117,17 @@
IFA_F_TENTATIVE = 0x40
IFA_F_PERMANENT = 0x80
+# This cannot contain members that do not yet exist in older kernels, because
+# GetIfaceStats will fail if the kernel returns fewer bytes than the size of
+# RtnlLinkStats[64].
+_LINK_STATS_MEMBERS = (
+ "rx_packets tx_packets rx_bytes tx_bytes rx_errors tx_errors "
+ "rx_dropped tx_dropped multicast collisions "
+ "rx_length_errors rx_over_errors rx_crc_errors rx_frame_errors "
+ "rx_fifo_errors rx_missed_errors tx_aborted_errors tx_carrier_errors "
+ "tx_fifo_errors tx_heartbeat_errors tx_window_errors "
+ "rx_compressed tx_compressed")
+
# Data structure formats.
IfAddrMsg = cstruct.Struct(
"IfAddrMsg", "=BBBBI",
@@ -129,7 +136,10 @@
"IFACacheinfo", "=IIII", "prefered valid cstamp tstamp")
NDACacheinfo = cstruct.Struct(
"NDACacheinfo", "=IIII", "confirmed used updated refcnt")
-
+RtnlLinkStats = cstruct.Struct(
+ "RtnlLinkStats", "=IIIIIIIIIIIIIIIIIIIIIII", _LINK_STATS_MEMBERS)
+RtnlLinkStats64 = cstruct.Struct(
+ "RtnlLinkStats64", "=QQQQQQQQQQQQQQQQQQQQQQQ", _LINK_STATS_MEMBERS)
### Neighbour table entry constants. See include/uapi/linux/neighbour.h.
# Neighbour cache entry attributes.
@@ -137,6 +147,7 @@
NDA_LLADDR = 2
NDA_CACHEINFO = 3
NDA_PROBES = 4
+NDA_IFINDEX = 8
# Neighbour cache entry states.
NUD_PERMANENT = 0x80
@@ -161,16 +172,27 @@
FibRuleUidRange = cstruct.Struct("FibRuleUidRange", "=II", "start end")
# Link constants. See include/uapi/linux/if_link.h.
+IFLA_UNSPEC = 0
IFLA_ADDRESS = 1
IFLA_BROADCAST = 2
IFLA_IFNAME = 3
IFLA_MTU = 4
+IFLA_LINK = 5
IFLA_QDISC = 6
IFLA_STATS = 7
+IFLA_COST = 8
+IFLA_PRIORITY = 9
+IFLA_MASTER = 10
+IFLA_WIRELESS = 11
+IFLA_PROTINFO = 12
IFLA_TXQLEN = 13
IFLA_MAP = 14
+IFLA_WEIGHT = 15
IFLA_OPERSTATE = 16
IFLA_LINKMODE = 17
+IFLA_LINKINFO = 18
+IFLA_NET_NS_PID = 19
+IFLA_IFALIAS = 20
IFLA_STATS64 = 23
IFLA_AF_SPEC = 26
IFLA_GROUP = 27
@@ -179,6 +201,31 @@
IFLA_NUM_TX_QUEUES = 31
IFLA_NUM_RX_QUEUES = 32
IFLA_CARRIER = 33
+IFLA_CARRIER_CHANGES = 35
+IFLA_PROTO_DOWN = 39
+IFLA_GSO_MAX_SEGS = 40
+IFLA_GSO_MAX_SIZE = 41
+IFLA_PAD = 42
+IFLA_XDP = 43
+IFLA_EVENT = 44
+
+# include/uapi/linux/if_link.h
+IFLA_INFO_UNSPEC = 0
+IFLA_INFO_KIND = 1
+IFLA_INFO_DATA = 2
+IFLA_INFO_XSTATS = 3
+
+IFLA_XFRM_UNSPEC = 0
+IFLA_XFRM_LINK = 1
+IFLA_XFRM_IF_ID = 2
+
+# include/uapi/linux/if_tunnel.h
+IFLA_VTI_UNSPEC = 0
+IFLA_VTI_LINK = 1
+IFLA_VTI_IKEY = 2
+IFLA_VTI_OKEY = 3
+IFLA_VTI_LOCAL = 4
+IFLA_VTI_REMOTE = 5
def CommandVerb(command):
@@ -199,18 +246,13 @@
class IPRoute(netlink.NetlinkSocket):
"""Provides a tiny subset of iproute functionality."""
- FAMILY = NETLINK_ROUTE
-
- def _NlAttrIPAddress(self, nla_type, family, address):
- return self._NlAttr(nla_type, socket.inet_pton(family, address))
-
def _NlAttrInterfaceName(self, nla_type, interface):
return self._NlAttr(nla_type, interface + "\x00")
def _GetConstantName(self, value, prefix):
return super(IPRoute, self)._GetConstantName(__name__, value, prefix)
- def _Decode(self, command, msg, nla_type, nla_data):
+ def _Decode(self, command, msg, nla_type, nla_data, nested=0):
"""Decodes netlink attributes to Python types.
Values for which the code knows the type (e.g., the fwmark ID in a
@@ -226,9 +268,11 @@
incoming interface name and is a string.
- If negative, one of the following (negative) values:
- RTA_METRICS: Interpret as nested route metrics.
+ - IFLA_LINKINFO: Nested interface information.
family: The address family. Used to convert IP addresses into strings.
nla_type: An integer, then netlink attribute type.
nla_data: A byte string, the netlink attribute data.
+ nested: An integer, how deep we're currently nested.
Returns:
A tuple (name, data):
@@ -241,6 +285,10 @@
"""
if command == -RTA_METRICS:
name = self._GetConstantName(nla_type, "RTAX_")
+ elif command == -IFLA_LINKINFO:
+ name = self._GetConstantName(nla_type, "IFLA_INFO_")
+ elif command == -IFLA_INFO_DATA:
+ name = self._GetConstantName(nla_type, "IFLA_VTI_")
elif CommandSubject(command) == "ADDR":
name = self._GetConstantName(nla_type, "IFA_")
elif CommandSubject(command) == "LINK":
@@ -260,41 +308,51 @@
"IFLA_MTU", "IFLA_TXQLEN", "IFLA_GROUP", "IFLA_EXT_MASK",
"IFLA_PROMISCUITY", "IFLA_NUM_RX_QUEUES",
"IFLA_NUM_TX_QUEUES", "NDA_PROBES", "RTAX_MTU",
- "RTAX_HOPLIMIT"]:
+ "RTAX_HOPLIMIT", "IFLA_CARRIER_CHANGES", "IFLA_GSO_MAX_SEGS",
+ "IFLA_GSO_MAX_SIZE", "RTA_UID"]:
data = struct.unpack("=I", nla_data)[0]
+ elif name in ["IFLA_VTI_OKEY", "IFLA_VTI_IKEY"]:
+ data = struct.unpack("!I", nla_data)[0]
elif name == "FRA_SUPPRESS_PREFIXLEN":
data = struct.unpack("=i", nla_data)[0]
elif name in ["IFLA_LINKMODE", "IFLA_OPERSTATE", "IFLA_CARRIER"]:
data = ord(nla_data)
elif name in ["IFA_ADDRESS", "IFA_LOCAL", "RTA_DST", "RTA_SRC",
- "RTA_GATEWAY", "RTA_PREFSRC", "RTA_UID",
- "NDA_DST"]:
+ "RTA_GATEWAY", "RTA_PREFSRC", "NDA_DST"]:
data = socket.inet_ntop(msg.family, nla_data)
elif name in ["FRA_IIFNAME", "FRA_OIFNAME", "IFLA_IFNAME", "IFLA_QDISC",
- "IFA_LABEL"]:
+ "IFA_LABEL", "IFLA_INFO_KIND"]:
data = nla_data.strip("\x00")
elif name == "RTA_METRICS":
- data = self._ParseAttributes(-RTA_METRICS, None, nla_data)
+ data = self._ParseAttributes(-RTA_METRICS, None, nla_data, nested + 1)
+ elif name == "IFLA_LINKINFO":
+ data = self._ParseAttributes(-IFLA_LINKINFO, None, nla_data, nested + 1)
+ elif name == "IFLA_INFO_DATA":
+ data = self._ParseAttributes(-IFLA_INFO_DATA, None, nla_data)
elif name == "RTA_CACHEINFO":
data = RTACacheinfo(nla_data)
elif name == "IFA_CACHEINFO":
data = IFACacheinfo(nla_data)
elif name == "NDA_CACHEINFO":
data = NDACacheinfo(nla_data)
- elif name in ["NDA_LLADDR", "IFLA_ADDRESS"]:
+ elif name in ["NDA_LLADDR", "IFLA_ADDRESS", "IFLA_BROADCAST"]:
data = ":".join(x.encode("hex") for x in nla_data)
elif name == "FRA_UID_RANGE":
data = FibRuleUidRange(nla_data)
+ elif name == "IFLA_STATS":
+ data = RtnlLinkStats(nla_data)
+ elif name == "IFLA_STATS64":
+ data = RtnlLinkStats64(nla_data)
else:
data = nla_data
return name, data
def __init__(self):
- super(IPRoute, self).__init__()
+ super(IPRoute, self).__init__(netlink.NETLINK_ROUTE)
def _AddressFamily(self, version):
- return {4: socket.AF_INET, 6: socket.AF_INET6}[version]
+ return {4: AF_INET, 6: AF_INET6}[version]
def _SendNlRequest(self, command, data, flags=0):
"""Sends a netlink request and expects an ack."""
@@ -303,8 +361,8 @@
if CommandVerb(command) != "GET":
flags |= netlink.NLM_F_ACK
if CommandVerb(command) == "NEW":
- if not flags & netlink.NLM_F_REPLACE:
- flags |= (netlink.NLM_F_EXCL | netlink.NLM_F_CREATE)
+ if flags & (netlink.NLM_F_REPLACE | netlink.NLM_F_CREATE) == 0:
+ flags |= netlink.NLM_F_CREATE | netlink.NLM_F_EXCL
super(IPRoute, self)._SendNlRequest(command, data, flags)
@@ -347,13 +405,14 @@
try:
self._SendNlRequest(RTM_DELRULE, rtmsg)
except IOError, e:
- if e.errno == -errno.ENOENT:
+ if e.errno == errno.ENOENT:
break
else:
raise
- def FwmarkRule(self, version, is_add, fwmark, table, priority):
+ def FwmarkRule(self, version, is_add, fwmark, fwmask, table, priority):
nlattr = self._NlAttrU32(FRA_FWMARK, fwmark)
+ nlattr += self._NlAttrU32(FRA_FWMASK, fwmask)
return self._Rule(version, is_add, RTN_UNICAST, table, nlattr, priority)
def IifRule(self, version, is_add, iif, table, priority):
@@ -374,7 +433,7 @@
return self._Rule(version, is_add, RTN_UNREACHABLE, None, None, priority)
def DefaultRule(self, version, is_add, table, priority):
- return self.FwmarkRule(version, is_add, 0, table, priority)
+ return self.FwmarkRule(version, is_add, 0, 0, table, priority)
def CommandToString(self, command, data):
try:
@@ -399,11 +458,11 @@
print self.CommandToString(command, data)
def MaybeDebugMessage(self, message):
- hdr = NLMsgHdr(message)
+ hdr = netlink.NLMsgHdr(message)
self.MaybeDebugCommand(hdr.type, message)
def PrintMessage(self, message):
- hdr = NLMsgHdr(message)
+ hdr = netlink.NLMsgHdr(message)
print self.CommandToString(hdr.type, message)
def DumpRules(self, version):
@@ -431,13 +490,62 @@
ifaddrmsg += self._NlAttrIPAddress(IFA_LOCAL, family, addr)
self._SendNlRequest(command, ifaddrmsg)
+ def _WaitForAddress(self, sock, address, ifindex):
+ # IPv6 addresses aren't immediately usable when the netlink ACK comes back.
+ # Even if DAD is disabled via IFA_F_NODAD or on the interface, when the ACK
+ # arrives the input route has not yet been added to the local table. The
+ # route is added in addrconf_dad_begin with a delayed timer of 0, but if
+ # the system is under load, we could win the race against that timer and
+ # cause the tests to be flaky. So, wait for RTM_NEWADDR to arrive
+ csocket.SetSocketTimeout(sock, 100)
+ while True:
+ try:
+ data = sock.recv(4096)
+ except EnvironmentError as e:
+ raise AssertionError("Address %s did not appear on ifindex %d: %s" %
+ (address, ifindex, e.strerror))
+ msg, attrs = self._ParseNLMsg(data, IfAddrMsg)[0]
+ if msg.index == ifindex and attrs["IFA_ADDRESS"] == address:
+ return
+
def AddAddress(self, address, prefixlen, ifindex):
- self._Address(6 if ":" in address else 4,
- RTM_NEWADDR, address, prefixlen,
- IFA_F_PERMANENT, RT_SCOPE_UNIVERSE, ifindex)
+ """Adds a statically-configured IP address to an interface.
+
+ The address is created with flags IFA_F_PERMANENT, and, if IPv6,
+ IFA_F_NODAD. The requested scope is RT_SCOPE_UNIVERSE, but at least for
+ IPv6, is instead determined by the kernel.
+
+ In order to avoid races (see comments in _WaitForAddress above), when
+ configuring IPv6 addresses, the method blocks until it receives an
+ RTM_NEWADDR from the kernel confirming that the address has been added.
+ If the address does not appear within 100ms, AssertionError is thrown.
+
+ Args:
+ address: A string, the IP address to configure.
+ prefixlen: The prefix length passed to the kernel. If not /32 for IPv4 or
+ /128 for IPv6, the kernel creates an implicit directly-connected route.
+ ifindex: The interface index to add the address to.
+
+ Raises:
+ AssertionError: An IPv6 address was requested, and it did not appear
+ within the timeout.
+ """
+ version = csocket.AddressVersion(address)
+
+ flags = IFA_F_PERMANENT
+ if version == 6:
+ flags |= IFA_F_NODAD
+ sock = self._OpenNetlinkSocket(netlink.NETLINK_ROUTE,
+ groups=RTMGRP_IPV6_IFADDR)
+
+ self._Address(version, RTM_NEWADDR, address, prefixlen, flags,
+ RT_SCOPE_UNIVERSE, ifindex)
+
+ if version == 6:
+ self._WaitForAddress(sock, address, ifindex)
def DelAddress(self, address, prefixlen, ifindex):
- self._Address(6 if ":" in address else 4,
+ self._Address(csocket.AddressVersion(address),
RTM_DELADDR, address, prefixlen, 0, 0, ifindex)
def GetAddress(self, address, ifindex=0):
@@ -450,13 +558,13 @@
self._Address(6, RTM_GETADDR, address, 0, 0, RT_SCOPE_UNIVERSE, ifindex)
return self._GetMsg(IfAddrMsg)
- def _Route(self, version, command, table, dest, prefixlen, nexthop, dev,
- mark, uid):
+ def _Route(self, version, proto, command, table, dest, prefixlen, nexthop,
+ dev, mark, uid, route_type=RTN_UNICAST, priority=None, iif=None):
"""Adds, deletes, or queries a route."""
family = self._AddressFamily(version)
scope = RT_SCOPE_UNIVERSE if nexthop else RT_SCOPE_LINK
rtmsg = RTMsg((family, prefixlen, 0, 0, RT_TABLE_UNSPEC,
- RTPROT_STATIC, scope, RTN_UNICAST, 0)).Pack()
+ proto, scope, route_type, 0)).Pack()
if command == RTM_NEWROUTE and not table:
# Don't allow setting routes in table 0, since its behaviour is confusing
# and differs between IPv4 and IPv6.
@@ -473,30 +581,35 @@
rtmsg += self._NlAttrU32(RTA_MARK, mark)
if uid is not None:
rtmsg += self._NlAttrU32(RTA_UID, uid)
+ if priority is not None:
+ rtmsg += self._NlAttrU32(RTA_PRIORITY, priority)
+ if iif is not None:
+ rtmsg += self._NlAttrU32(RTA_IIF, iif)
self._SendNlRequest(command, rtmsg)
def AddRoute(self, version, table, dest, prefixlen, nexthop, dev):
- self._Route(version, RTM_NEWROUTE, table, dest, prefixlen, nexthop, dev,
- None, None)
+ self._Route(version, RTPROT_STATIC, RTM_NEWROUTE, table, dest, prefixlen,
+ nexthop, dev, None, None)
def DelRoute(self, version, table, dest, prefixlen, nexthop, dev):
- self._Route(version, RTM_DELROUTE, table, dest, prefixlen, nexthop, dev,
- None, None)
+ self._Route(version, RTPROT_STATIC, RTM_DELROUTE, table, dest, prefixlen,
+ nexthop, dev, None, None)
- def GetRoutes(self, dest, oif, mark, uid):
- version = 6 if ":" in dest else 4
+ def GetRoutes(self, dest, oif, mark, uid, iif=None):
+ version = csocket.AddressVersion(dest)
prefixlen = {4: 32, 6: 128}[version]
- self._Route(version, RTM_GETROUTE, 0, dest, prefixlen, None, oif, mark, uid)
+ self._Route(version, RTPROT_STATIC, RTM_GETROUTE, 0, dest, prefixlen, None,
+ oif, mark, uid, iif=iif)
data = self._Recv()
# The response will either be an error or a list of routes.
- if NLMsgHdr(data).type == netlink.NLMSG_ERROR:
+ if netlink.NLMsgHdr(data).type == netlink.NLMSG_ERROR:
self._ParseAck(data)
routes = self._GetMsgList(RTMsg, data, False)
return routes
def DumpRoutes(self, version, ifindex):
- ndmsg = NdMsg((self._AddressFamily(version), 0, 0, 0, 0))
- return [(m, r) for (m, r) in self._Dump(RTM_GETROUTE, ndmsg, NdMsg, "")
+ rtmsg = RTMsg(family=self._AddressFamily(version))
+ return [(m, r) for (m, r) in self._Dump(RTM_GETROUTE, rtmsg, RTMsg, "")
if r['RTA_TABLE'] == ifindex]
def _Neighbour(self, version, is_add, addr, lladdr, dev, state, flags=0):
@@ -527,14 +640,131 @@
self._Neighbour(version, True, addr, lladdr, dev, state,
flags=netlink.NLM_F_REPLACE)
- def DumpNeighbours(self, version):
+ def DumpNeighbours(self, version, ifindex):
ndmsg = NdMsg((self._AddressFamily(version), 0, 0, 0, 0))
- return self._Dump(RTM_GETNEIGH, ndmsg, NdMsg, "")
+ attrs = self._NlAttrU32(NDA_IFINDEX, ifindex) if ifindex else ""
+ return self._Dump(RTM_GETNEIGH, ndmsg, NdMsg, attrs)
def ParseNeighbourMessage(self, msg):
msg, _ = self._ParseNLMsg(msg, NdMsg)
return msg
+ def DeleteLink(self, dev_name):
+ ifinfo = IfinfoMsg().Pack()
+ ifinfo += self._NlAttrStr(IFLA_IFNAME, dev_name)
+ return self._SendNlRequest(RTM_DELLINK, ifinfo)
+
+ def GetIfinfo(self, dev_name):
+ """Fetches information about the specified interface.
+
+ Args:
+ dev_name: A string, the name of the interface.
+
+ Returns:
+ A tuple containing an IfinfoMsg struct and raw, undecoded attributes.
+ """
+ ifinfo = IfinfoMsg().Pack()
+ ifinfo += self._NlAttrStr(IFLA_IFNAME, dev_name)
+ self._SendNlRequest(RTM_GETLINK, ifinfo)
+ hdr, data = cstruct.Read(self._Recv(), netlink.NLMsgHdr)
+ if hdr.type == RTM_NEWLINK:
+ return cstruct.Read(data, IfinfoMsg)
+ elif hdr.type == netlink.NLMSG_ERROR:
+ error = netlink.NLMsgErr(data).error
+ raise IOError(error, os.strerror(-error))
+ else:
+ raise ValueError("Unknown Netlink Message Type %d" % hdr.type)
+
+ def GetIfIndex(self, dev_name):
+ """Returns the interface index for the specified interface."""
+ ifinfo, _ = self.GetIfinfo(dev_name)
+ return ifinfo.index
+
+ def GetIfaceStats(self, dev_name):
+ """Returns an RtnlLinkStats64 stats object for the specified interface."""
+ _, attrs = self.GetIfinfo(dev_name)
+ attrs = self._ParseAttributes(RTM_NEWLINK, IfinfoMsg, attrs)
+ return attrs["IFLA_STATS64"]
+
+ def GetIfinfoData(self, dev_name):
+ """Returns an IFLA_INFO_DATA dict object for the specified interface."""
+ _, attrs = self.GetIfinfo(dev_name)
+ attrs = self._ParseAttributes(RTM_NEWLINK, IfinfoMsg, attrs)
+ return attrs["IFLA_LINKINFO"]["IFLA_INFO_DATA"]
+
+ def GetRxTxPackets(self, dev_name):
+ stats = self.GetIfaceStats(dev_name)
+ return stats.rx_packets, stats.tx_packets
+
+ def CreateVirtualTunnelInterface(self, dev_name, local_addr, remote_addr,
+ i_key=None, o_key=None, is_update=False):
+ """
+ Create a Virtual Tunnel Interface that provides a proxy interface
+ for IPsec tunnels.
+
+ The VTI Newlink structure is a series of nested netlink
+ attributes following a mostly-ignored 'struct ifinfomsg':
+
+ NLMSGHDR (type=RTM_NEWLINK)
+ |
+ |-{IfinfoMsg}
+ |
+ |-IFLA_IFNAME = <user-provided ifname>
+ |
+ |-IFLA_LINKINFO
+ |
+ |-IFLA_INFO_KIND = "vti"
+ |
+ |-IFLA_INFO_DATA
+ |
+ |-IFLA_VTI_LOCAL = <local addr>
+ |-IFLA_VTI_REMOTE = <remote addr>
+ |-IFLA_VTI_LINK = ????
+ |-IFLA_VTI_OKEY = [outbound mark]
+ |-IFLA_VTI_IKEY = [inbound mark]
+ """
+ family = AF_INET6 if ":" in remote_addr else AF_INET
+
+ ifinfo = IfinfoMsg().Pack()
+ ifinfo += self._NlAttrStr(IFLA_IFNAME, dev_name)
+
+ linkinfo = self._NlAttrStr(IFLA_INFO_KIND,
+ {AF_INET6: "vti6", AF_INET: "vti"}[family])
+
+ ifdata = self._NlAttrIPAddress(IFLA_VTI_LOCAL, family, local_addr)
+ ifdata += self._NlAttrIPAddress(IFLA_VTI_REMOTE, family,
+ remote_addr)
+ if i_key is not None:
+ ifdata += self._NlAttrU32(IFLA_VTI_IKEY, socket.htonl(i_key))
+ if o_key is not None:
+ ifdata += self._NlAttrU32(IFLA_VTI_OKEY, socket.htonl(o_key))
+ linkinfo += self._NlAttr(IFLA_INFO_DATA, ifdata)
+
+ ifinfo += self._NlAttr(IFLA_LINKINFO, linkinfo)
+
+ # Always pass CREATE to prevent _SendNlRequest() from incorrectly
+ # guessing the flags.
+ flags = netlink.NLM_F_CREATE
+ if not is_update:
+ flags |= netlink.NLM_F_EXCL
+ return self._SendNlRequest(RTM_NEWLINK, ifinfo, flags)
+
+ def CreateXfrmInterface(self, dev_name, xfrm_if_id, underlying_ifindex):
+ """Creates an XFRM interface with the specified parameters."""
+ # The netlink attribute structure is essentially identical to the one
+ # for VTI above (q.v).
+ ifdata = self._NlAttrU32(IFLA_XFRM_LINK, underlying_ifindex)
+ ifdata += self._NlAttrU32(IFLA_XFRM_IF_ID, xfrm_if_id)
+
+ linkinfo = self._NlAttrStr(IFLA_INFO_KIND, "xfrm")
+ linkinfo += self._NlAttr(IFLA_INFO_DATA, ifdata)
+
+ msg = IfinfoMsg().Pack()
+ msg += self._NlAttrStr(IFLA_IFNAME, dev_name)
+ msg += self._NlAttr(IFLA_LINKINFO, linkinfo)
+
+ return self._SendNlRequest(RTM_NEWLINK, msg)
+
if __name__ == "__main__":
iproute = IPRoute()
diff --git a/net/test/leak_test.py b/net/test/leak_test.py
index f13eda0..8ef4b41 100755
--- a/net/test/leak_test.py
+++ b/net/test/leak_test.py
@@ -31,7 +31,7 @@
s.bind(("::1", 0))
# Call shutdown on another thread while a recvfrom is in progress.
- net_test.SetSocketTimeout(s, 2000)
+ csocket.SetSocketTimeout(s, 2000)
def ShutdownSocket():
time.sleep(0.5)
self.assertRaisesErrno(ENOTCONN, s.shutdown, SHUT_RDWR)
@@ -47,5 +47,38 @@
self.assertEqual(None, addr)
+class ForceSocketBufferOptionTest(net_test.NetworkTest):
+
+ SO_SNDBUFFORCE = 32
+ SO_RCVBUFFORCE = 33
+
+ def CheckForceSocketBufferOption(self, option, force_option):
+ s = socket(AF_INET6, SOCK_DGRAM, 0)
+
+ # Find the minimum buffer value.
+ s.setsockopt(SOL_SOCKET, option, 0)
+ minbuf = s.getsockopt(SOL_SOCKET, option)
+
+ # Check that the force option works to set reasonable values.
+ val = 4097
+ self.assertGreater(2 * val, minbuf)
+ s.setsockopt(SOL_SOCKET, force_option, val)
+ self.assertEquals(2 * val, s.getsockopt(SOL_SOCKET, option))
+
+ # Check that the force option sets the minimum value instead of a negative
+ # value on integer overflow. Because the kernel multiplies passed-in values
+ # by 2, pick a value that becomes a small negative number if treated as
+ # unsigned.
+ bogusval = 2 ** 31 - val
+ s.setsockopt(SOL_SOCKET, force_option, bogusval)
+ self.assertEquals(minbuf, s.getsockopt(SOL_SOCKET, option))
+
+ def testRcvBufForce(self):
+ self.CheckForceSocketBufferOption(SO_RCVBUF, self.SO_RCVBUFFORCE)
+
+ def testRcvBufForce(self):
+ self.CheckForceSocketBufferOption(SO_SNDBUF, self.SO_SNDBUFFORCE)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/net/test/multinetwork_base.py b/net/test/multinetwork_base.py
index a4ba472..ce653b2 100644
--- a/net/test/multinetwork_base.py
+++ b/net/test/multinetwork_base.py
@@ -29,7 +29,6 @@
from scapy import all as scapy
import csocket
-import cstruct
import iproute
import net_test
@@ -54,6 +53,8 @@
AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
+IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
+IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
@@ -107,9 +108,16 @@
PRIORITY_DEFAULT = 999
PRIORITY_UNREACHABLE = 1000
+ # Actual device routing is more complicated, involving more than one rule
+ # per NetId, but here we make do with just one rule that selects the lower
+ # 16 bits.
+ NETID_FWMASK = 0xffff
+
# For convenience.
IPV4_ADDR = net_test.IPV4_ADDR
IPV6_ADDR = net_test.IPV6_ADDR
+ IPV4_ADDR2 = net_test.IPV4_ADDR2
+ IPV6_ADDR2 = net_test.IPV6_ADDR2
IPV4_PING = net_test.IPV4_PING
IPV6_PING = net_test.IPV6_PING
@@ -167,10 +175,15 @@
@classmethod
def MyAddress(cls, version, netid):
return {4: cls._MyIPv4Address(netid),
- 5: "::ffff:" + cls._MyIPv4Address(netid),
+ 5: cls._MyIPv4Address(netid),
6: cls._MyIPv6Address(netid)}[version]
@classmethod
+ def MySocketAddress(cls, version, netid):
+ addr = cls.MyAddress(version, netid)
+ return "::ffff:" + addr if version == 5 else addr
+
+ @classmethod
def MyLinkLocalAddress(cls, netid):
return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
@@ -197,7 +210,10 @@
@classmethod
def CreateTunInterface(cls, netid):
iface = cls.GetInterfaceName(netid)
- f = open("/dev/net/tun", "r+b")
+ try:
+ f = open("/dev/net/tun", "r+b")
+ except IOError:
+ f = open("/dev/tun", "r+b")
ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
ifr += "\x00" * (40 - len(ifr))
fcntl.ioctl(f, TUNSETIFF, ifr)
@@ -257,7 +273,7 @@
cls.iproute.UidRangeRule(version, is_add, start, end, table,
cls.PRIORITY_UID)
cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
- cls.iproute.FwmarkRule(version, is_add, netid, table,
+ cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
cls.PRIORITY_FWMARK)
# Configure routing and addressing.
@@ -296,6 +312,28 @@
cls.OnlinkPrefixLen(4), ifindex)
@classmethod
+ def SetMarkReflectSysctls(cls, value):
+ """Makes kernel-generated replies use the mark of the original packet."""
+ cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
+ cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
+
+ @classmethod
+ def _SetInboundMarking(cls, netid, iface, is_add):
+ for version in [4, 6]:
+ # Run iptables to set up incoming packet marking.
+ add_del = "-A" if is_add else "-D"
+ iptables = {4: "iptables", 6: "ip6tables"}[version]
+ args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
+ add_del, iface, netid)
+ if net_test.RunIptablesCommand(version, args):
+ raise ConfigurationError("Setup command failed: %s" % args)
+
+ @classmethod
+ def SetInboundMarks(cls, is_add):
+ for netid in cls.tuns:
+ cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
+
+ @classmethod
def SetDefaultNetwork(cls, netid):
table = cls._TableForNetid(netid) if netid else None
for version in [4, 6]:
@@ -376,6 +414,9 @@
cls.loglevel = cls.GetConsoleLogLevel()
cls.SetConsoleLogLevel(net_test.KERN_INFO)
+ # When running on device, don't send connections through FwmarkServer.
+ os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
+
# Uncomment to look around at interface and rule configuration while
# running in the background. (Once the test finishes running, all the
# interfaces and rules are gone.)
@@ -383,6 +424,8 @@
@classmethod
def tearDownClass(cls):
+ del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
+
for version in [4, 6]:
try:
cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
@@ -427,9 +470,18 @@
def GetRemoteAddress(self, version):
return {4: self.IPV4_ADDR,
- 5: "::ffff:" + self.IPV4_ADDR,
+ 5: self.IPV4_ADDR,
6: self.IPV6_ADDR}[version]
+ def GetRemoteSocketAddress(self, version):
+ addr = self.GetRemoteAddress(version)
+ return "::ffff:" + addr if version == 5 else addr
+
+ def GetOtherRemoteSocketAddress(self, version):
+ return {4: self.IPV4_ADDR2,
+ 5: "::ffff:" + self.IPV4_ADDR2,
+ 6: self.IPV6_ADDR2}[version]
+
def SelectInterface(self, s, netid, mode):
if mode == "uid":
os.fchown(s.fileno(), self.UidForNetid(netid), -1)
@@ -454,6 +506,19 @@
return s
+ def RandomNetid(self, exclude=None):
+ """Return a random netid from the list of netids
+
+ Args:
+ exclude: a netid or list of netids that should not be chosen
+ """
+ if exclude is None:
+ exclude = []
+ elif isinstance(exclude, int):
+ exclude = [exclude]
+ diff = [netid for netid in self.NETIDS if netid not in exclude]
+ return random.choice(diff)
+
def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
if netid is not None:
pktinfo = MakePktInfo(version, None, self.ifindices[netid])
@@ -473,6 +538,13 @@
self.ReceiveEtherPacketOn(netid, packet)
def ReadAllPacketsOn(self, netid, include_multicast=False):
+ """Return all queued packets on a netid as a list.
+
+ Args:
+ netid: The netid from which to read packets
+ include_multicast: A boolean, whether to remove multicast packets
+ (default=False)
+ """
packets = []
retries = 0
max_retries = 1
@@ -501,11 +573,13 @@
raise e
return packets
- def InvalidateDstCache(self, version, remoteaddr, netid):
- """Invalidates destination cache entries of sockets to remoteaddr.
+ def InvalidateDstCache(self, version, netid):
+ """Invalidates destination cache entries of sockets on the specified table.
- Creates and then deletes a route pointing to remoteaddr, which invalidates
- the destination cache entries of any sockets connected to remoteaddr.
+ Creates and then deletes a low-priority throw route in the table for the
+ given netid, which invalidates the destination cache entries of any sockets
+ that refer to routes in that table.
+
The fact that this method actually invalidates destination cache entries is
tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
does not re-route sockets when they are remarked, but does re-route them if
@@ -513,16 +587,16 @@
Args:
version: The IP version, 4 or 6.
- remoteaddr: The IP address to temporarily reroute.
- netid: The netid to add/remove the route to.
+ netid: The netid to invalidate dst caches on.
"""
iface = self.GetInterfaceName(netid)
ifindex = self.ifindices[netid]
table = self._TableForNetid(netid)
- nexthop = self._RouterAddress(netid, version)
- plen = {4: 32, 6: 128}[version]
- self.iproute.AddRoute(version, table, remoteaddr, plen, nexthop, ifindex)
- self.iproute.DelRoute(version, table, remoteaddr, plen, nexthop, ifindex)
+ for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
+ self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
+ "default", 0, nexthop=None, dev=None, mark=None,
+ uid=None, route_type=iproute.RTN_THROW,
+ priority=100000)
def ClearTunQueues(self):
# Keep reading packets on all netids until we get no packets on any of them.
@@ -673,3 +747,17 @@
else:
self.ExpectNoPacketsOn(netid, msg)
return None
+
+
+class InboundMarkingTest(MultiNetworkBaseTest):
+ """Class that automatically sets up inbound marking."""
+
+ @classmethod
+ def setUpClass(cls):
+ super(InboundMarkingTest, cls).setUpClass()
+ cls.SetInboundMarks(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.SetInboundMarks(False)
+ super(InboundMarkingTest, cls).tearDownClass()
diff --git a/net/test/multinetwork_test.py b/net/test/multinetwork_test.py
index 0ade759..68d0ef4 100755
--- a/net/test/multinetwork_test.py
+++ b/net/test/multinetwork_test.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ctypes
import errno
import os
import random
@@ -35,118 +36,82 @@
IPV6_FLOWINFO = 11
-IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
-IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
# The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
-MAX_PLEN_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_info_max_plen"
-HAVE_MAX_PLEN = os.path.isfile(MAX_PLEN_SYSCTL)
+# RTPROT_RA is working properly with 4.14
+HAVE_RTPROT_RA = net_test.LINUX_VERSION >= (4, 14, 0)
class ConfigurationError(AssertionError):
pass
-class InboundMarkingTest(multinetwork_base.MultiNetworkBaseTest):
-
- @classmethod
- def _SetInboundMarking(cls, netid, is_add):
- for version in [4, 6]:
- # Run iptables to set up incoming packet marking.
- iface = cls.GetInterfaceName(netid)
- add_del = "-A" if is_add else "-D"
- iptables = {4: "iptables", 6: "ip6tables"}[version]
- args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
- iptables, add_del, iface, netid)
- iptables = "/sbin/" + iptables
- ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
- if ret:
- raise ConfigurationError("Setup command failed: %s" % args)
-
- @classmethod
- def setUpClass(cls):
- super(InboundMarkingTest, cls).setUpClass()
- for netid in cls.tuns:
- cls._SetInboundMarking(netid, True)
-
- @classmethod
- def tearDownClass(cls):
- for netid in cls.tuns:
- cls._SetInboundMarking(netid, False)
- super(InboundMarkingTest, cls).tearDownClass()
-
- @classmethod
- def SetMarkReflectSysctls(cls, value):
- cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
- try:
- cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
- except IOError:
- # This does not exist if we use the version of the patch that uses a
- # common sysctl for IPv4 and IPv6.
- pass
-
-
class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
# How many times to run outgoing packet tests.
ITERATIONS = 5
- def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
+ def CheckPingPacket(self, version, netid, routing_mode, packet):
s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
myaddr = self.MyAddress(version, netid)
+ mysockaddr = self.MySocketAddress(version, netid)
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- s.bind((myaddr, packets.PING_IDENT))
+ s.bind((mysockaddr, packets.PING_IDENT))
net_test.SetSocketTos(s, packets.PING_TOS)
+ dstaddr = self.GetRemoteAddress(version)
+ dstsockaddr = self.GetRemoteSocketAddress(version)
desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
msg = "IPv%d ping: expected %s on %s" % (
version, desc, self.GetInterfaceName(netid))
- s.sendto(packet + packets.PING_PAYLOAD, (dstaddr, 19321))
+ s.sendto(packet + packets.PING_PAYLOAD, (dstsockaddr, 19321))
self.ExpectPacketOn(netid, msg, expected)
- def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
+ def CheckTCPSYNPacket(self, version, netid, routing_mode):
s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
- if version == 6 and dstaddr.startswith("::ffff"):
- version = 4
myaddr = self.MyAddress(version, netid)
+ dstaddr = self.GetRemoteAddress(version)
+ dstsockaddr = self.GetRemoteSocketAddress(version)
desc, expected = packets.SYN(53, version, myaddr, dstaddr,
sport=None, seq=None)
+
# Non-blocking TCP connects always return EINPROGRESS.
- self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
+ self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstsockaddr, 53))
msg = "IPv%s TCP connect: expected %s on %s" % (
version, desc, self.GetInterfaceName(netid))
self.ExpectPacketOn(netid, msg, expected)
s.close()
- def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
+ def CheckUDPPacket(self, version, netid, routing_mode):
s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
- if version == 6 and dstaddr.startswith("::ffff"):
- version = 4
myaddr = self.MyAddress(version, netid)
+ dstaddr = self.GetRemoteAddress(version)
+ dstsockaddr = self.GetRemoteSocketAddress(version)
+
desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
msg = "IPv%s UDP %%s: expected %s on %s" % (
version, desc, self.GetInterfaceName(netid))
- s.sendto(UDP_PAYLOAD, (dstaddr, 53))
+ s.sendto(UDP_PAYLOAD, (dstsockaddr, 53))
self.ExpectPacketOn(netid, msg % "sendto", expected)
# IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
if routing_mode != "ucast_oif":
- s.connect((dstaddr, 53))
+ s.connect((dstsockaddr, 53))
s.send(UDP_PAYLOAD)
self.ExpectPacketOn(netid, msg % "connect/send", expected)
s.close()
- def CheckRawGrePacket(self, version, netid, routing_mode, dstaddr):
+ def CheckRawGrePacket(self, version, netid, routing_mode):
s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
inner_version = {4: 6, 6: 4}[version]
@@ -158,6 +123,7 @@
# A GRE header can be as simple as two zero bytes and the ethertype.
packet = struct.pack("!i", ethertype) + inner
myaddr = self.MyAddress(version, netid)
+ dstaddr = self.GetRemoteAddress(version)
s.sendto(packet, (dstaddr, IPPROTO_GRE))
desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
@@ -166,34 +132,30 @@
self.ExpectPacketOn(netid, msg, expected)
def CheckOutgoingPackets(self, routing_mode):
- v4addr = self.IPV4_ADDR
- v6addr = self.IPV6_ADDR
- v4mapped = "::ffff:" + v4addr
-
for _ in xrange(self.ITERATIONS):
for netid in self.tuns:
- self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
+ self.CheckPingPacket(4, netid, routing_mode, self.IPV4_PING)
# Kernel bug.
if routing_mode != "oif":
- self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
+ self.CheckPingPacket(6, netid, routing_mode, self.IPV6_PING)
# IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
if routing_mode != "ucast_oif":
- self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
- self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
- self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
+ self.CheckTCPSYNPacket(4, netid, routing_mode)
+ self.CheckTCPSYNPacket(6, netid, routing_mode)
+ self.CheckTCPSYNPacket(5, netid, routing_mode)
- self.CheckUDPPacket(4, netid, routing_mode, v4addr)
- self.CheckUDPPacket(6, netid, routing_mode, v6addr)
- self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
+ self.CheckUDPPacket(4, netid, routing_mode)
+ self.CheckUDPPacket(6, netid, routing_mode)
+ self.CheckUDPPacket(5, netid, routing_mode)
# Creating raw sockets on non-root UIDs requires properly setting
# capabilities, which is hard to do from Python.
# IP_UNICAST_IF is not supported on raw sockets.
if routing_mode not in ["uid", "ucast_oif"]:
- self.CheckRawGrePacket(4, netid, routing_mode, v4addr)
- self.CheckRawGrePacket(6, netid, routing_mode, v6addr)
+ self.CheckRawGrePacket(4, netid, routing_mode)
+ self.CheckRawGrePacket(6, netid, routing_mode)
def testMarkRouting(self):
"""Checks that socket marking selects the right outgoing interface."""
@@ -266,7 +228,7 @@
if prevnetid:
ExpectSendUsesNetid(prevnetid)
# ... until we invalidate it.
- self.InvalidateDstCache(version, dstaddr, prevnetid)
+ self.InvalidateDstCache(version, prevnetid)
ExpectSendUsesNetid(netid)
else:
ExpectSendUsesNetid(netid)
@@ -337,10 +299,14 @@
net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
# Specify some arbitrary options.
+ # We declare the flowlabel as ctypes.c_uint32 because on a 32-bit
+ # Python interpreter an integer greater than 0x7fffffff (such as our
+ # chosen flowlabel after being passed through htonl) is converted to
+ # long, and _MakeMsgControl doesn't know what to do with longs.
cmsgs = [
(net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
(net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
- (net_test.SOL_IPV6, IPV6_FLOWINFO, int(htonl(0xbeef))),
+ (net_test.SOL_IPV6, IPV6_FLOWINFO, ctypes.c_uint(htonl(0xbeef))),
]
else:
# Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
@@ -368,7 +334,7 @@
self.CheckPktinfoRouting(6)
-class MarkTest(InboundMarkingTest):
+class MarkTest(multinetwork_base.InboundMarkingTest):
def CheckReflection(self, version, gen_packet, gen_reply):
"""Checks that replies go out on the same interface as the original.
@@ -430,7 +396,7 @@
self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
-class TCPAcceptTest(InboundMarkingTest):
+class TCPAcceptTest(multinetwork_base.InboundMarkingTest):
MODE_BINDTODEVICE = "SO_BINDTODEVICE"
MODE_INCOMING_MARK = "incoming mark"
@@ -458,7 +424,7 @@
establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
# Attempt to confuse the kernel.
- self.InvalidateDstCache(version, remoteaddr, netid)
+ self.InvalidateDstCache(version, netid)
self.ReceivePacketOn(netid, establishing_ack)
@@ -475,21 +441,21 @@
payload=UDP_PAYLOAD)
s.send(UDP_PAYLOAD)
self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
- self.InvalidateDstCache(version, remoteaddr, netid)
+ self.InvalidateDstCache(version, netid)
# Keep up our end of the conversation.
ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
- self.InvalidateDstCache(version, remoteaddr, netid)
+ self.InvalidateDstCache(version, netid)
self.ReceivePacketOn(netid, ack)
mark = self.GetSocketMark(s)
finally:
- self.InvalidateDstCache(version, remoteaddr, netid)
+ self.InvalidateDstCache(version, netid)
s.close()
- self.InvalidateDstCache(version, remoteaddr, netid)
+ self.InvalidateDstCache(version, netid)
if mode == self.MODE_INCOMING_MARK:
- self.assertEquals(netid, mark,
+ self.assertEquals(netid, mark & self.NETID_FWMASK,
msg + ": Accepted socket: Expected mark %d, got %d" % (
netid, mark))
elif mode != self.MODE_EXPLICIT_MARK:
@@ -576,19 +542,57 @@
def testIPv6ExplicitMark(self):
self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
+@unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
+ "need support for per-table autoconf")
class RIOTest(multinetwork_base.MultiNetworkBaseTest):
+ """Test for IPv6 RFC 4191 route information option
+
+ Relevant kernel commits:
+ upstream:
+ f104a567e673 ipv6: use rt6_get_dflt_router to get default router in rt6_route_rcv
+ bbea124bc99d net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
+
+ android-4.9:
+ d860b2e8a7f1 FROMLIST: net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs
+
+ android-4.4:
+ e953f89b8563 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
+
+ android-4.1:
+ 84f2f47716cd net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
+
+ android-3.18:
+ 65f8936934fa net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
+
+ android-3.10:
+ 161e88ebebc7 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
+
+ """
def setUp(self):
+ super(RIOTest, self).setUp()
self.NETID = random.choice(self.NETIDS)
self.IFACE = self.GetInterfaceName(self.NETID)
+ # return min/max plen to default values before each test case
+ self.SetAcceptRaRtInfoMinPlen(0)
+ self.SetAcceptRaRtInfoMaxPlen(0)
def GetRoutingTable(self):
return self._TableForNetid(self.NETID)
+ def SetAcceptRaRtInfoMinPlen(self, plen):
+ self.SetSysctl(
+ "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen"
+ % self.IFACE, plen)
+
+ def GetAcceptRaRtInfoMinPlen(self):
+ return int(self.GetSysctl(
+ "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen" % self.IFACE))
+
def SetAcceptRaRtInfoMaxPlen(self, plen):
self.SetSysctl(
"/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen"
- % self.IFACE, str(plen))
+ % self.IFACE, plen)
def GetAcceptRaRtInfoMaxPlen(self):
return int(self.GetSysctl(
@@ -614,18 +618,47 @@
def GetRouteExpiration(self, route):
return float(route['RTA_CACHEINFO'].expires) / 100.0
- @unittest.skipUnless(HAVE_MAX_PLEN and multinetwork_base.HAVE_AUTOCONF_TABLE,
- "need support for RIO and per-table autoconf")
+ def AssertExpirationInRange(self, routes, lifetime, epsilon):
+ self.assertTrue(routes)
+ found = False
+ # Assert that at least one route in routes has the expected lifetime
+ for route in routes:
+ expiration = self.GetRouteExpiration(route)
+ if expiration < lifetime - epsilon:
+ continue
+ if expiration > lifetime + epsilon:
+ continue
+ found = True
+ self.assertTrue(found)
+
+ def DelRA6(self, prefix, plen):
+ version = 6
+ netid = self.NETID
+ table = self._TableForNetid(netid)
+ router = self._RouterAddress(netid, version)
+ ifindex = self.ifindices[netid]
+ # We actually want to specify RTPROT_RA, however an upstream
+ # kernel bug causes RAs to be installed with RTPROT_BOOT.
+ if HAVE_RTPROT_RA:
+ rtprot = iproute.RTPROT_RA
+ else:
+ rtprot = iproute.RTPROT_BOOT
+ self.iproute._Route(version, rtprot, iproute.RTM_DELROUTE,
+ table, prefix, plen, router, ifindex, None, None)
+
+ def testSetAcceptRaRtInfoMinPlen(self):
+ for plen in xrange(-1, 130):
+ self.SetAcceptRaRtInfoMinPlen(plen)
+ self.assertEquals(plen, self.GetAcceptRaRtInfoMinPlen())
+
def testSetAcceptRaRtInfoMaxPlen(self):
for plen in xrange(-1, 130):
self.SetAcceptRaRtInfoMaxPlen(plen)
self.assertEquals(plen, self.GetAcceptRaRtInfoMaxPlen())
- @unittest.skipUnless(HAVE_MAX_PLEN and multinetwork_base.HAVE_AUTOCONF_TABLE,
- "need support for RIO and per-table autoconf")
def testZeroRtLifetime(self):
PREFIX = "2001:db8:8901:2300::"
- RTLIFETIME = 7372
+ RTLIFETIME = 73500
PLEN = 56
PRF = 0
self.SetAcceptRaRtInfoMaxPlen(PLEN)
@@ -639,52 +672,87 @@
time.sleep(0.01)
self.assertFalse(self.FindRoutesWithDestination(PREFIX))
- @unittest.skipUnless(HAVE_MAX_PLEN and multinetwork_base.HAVE_AUTOCONF_TABLE,
- "need support for RIO and per-table autoconf")
- def testMaxPrefixLenRejection(self):
- PREFIX = "2001:db8:8901:2345::"
- RTLIFETIME = 7372
+ def testMinPrefixLenRejection(self):
+ PREFIX = "2001:db8:8902:2345::"
+ RTLIFETIME = 70372
PRF = 0
- for plen in xrange(0, 64):
+ # sweep from high to low to avoid spurious failures from late arrivals.
+ for plen in xrange(130, 1, -1):
+ self.SetAcceptRaRtInfoMinPlen(plen)
+ # RIO with plen < min_plen should be ignored
+ self.SendRIO(RTLIFETIME, plen - 1, PREFIX, PRF)
+ # Give the kernel time to notice our RAs
+ time.sleep(0.1)
+ # Expect no routes
+ routes = self.FindRoutesWithDestination(PREFIX)
+ self.assertFalse(routes)
+
+ def testMaxPrefixLenRejection(self):
+ PREFIX = "2001:db8:8903:2345::"
+ RTLIFETIME = 73078
+ PRF = 0
+ # sweep from low to high to avoid spurious failures from late arrivals.
+ for plen in xrange(-1, 128, 1):
self.SetAcceptRaRtInfoMaxPlen(plen)
# RIO with plen > max_plen should be ignored
self.SendRIO(RTLIFETIME, plen + 1, PREFIX, PRF)
- # Give the kernel time to notice our RA
- time.sleep(0.01)
- routes = self.FindRoutesWithDestination(PREFIX)
- self.assertFalse(routes)
+ # Give the kernel time to notice our RAs
+ time.sleep(0.1)
+ # Expect no routes
+ routes = self.FindRoutesWithDestination(PREFIX)
+ self.assertFalse(routes)
- @unittest.skipUnless(HAVE_MAX_PLEN and multinetwork_base.HAVE_AUTOCONF_TABLE,
- "need support for RIO and per-table autoconf")
+ def testSimpleAccept(self):
+ PREFIX = "2001:db8:8904:2345::"
+ RTLIFETIME = 9993
+ PRF = 0
+ PLEN = 56
+ self.SetAcceptRaRtInfoMinPlen(48)
+ self.SetAcceptRaRtInfoMaxPlen(64)
+ self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
+ # Give the kernel time to notice our RA
+ time.sleep(0.01)
+ routes = self.FindRoutesWithGateway()
+ self.AssertExpirationInRange(routes, RTLIFETIME, 1)
+ self.DelRA6(PREFIX, PLEN)
+
+ def testEqualMinMaxAccept(self):
+ PREFIX = "2001:db8:8905:2345::"
+ RTLIFETIME = 6326
+ PLEN = 21
+ PRF = 0
+ self.SetAcceptRaRtInfoMinPlen(PLEN)
+ self.SetAcceptRaRtInfoMaxPlen(PLEN)
+ self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
+ # Give the kernel time to notice our RA
+ time.sleep(0.01)
+ routes = self.FindRoutesWithGateway()
+ self.AssertExpirationInRange(routes, RTLIFETIME, 1)
+ self.DelRA6(PREFIX, PLEN)
+
def testZeroLengthPrefix(self):
- PREFIX = "::"
+ PREFIX = "2001:db8:8906:2345::"
RTLIFETIME = self.RA_VALIDITY * 2
PLEN = 0
PRF = 0
# Max plen = 0 still allows default RIOs!
self.SetAcceptRaRtInfoMaxPlen(PLEN)
+ self.SendRA(self.NETID)
+ # Give the kernel time to notice our RA
+ time.sleep(0.01)
default = self.FindRoutesWithGateway()
- self.assertTrue(default)
- self.assertLess(self.GetRouteExpiration(default[0]), self.RA_VALIDITY)
+ self.AssertExpirationInRange(default, self.RA_VALIDITY, 1)
# RIO with prefix length = 0, should overwrite default route lifetime
# note that the RIO lifetime overwrites the RA lifetime.
self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
# Give the kernel time to notice our RA
time.sleep(0.01)
default = self.FindRoutesWithGateway()
- self.assertTrue(default)
- if net_test.LINUX_VERSION > (3, 12, 0):
- # Vanilla linux earlier than 3.13 handles RIOs with zero length prefixes
- # incorrectly. There's nothing useful to assert other than the existence
- # of a default route.
- # TODO: remove this condition after pulling bullhead/angler backports to
- # other 3.10 flavors.
- self.assertGreater(self.GetRouteExpiration(default[0]), self.RA_VALIDITY)
+ self.AssertExpirationInRange(default, RTLIFETIME, 1)
+ self.DelRA6(PREFIX, PLEN)
- @unittest.skipUnless(HAVE_MAX_PLEN and multinetwork_base.HAVE_AUTOCONF_TABLE,
- "need support for RIO and per-table autoconf")
def testManyRIOs(self):
- RTLIFETIME = 6809
+ RTLIFETIME = 68012
PLEN = 56
PRF = 0
COUNT = 1000
@@ -694,11 +762,11 @@
for i in xrange(0, COUNT):
prefix = "2001:db8:%x:1100::" % i
self.SendRIO(RTLIFETIME, PLEN, prefix, PRF)
+ time.sleep(0.1)
self.assertEquals(COUNT + baseline, self.CountRoutes())
- # Use lifetime = 0 to cleanup all previously announced RIOs.
for i in xrange(0, COUNT):
prefix = "2001:db8:%x:1100::" % i
- self.SendRIO(0, PLEN, prefix, PRF)
+ self.DelRA6(prefix, PLEN)
# Expect that we can return to baseline config without lingering routes.
self.assertEquals(baseline, self.CountRoutes())
@@ -788,7 +856,7 @@
self.assertLess(num_routes, GetNumRoutes())
-class PMTUTest(InboundMarkingTest):
+class PMTUTest(multinetwork_base.InboundMarkingTest):
PAYLOAD_SIZE = 1400
dstaddrs = set()
@@ -860,7 +928,7 @@
# If this is a connected socket, make sure the socket MTU was set.
# Note that in IPv4 this only started working in Linux 3.6!
if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
- self.assertEquals(1280, self.GetSocketMTU(version, s))
+ self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s))
s.close()
@@ -870,7 +938,7 @@
# here we use a mark for simplicity.
s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
s2.connect((dstaddr, 1234))
- self.assertEquals(1280, self.GetSocketMTU(version, s2))
+ self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s2))
# Also check the MTU reported by ip route get, this time using the oif.
routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
@@ -879,7 +947,7 @@
rtmsg, attributes = route
self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
metrics = attributes["RTA_METRICS"]
- self.assertEquals(metrics["RTAX_MTU"], 1280)
+ self.assertEquals(packets.PTB_MTU, metrics["RTAX_MTU"])
def testIPv4BasicPMTU(self):
"""Tests IPv4 path MTU discovery.
@@ -1024,9 +1092,10 @@
self.iproute.UidRangeRule, version, False, start, end, table,
priority)
+ fwmask = 0xfefefefe
try:
# Create a rule without a UID range.
- self.iproute.FwmarkRule(version, True, 300, 301, priority + 1)
+ self.iproute.FwmarkRule(version, True, 300, fwmask, 301, priority + 1)
# Check it doesn't have a UID range.
rules = self.GetRulesAtPriority(version, priority + 1)
@@ -1035,7 +1104,7 @@
self.assertIn("FRA_TABLE", attributes)
self.assertNotIn("FRA_UID_RANGE", attributes)
finally:
- self.iproute.FwmarkRule(version, False, 300, 301, priority + 1)
+ self.iproute.FwmarkRule(version, False, 300, fwmask, 301, priority + 1)
# Test that EEXIST worksfor UID range rules too. This behaviour was only
# added in 4.8.
@@ -1067,6 +1136,16 @@
def testIPv6GetAndSetRules(self):
self.CheckGetAndSetRules(6)
+ @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not backported")
+ def testDeleteErrno(self):
+ for version in [4, 6]:
+ table = self._Random()
+ priority = self._Random()
+ self.assertRaisesErrno(
+ errno.EINVAL,
+ self.iproute.UidRangeRule, version, False, 100, 0xffffffff, table,
+ priority)
+
def ExpectNoRoute(self, addr, oif, mark, uid):
# The lack of a route may be either an error, or an unreachable route.
try:
@@ -1074,7 +1153,7 @@
rtmsg, _ = routes[0]
self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
except IOError, e:
- if int(e.errno) != -int(errno.ENETUNREACH):
+ if int(e.errno) != int(errno.ENETUNREACH):
raise e
def ExpectRoute(self, addr, oif, mark, uid):
@@ -1129,6 +1208,7 @@
class RulesTest(net_test.NetworkTest):
RULE_PRIORITY = 99999
+ FWMASK = 0xffffffff
def setUp(self):
self.iproute = iproute.IPRoute()
@@ -1144,12 +1224,12 @@
# Add rules with mark 300 pointing at tables 301 and 302.
# This checks for a kernel bug where deletion request for tables > 256
# ignored the table.
- self.iproute.FwmarkRule(version, True, 300, 301,
+ self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 301,
priority=self.RULE_PRIORITY)
- self.iproute.FwmarkRule(version, True, 300, 302,
+ self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 302,
priority=self.RULE_PRIORITY)
# Delete rule with mark 300 pointing at table 302.
- self.iproute.FwmarkRule(version, False, 300, 302,
+ self.iproute.FwmarkRule(version, False, 300, self.FWMASK, 302,
priority=self.RULE_PRIORITY)
# Check that the rule pointing at table 301 is still around.
attributes = [a for _, a in self.iproute.DumpRules(version)
diff --git a/net/test/neighbour_test.py b/net/test/neighbour_test.py
index 24b434b..caf2e6e 100755
--- a/net/test/neighbour_test.py
+++ b/net/test/neighbour_test.py
@@ -17,12 +17,12 @@
import errno
import random
from socket import * # pylint: disable=wildcard-import
-import subprocess
import time
import unittest
from scapy import all as scapy
+import csocket
import multinetwork_base
import net_test
@@ -87,14 +87,14 @@
self.netid = random.choice(self.tuns.keys())
self.ifindex = self.ifindices[self.netid]
- def GetNeighbour(self, addr):
- version = 6 if ":" in addr else 4
- for msg, args in self.iproute.DumpNeighbours(version):
+ def GetNeighbour(self, addr, ifindex):
+ version = csocket.AddressVersion(addr)
+ for msg, args in self.iproute.DumpNeighbours(version, ifindex):
if args["NDA_DST"] == addr:
return msg, args
def GetNdEntry(self, addr):
- return self.GetNeighbour(addr)
+ return self.GetNeighbour(addr, self.ifindex)
def CheckNoNdEvents(self):
self.assertRaisesErrno(errno.EAGAIN, self.sock.recvfrom, 4096, MSG_PEEK)
@@ -115,7 +115,7 @@
self.assertEquals(attrs[name], actual_attrs[name])
def ExpectProbe(self, is_unicast, addr):
- version = 6 if ":" in addr else 4
+ version = csocket.AddressVersion(addr)
if version == 6:
llsrc = self.MyMacAddress(self.netid)
if is_unicast:
@@ -144,7 +144,7 @@
def ReceiveUnicastAdvertisement(self, addr, mac, srcaddr=None, dstaddr=None,
S=1, O=0, R=1):
- version = 6 if ":" in addr else 4
+ version = csocket.AddressVersion(addr)
if srcaddr is None:
srcaddr = addr
if dstaddr is None:
diff --git a/net/test/net_test.py b/net/test/net_test.py
index b469009..1c7f32f 100755
--- a/net/test/net_test.py
+++ b/net/test/net_test.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import errno
import fcntl
import os
import random
@@ -41,6 +40,7 @@
SO_MARK = 36
SO_PROTOCOL = 38
SO_DOMAIN = 39
+SO_COOKIE = 57
ETH_P_IP = 0x0800
ETH_P_IPV6 = 0x86dd
@@ -66,7 +66,9 @@
IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
IPV4_ADDR = "8.8.8.8"
+IPV4_ADDR2 = "8.8.4.4"
IPV6_ADDR = "2001:4860:4860::8888"
+IPV6_ADDR2 = "2001:4860:4860::8844"
IPV6_SEQ_DGRAM_HEADER = (" sl "
"local_address "
@@ -74,6 +76,8 @@
"st tx_queue rx_queue tr tm->when retrnsmt"
" uid timeout inode ref pointer drops\n")
+UDP_HDR_LEN = 8
+
# Arbitrary packet payload.
UDP_PAYLOAD = str(scapy.DNS(rd=1,
id=random.randint(0, 65535),
@@ -89,11 +93,25 @@
LINUX_VERSION = csocket.LinuxVersion()
-def SetSocketTimeout(sock, ms):
- s = ms / 1000
- us = (ms % 1000) * 1000
- sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
+def GetWildcardAddress(version):
+ return {4: "0.0.0.0", 6: "::"}[version]
+def GetIpHdrLength(version):
+ return {4: 20, 6: 40}[version]
+
+def GetAddressFamily(version):
+ return {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+
+
+def AddressLengthBits(version):
+ return {4: 32, 6: 128}[version]
+
+def GetAddressVersion(address):
+ if ":" not in address:
+ return 4
+ if address.startswith("::ffff"):
+ return 5
+ return 6
def SetSocketTos(s, tos):
level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
@@ -109,7 +127,7 @@
# Convenience functions to create sockets.
def Socket(family, sock_type, protocol):
s = socket(family, sock_type, protocol)
- SetSocketTimeout(s, 5000)
+ csocket.SetSocketTimeout(s, 5000)
return s
@@ -189,14 +207,14 @@
def GetInterfaceIndex(ifname):
- s = IPv4PingSocket()
+ s = UDPSocket(AF_INET)
ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
def SetInterfaceHWAddr(ifname, hwaddr):
- s = IPv4PingSocket()
+ s = UDPSocket(AF_INET)
hwaddr = hwaddr.replace(":", "")
hwaddr = hwaddr.decode("hex")
if len(hwaddr) != 6:
@@ -206,7 +224,7 @@
def SetInterfaceState(ifname, up):
- s = IPv4PingSocket()
+ s = UDPSocket(AF_INET)
ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
_, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
@@ -322,6 +340,13 @@
# Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
+def RunIptablesCommand(version, args):
+ iptables = {4: "iptables", 6: "ip6tables"}[version]
+ iptables_path = "/sbin/" + iptables
+ if not os.access(iptables_path, os.X_OK):
+ iptables_path = "/system/bin/" + iptables
+ return os.spawnvp(os.P_WAIT, iptables_path, [iptables_path] + args.split(" "))
+
# Determine network configuration.
try:
GetDefaultRoute(version=4)
@@ -335,36 +360,59 @@
except ValueError:
HAVE_IPV6 = False
-
-CONTINUOUS_BUILD = re.search("net_test_mode=builder",
- open("/proc/cmdline").read())
-
-
-class RunAsUid(object):
+class RunAsUidGid(object):
"""Context guard to run a code block as a given UID."""
- def __init__(self, uid):
+ def __init__(self, uid, gid):
self.uid = uid
+ self.gid = gid
def __enter__(self):
if self.uid:
- self.saved_uid = os.geteuid()
+ self.saved_uids = os.getresuid()
self.saved_groups = os.getgroups()
- if self.uid:
- os.setgroups(self.saved_groups + [AID_INET])
- os.seteuid(self.uid)
+ os.setgroups(self.saved_groups + [AID_INET])
+ os.setresuid(self.uid, self.uid, self.saved_uids[0])
+ if self.gid:
+ self.saved_gid = os.getgid()
+ os.setgid(self.gid)
def __exit__(self, unused_type, unused_value, unused_traceback):
if self.uid:
- os.seteuid(self.saved_uid)
+ os.setresuid(*self.saved_uids)
os.setgroups(self.saved_groups)
+ if self.gid:
+ os.setgid(self.saved_gid)
+class RunAsUid(RunAsUidGid):
+ """Context guard to run a code block as a given GID and UID."""
+
+ def __init__(self, uid):
+ RunAsUidGid.__init__(self, uid, 0)
class NetworkTest(unittest.TestCase):
- def assertRaisesErrno(self, err_num, f, *args):
+ def assertRaisesErrno(self, err_num, f=None, *args):
+ """Test that the system returns an errno error.
+
+ This works similarly to unittest.TestCase.assertRaises. You can call it as
+ an assertion, or use it as a context manager.
+ e.g.
+ self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
+ or
+ with self.assertRaisesErrno(errno.ENOENT):
+ do_things(arg1, arg2)
+
+ Args:
+ err_num: an errno constant
+ f: (optional) A callable that should result in error
+ *args: arguments passed to f
+ """
msg = os.strerror(err_num)
- self.assertRaisesRegexp(EnvironmentError, msg, f, *args)
+ if f is None:
+ return self.assertRaisesRegexp(EnvironmentError, msg)
+ else:
+ self.assertRaisesRegexp(EnvironmentError, msg, f, *args)
def ReadProcNetSocket(self, protocol):
# Read file.
@@ -384,7 +432,7 @@
if protocol.startswith("tcp"):
# Real sockets have 5 extra numbers, timewait sockets have none.
- end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+|)$"
+ end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+)$"
elif re.match("icmp|udp|raw", protocol):
# Drops.
end_regexp = " +([0-9]+) *$"
@@ -409,8 +457,11 @@
# TODO: consider returning a dict or namedtuple instead.
out = []
for line in lines:
+ m = regexp.match(line)
+ if m is None:
+ raise ValueError("Failed match on [%s]" % line)
(_, src, dst, state, mem,
- _, _, uid, _, _, refcnt, _, extra) = regexp.match(line).groups()
+ _, _, uid, _, _, refcnt, _, extra) = m.groups()
out.append([src, dst, state, mem, uid, refcnt, extra])
return out
diff --git a/net/test/net_test.sh b/net/test/net_test.sh
index bade6de..72c67a9 100755
--- a/net/test/net_test.sh
+++ b/net/test/net_test.sh
@@ -1,4 +1,135 @@
#!/bin/bash
+if [[ -n "${verbose}" ]]; then
+ echo 'Current working directory:'
+ echo " - according to builtin: [$(pwd)]"
+ echo " - according to /bin/pwd: [$(/bin/pwd)]"
+ echo
+
+ echo 'Shell environment:'
+ env
+ echo
+
+ echo -n "net_test.sh (pid $$, parent ${PPID}, tty $(tty)) running [$0] with args:"
+ for arg in "$@"; do
+ echo -n " [${arg}]"
+ done
+ echo
+ echo
+fi
+
+if [[ "$(tty)" == '/dev/console' ]]; then
+ ARCH="$(uname -m)"
+ # Underscore is illegal in hostname, replace with hyphen
+ ARCH="${ARCH//_/-}"
+
+ # setsid + /dev/tty{,AMA,S}0 allows bash's job control to work, ie. Ctrl+C/Z
+ if [[ -c '/dev/tty0' ]]; then
+ # exists in UML, does not exist on graphics/vga/curses-less QEMU
+ CON='/dev/tty0'
+ hostname "uml-${ARCH}"
+ elif [[ -c '/dev/ttyAMA0' ]]; then
+ # Qemu for arm (note: /dev/ttyS0 also exists for exitcode)
+ CON='/dev/ttyAMA0'
+ hostname "qemu-${ARCH}"
+ elif [[ -c '/dev/ttyS0' ]]; then
+ # Qemu for x86 (note: /dev/ttyS1 also exists for exitcode)
+ CON='/dev/ttyS0'
+ hostname "qemu-${ARCH}"
+ else
+ # Can't figure it out, job control won't work, tough luck
+ echo 'Unable to figure out proper console - job control will not work.' >&2
+ CON=''
+ hostname "local-${ARCH}"
+ fi
+
+ unset ARCH
+
+ echo -n "$(hostname): Currently tty[/dev/console], but it should be [${CON}]..."
+
+ if [[ -n "${CON}" ]]; then
+ # Redirect std{in,out,err} to the console equivalent tty
+ # which actually supports all standard tty ioctls
+ exec <"${CON}" >&"${CON}"
+
+ # Bash wants to be session leader, hence need for setsid
+ echo " re-executing..."
+ exec /usr/bin/setsid "$0" "$@"
+ # If the above exec fails, we just fall through...
+ # (this implies failure to *find* setsid, not error return from bash,
+ # in practice due to image construction this cannot happen)
+ else
+ echo
+ fi
+
+ # In case we fall through, clean up
+ unset CON
+fi
+
+if [[ -n "${verbose}" ]]; then
+ echo 'TTY settings:'
+ stty
+ echo
+
+ echo 'TTY settings (verbose):'
+ stty -a
+ echo
+
+ echo 'Restoring TTY sanity...'
+fi
+
+stty sane
+stty 115200
+[[ -z "${console_cols}" ]] || stty columns "${console_cols}"
+[[ -z "${console_rows}" ]] || stty rows "${console_rows}"
+
+if [[ -n "${verbose}" ]]; then
+ echo
+
+ echo 'TTY settings:'
+ stty
+ echo
+
+ echo 'TTY settings (verbose):'
+ stty -a
+ echo
+fi
+
+# By the time we get here we should have a sane console:
+# - 115200 baud rate
+# - appropriate (and known) width and height (note: this assumes
+# that the terminal doesn't get further resized)
+# - it is no longer /dev/console, so job control should function
+# (this means working ctrl+c [abort] and ctrl+z [suspend])
+
+
+# This defaults to 60 which is needlessly long during boot
+# (we will reset it back to the default later)
+echo 0 > /proc/sys/kernel/random/urandom_min_reseed_secs
+
+if [[ -n "${entropy}" ]]; then
+ echo "adding entropy from hex string [${entropy}]" >&2
+
+ # In kernel/include/uapi/linux/random.h RNDADDENTROPY is defined as
+ # _IOW('R', 0x03, int[2]) =(R is 0x52)= 0x40085203 = 1074287107
+ /usr/bin/python 3>/dev/random <<EOF
+import fcntl, struct
+rnd = '${entropy}'.decode('base64')
+fcntl.ioctl(3, 0x40085203, struct.pack('ii', len(rnd) * 8, len(rnd)) + rnd)
+EOF
+
+fi
+
+# Make sure the urandom pool has a chance to initialize before we reset
+# the reseed timer back to 60 seconds. One timer tick should be enough.
+sleep 1.1
+
+# By this point either 'random: crng init done' (newer kernels)
+# or 'random: nonblocking pool is initialized' (older kernels)
+# should have been printed out to dmesg/console.
+
+# Reset it back to boot time default
+echo 60 > /proc/sys/kernel/random/urandom_min_reseed_secs
+
# In case IPv6 is compiled as a module.
[ -f /proc/net/if_inet6 ] || insmod $DIR/kernel/net-next/net/ipv6/ipv6.ko
@@ -22,6 +153,6 @@
echo -e "Running $net_test $net_test_args\n"
$net_test $net_test_args
-# Write exit code of net_test to /proc/exitcode so that the builder can use it
+# Write exit code of net_test to a file so that the builder can use it
# to signal failure if any tests fail.
-echo $? >/proc/exitcode
+echo $? >$net_test_exitcode
diff --git a/net/test/netlink.py b/net/test/netlink.py
index ee57b20..ceb547b 100644
--- a/net/test/netlink.py
+++ b/net/test/netlink.py
@@ -18,14 +18,19 @@
# pylint: disable=g-bad-todo
-import errno
import os
import socket
import struct
import sys
import cstruct
+import util
+### Base netlink constants. See include/uapi/linux/netlink.h.
+NETLINK_ROUTE = 0
+NETLINK_SOCK_DIAG = 4
+NETLINK_XFRM = 6
+NETLINK_GENERIC = 16
# Request constants.
NLM_F_REQUEST = 1
@@ -48,11 +53,9 @@
# Alignment / padding.
NLA_ALIGNTO = 4
-
-def PaddedLength(length):
- # TODO: This padding is probably overly simplistic.
- return NLA_ALIGNTO * ((length / NLA_ALIGNTO) + (length % NLA_ALIGNTO != 0))
-
+# List of attributes that can appear more than once in a given netlink message.
+# These can appear more than once but don't seem to contain any data.
+DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"]
class NetlinkSocket(object):
"""A basic netlink socket object."""
@@ -69,10 +72,17 @@
def _NlAttr(self, nla_type, data):
datalen = len(data)
# Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
- padding = "\x00" * (PaddedLength(datalen) - datalen)
+ padding = "\x00" * util.GetPadLength(NLA_ALIGNTO, datalen)
nla_len = datalen + len(NLAttr)
return NLAttr((nla_len, nla_type)).Pack() + data + padding
+ def _NlAttrIPAddress(self, nla_type, family, address):
+ return self._NlAttr(nla_type, socket.inet_pton(family, address))
+
+ def _NlAttrStr(self, nla_type, value):
+ value = value + "\x00"
+ return self._NlAttr(nla_type, value.encode("UTF-8"))
+
def _NlAttrU32(self, nla_type, value):
return self._NlAttr(nla_type, struct.pack("=I", value))
@@ -91,7 +101,18 @@
"""No-op, nonspecific version of decode."""
return nla_type, nla_data
- def _ParseAttributes(self, command, msg, data):
+ def _ReadNlAttr(self, data):
+ # Read the nlattr header.
+ nla, data = cstruct.Read(data, NLAttr)
+
+ # Read the data.
+ datalen = nla.nla_len - len(nla)
+ padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen
+ nla_data, data = data[:datalen], data[padded_len:]
+
+ return nla, nla_data, data
+
+ def _ParseAttributes(self, command, msg, data, nested=0):
"""Parses and decodes netlink attributes.
Takes a block of NLAttr data structures, decodes them using Decode, and
@@ -101,6 +122,7 @@
command: An integer, the rtnetlink command being carried out.
msg: A Struct, the type of the data after the netlink header.
data: A byte string containing a sequence of NLAttr data structures.
+ nested: An integer, how deep we're currently nested.
Returns:
A dictionary mapping attribute types (integers) to decoded values.
@@ -110,32 +132,31 @@
"""
attributes = {}
while data:
- # Read the nlattr header.
- nla, data = cstruct.Read(data, NLAttr)
-
- # Read the data.
- datalen = nla.nla_len - len(nla)
- padded_len = PaddedLength(nla.nla_len) - len(nla)
- nla_data, data = data[:datalen], data[padded_len:]
+ nla, nla_data, data = self._ReadNlAttr(data)
# If it's an attribute we know about, try to decode it.
nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
- # We only support unique attributes for now, except for INET_DIAG_NONE,
- # which can appear more than once but doesn't seem to contain any data.
- if nla_name in attributes and nla_name != "INET_DIAG_NONE":
+ if nla_name in attributes and nla_name not in DUP_ATTRS_OK:
raise ValueError("Duplicate attribute %s" % nla_name)
attributes[nla_name] = nla_data
- self._Debug(" %s" % str((nla_name, nla_data)))
+ if not nested:
+ self._Debug(" %s" % (str((nla_name, nla_data))))
return attributes
- def __init__(self):
+ def _OpenNetlinkSocket(self, family, groups=None):
+ sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family)
+ if groups:
+ sock.bind((0, groups))
+ sock.connect((0, 0)) # The kernel.
+ return sock
+
+ def __init__(self, family):
# Global sequence number.
self.seq = 0
- self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.FAMILY)
- self.sock.connect((0, 0)) # The kernel.
+ self.sock = self._OpenNetlinkSocket(family)
self.pid = self.sock.getsockname()[1]
def MaybeDebugCommand(self, command, flags, data):
@@ -164,7 +185,7 @@
if hdr.type == NLMSG_ERROR:
error = NLMsgErr(data).error
if error:
- raise IOError(error, os.strerror(-error))
+ raise IOError(-error, os.strerror(-error))
else:
raise ValueError("Expected ACK, got type %d" % hdr.type)
diff --git a/net/test/nf_test.py b/net/test/nf_test.py
new file mode 100755
index 0000000..cd6c976
--- /dev/null
+++ b/net/test/nf_test.py
@@ -0,0 +1,86 @@
+#!/usr/bin/python
+#
+# Copyright 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import errno
+from socket import *
+
+import multinetwork_base
+import net_test
+
+_TEST_IP4_ADDR = "192.0.2.1"
+_TEST_IP6_ADDR = "2001:db8::"
+
+
+# Regression tests for interactions between kernel networking and netfilter
+#
+# These tests were added to ensure that the lookup path for local-ICMP errors
+# do not cause failures. Specifically, local-ICMP packets do not have a
+# net_device in the skb, and has been known to trigger bugs in surrounding code.
+class NetilterRejectTargetTest(multinetwork_base.MultiNetworkBaseTest):
+
+ def setUp(self):
+ multinetwork_base.MultiNetworkBaseTest.setUp(self)
+ net_test.RunIptablesCommand(4, "-A OUTPUT -d " + _TEST_IP4_ADDR + " -j REJECT")
+ net_test.RunIptablesCommand(6, "-A OUTPUT -d " + _TEST_IP6_ADDR + " -j REJECT")
+
+ def tearDown(self):
+ net_test.RunIptablesCommand(4, "-D OUTPUT -d " + _TEST_IP4_ADDR + " -j REJECT")
+ net_test.RunIptablesCommand(6, "-D OUTPUT -d " + _TEST_IP6_ADDR + " -j REJECT")
+ multinetwork_base.MultiNetworkBaseTest.tearDown(self)
+
+ # Test a rejected TCP connect. The responding ICMP may not have skb->dev set.
+ # This tests the local-ICMP output-input path.
+ def CheckRejectedTcp(self, version, addr):
+ sock = net_test.TCPSocket(net_test.GetAddressFamily(version))
+ netid = self.RandomNetid()
+ self.SelectInterface(sock, netid, "mark")
+
+ # Expect this to fail with ICMP unreachable
+ try:
+ sock.connect((addr, 53))
+ except IOError:
+ pass
+
+ def testRejectTcp4(self):
+ self.CheckRejectedTcp(4, _TEST_IP4_ADDR)
+
+ def testRejectTcp6(self):
+ self.CheckRejectedTcp(6, _TEST_IP6_ADDR)
+
+ # Test a rejected UDP connect. The responding ICMP may not have skb->dev set.
+ # This tests the local-ICMP output-input path.
+ def CheckRejectedUdp(self, version, addr):
+ sock = net_test.UDPSocket(net_test.GetAddressFamily(version))
+ netid = self.RandomNetid()
+ self.SelectInterface(sock, netid, "mark")
+
+ # Expect this to fail with ICMP unreachable
+ try:
+ sock.sendto(net_test.UDP_PAYLOAD, (addr, 53))
+ except IOError:
+ pass
+
+ def testRejectUdp4(self):
+ self.CheckRejectedUdp(4, _TEST_IP4_ADDR)
+
+ def testRejectUdp6(self):
+ self.CheckRejectedUdp(6, _TEST_IP6_ADDR)
+
+
+if __name__ == "__main__":
+ unittest.main()
\ No newline at end of file
diff --git a/net/test/no_test b/net/test/no_test
new file mode 100755
index 0000000..b23e556
--- /dev/null
+++ b/net/test/no_test
@@ -0,0 +1 @@
+#!/bin/true
diff --git a/net/test/packets.py b/net/test/packets.py
index c36efe8..87a72f9 100644
--- a/net/test/packets.py
+++ b/net/test/packets.py
@@ -29,6 +29,8 @@
TCP_WINDOW = 14400
+PTB_MTU = 1280
+
PING_IDENT = 0xff19
PING_PAYLOAD = "foobarbaz"
PING_SEQ = 3
@@ -42,7 +44,7 @@
return random.randint(1025, 65535)
def _GetIpLayer(version):
- return {4: scapy.IP, 6: scapy.IPv6}[version]
+ return {4: scapy.IP, 5: scapy.IP, 6: scapy.IPv6}[version]
def _SetPacketTos(packet, tos):
if isinstance(packet, scapy.IPv6):
@@ -61,14 +63,14 @@
ip(src=srcaddr, dst=dstaddr) /
scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
-def UDPWithOptions(version, srcaddr, dstaddr, sport=0):
+def UDPWithOptions(version, srcaddr, dstaddr, sport=0, lifetime=39):
if version == 4:
- packet = (scapy.IP(src=srcaddr, dst=dstaddr, ttl=39, tos=0x83) /
+ packet = (scapy.IP(src=srcaddr, dst=dstaddr, ttl=lifetime, tos=0x83) /
scapy.UDP(sport=sport, dport=53) /
UDP_PAYLOAD)
else:
packet = (scapy.IPv6(src=srcaddr, dst=dstaddr,
- fl=0xbeef, hlim=39, tc=0x83) /
+ fl=0xbeef, hlim=lifetime, tc=0x83) /
scapy.UDP(sport=sport, dport=53) /
UDP_PAYLOAD)
return ("UDPv%d packet with options" % version, packet)
@@ -92,7 +94,8 @@
return ("TCP RST",
ip(src=srcaddr, dst=dstaddr) /
scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + was_syn_or_fin, seq=None,
+ ack=original.seq + was_syn_or_fin,
+ seq=original.ack,
flags=TCP_RST | TCP_ACK, window=TCP_WINDOW))
def SYNACK(version, srcaddr, dstaddr, packet):
@@ -151,15 +154,20 @@
def ICMPPacketTooBig(version, srcaddr, dstaddr, packet):
if version == 4:
- return ("ICMPv4 fragmentation needed",
- scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
- scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
+ desc = "ICMPv4 fragmentation needed"
+ pkt = (scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
+ scapy.ICMPerror(type=3, code=4) / str(packet)[:64])
+ # Only newer versions of scapy understand that since RFC 1191, the last two
+ # bytes of a fragmentation needed ICMP error contain the MTU.
+ if hasattr(scapy.ICMP, "nexthopmtu"):
+ pkt[scapy.ICMPerror].nexthopmtu = PTB_MTU
+ else:
+ pkt[scapy.ICMPerror].unused = PTB_MTU
+ return desc, pkt
else:
- udp = packet.getlayer("UDP")
- udp.payload = str(udp.payload)[:1280-40-8]
return ("ICMPv6 Packet Too Big",
scapy.IPv6(src=srcaddr, dst=dstaddr) /
- scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
+ scapy.ICMPv6PacketTooBig(mtu=PTB_MTU) / str(packet)[:1232])
def ICMPEcho(version, srcaddr, dstaddr):
ip = _GetIpLayer(version)
diff --git a/net/test/parallel_tests.sh b/net/test/parallel_tests.sh
index ec96cce..eb67421 100755
--- a/net/test/parallel_tests.sh
+++ b/net/test/parallel_tests.sh
@@ -3,8 +3,8 @@
# Runs many iterations of run_net_test.sh in parallel processes, for the
# purposes of finding flaky tests.
-if ! [[ $1 =~ ^[0-9]+$ ]] || ! [[ $2 =~ ^[0-9]+$ ]]; then
- echo "Usage: $0 <workers> <runs_per_worker>" >&2
+if ! [[ $1 =~ ^[0-9]+$ ]] || ! [[ $2 =~ ^[0-9]+$ ]] || [ -z "$3" ]; then
+ echo "Usage: $0 <workers> <runs_per_worker> <test>" >&2
exit 1
fi
@@ -12,9 +12,10 @@
function runtests() {
local worker=$1
local runs=$2
+ local test=$3
local j=0
while ((j < runs)); do
- $DIR/run_net_test.sh --readonly --builder --nobuild all_tests.sh \
+ $DIR/run_net_test.sh --builder --nobuild $test \
> /dev/null 2> $RESULTSDIR/results.$worker.$j
j=$((j + 1))
echo -n "." >&2
@@ -23,19 +24,26 @@
WORKERS=$1
RUNS=$2
+TEST=$3
DIR=$(dirname $0)
RESULTSDIR=$(mktemp --tmpdir -d net_test.parallel.XXXXXX)
[ -z $RESULTSDIR ] && exit 1
+test_file=$DIR/$TEST
+if [[ ! -x $test_file ]]; then
+ echo "test file '${test_file}' does not exist"
+ exit 1
+fi
+
echo "Building kernel..." >&2
-$DIR/run_net_test.sh --norun
+$DIR/run_net_test.sh --norun || exit 1
echo "Running $WORKERS worker(s) with $RUNS test run(s) each..." >&2
# Start all the workers.
worker=0
while ((worker < WORKERS)); do
- runtests $worker $RUNS &
+ runtests $worker $RUNS $TEST &
worker=$((worker + 1))
done
wait
diff --git a/net/test/parameterization_test.py b/net/test/parameterization_test.py
new file mode 100755
index 0000000..8f9e130
--- /dev/null
+++ b/net/test/parameterization_test.py
@@ -0,0 +1,83 @@
+#!/usr/bin/python
+#
+# Copyright 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import unittest
+
+import net_test
+import util
+
+
+def InjectTests():
+ ParmeterizationTest.InjectTests()
+
+
+# This test class ensures that the Parameterized Test generator in utils.py
+# works properly. It injects test methods into itself, and ensures that they
+# are generated as expected, and that the TestClosures being run are properly
+# defined, and running different parameterized tests each time.
+class ParmeterizationTest(net_test.NetworkTest):
+ tests_run_list = []
+
+ @staticmethod
+ def NameGenerator(a, b, c):
+ return str(a) + "_" + str(b) + "_" + str(c)
+
+ @classmethod
+ def InjectTests(cls):
+ PARAMS_A = (1, 2)
+ PARAMS_B = (3, 4)
+ PARAMS_C = (5, 6)
+
+ param_list = itertools.product(PARAMS_A, PARAMS_B, PARAMS_C)
+ util.InjectParameterizedTest(cls, param_list, cls.NameGenerator)
+
+ def ParamTestDummyFunc(self, a, b, c):
+ self.tests_run_list.append(
+ "testDummyFunc_" + ParmeterizationTest.NameGenerator(a, b, c))
+
+ def testParameterization(self):
+ expected = [
+ "testDummyFunc_1_3_5",
+ "testDummyFunc_1_3_6",
+ "testDummyFunc_1_4_5",
+ "testDummyFunc_1_4_6",
+ "testDummyFunc_2_3_5",
+ "testDummyFunc_2_3_6",
+ "testDummyFunc_2_4_5",
+ "testDummyFunc_2_4_6",
+ ]
+
+ actual = [name for name in dir(self) if name.startswith("testDummyFunc")]
+
+ # Check that name and contents are equal
+ self.assertEqual(len(expected), len(actual))
+ self.assertEqual(sorted(expected), sorted(actual))
+
+ # Start a clean list, and run all the tests.
+ self.tests_run_list = list()
+ for test_name in expected:
+ test_method = getattr(self, test_name)
+ test_method()
+
+ # Make sure all tests have been run with the correct parameters
+ for test_name in expected:
+ self.assertTrue(test_name in self.tests_run_list)
+
+
+if __name__ == "__main__":
+ ParmeterizationTest.InjectTests()
+ unittest.main()
diff --git a/net/test/pf_key.py b/net/test/pf_key.py
new file mode 100755
index 0000000..875e01c
--- /dev/null
+++ b/net/test/pf_key.py
@@ -0,0 +1,328 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Partial implementation of the PFKEYv2 interface."""
+
+# pylint: disable=g-bad-todo,bad-whitespace
+
+import os
+from socket import * # pylint: disable=wildcard-import
+import sys
+
+import cstruct
+import net_test
+
+
+# AF_KEY socket type. See include/linux/socket.h.
+AF_KEY = 15
+
+# PFKEYv2 constants. See include/uapi/linux/pfkeyv2.h.
+PF_KEY_V2 = 2
+
+# IPsec constants. See include/uapi/linux/ipsec.h.
+IPSEC_MODE_ANY = 0
+IPSEC_MODE_TRANSPORT = 1
+IPSEC_MODE_TUNNEL = 2
+IPSEC_MODE_BEET = 3
+
+# Operation types.
+SADB_ADD = 3
+SADB_DELETE = 4
+SADB_DUMP = 10
+
+# SA types.
+SADB_TYPE_UNSPEC = 0
+SADB_TYPE_AH = 2
+SADB_TYPE_ESP = 3
+
+# SA states.
+SADB_SASTATE_LARVAL = 0
+SADB_SASTATE_MATURE = 1
+SADB_SASTATE_DYING = 2
+SADB_SASTATE_DEAD = 3
+
+# Authentication algorithms.
+SADB_AALG_NONE = 0
+SADB_AALG_MD5HMAC = 2
+SADB_AALG_SHA1HMAC = 3
+SADB_X_AALG_SHA2_256HMAC = 5
+SADB_X_AALG_SHA2_384HMAC = 6
+SADB_X_AALG_SHA2_512HMAC = 7
+SADB_X_AALG_RIPEMD160HMAC = 8
+SADB_X_AALG_AES_XCBC_MAC = 9
+SADB_X_AALG_NULL = 251
+
+# Encryption algorithms.
+SADB_EALG_NONE = 0
+SADB_EALG_DESCBC = 2
+SADB_EALG_3DESCBC = 3
+SADB_X_EALG_CASTCBC = 6
+SADB_X_EALG_BLOWFISHCBC = 7
+SADB_EALG_NULL = 11
+SADB_X_EALG_AESCBC = 12
+SADB_X_EALG_AESCTR = 13
+SADB_X_EALG_AES_CCM_ICV8 = 14
+SADB_X_EALG_AES_CCM_ICV12 = 15
+SADB_X_EALG_AES_CCM_ICV16 = 16
+SADB_X_EALG_AES_GCM_ICV8 = 18
+SADB_X_EALG_AES_GCM_ICV12 = 19
+SADB_X_EALG_AES_GCM_ICV16 = 20
+SADB_X_EALG_CAMELLIACBC = 22
+SADB_X_EALG_NULL_AES_GMAC = 23
+SADB_X_EALG_SERPENTCBC = 252
+SADB_X_EALG_TWOFISHCBC = 253
+
+# Extension Header values.
+SADB_EXT_RESERVED = 0
+SADB_EXT_SA = 1
+SADB_EXT_LIFETIME_CURRENT = 2
+SADB_EXT_LIFETIME_HARD = 3
+SADB_EXT_LIFETIME_SOFT = 4
+SADB_EXT_ADDRESS_SRC = 5
+SADB_EXT_ADDRESS_DST = 6
+SADB_EXT_ADDRESS_PROXY = 7
+SADB_EXT_KEY_AUTH = 8
+SADB_EXT_KEY_ENCRYPT = 9
+SADB_EXT_IDENTITY_SRC = 10
+SADB_EXT_IDENTITY_DST = 11
+SADB_EXT_SENSITIVITY = 12
+SADB_EXT_PROPOSAL = 13
+SADB_EXT_SUPPORTED_AUTH = 14
+SADB_EXT_SUPPORTED_ENCRYPT = 15
+SADB_EXT_SPIRANGE = 16
+SADB_X_EXT_KMPRIVATE = 17
+SADB_X_EXT_POLICY = 18
+SADB_X_EXT_SA2 = 19
+SADB_X_EXT_NAT_T_TYPE = 20
+SADB_X_EXT_NAT_T_SPORT = 21
+SADB_X_EXT_NAT_T_DPORT = 22
+SADB_X_EXT_NAT_T_OA = 23
+SADB_X_EXT_SEC_CTX = 24
+SADB_X_EXT_KMADDRESS = 25
+SADB_X_EXT_FILTER = 26
+
+# Data structure formats.
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+SadbMsg = cstruct.Struct(
+ "SadbMsg", "=BBBBHHII", "version type errno satype len reserved seq pid")
+
+# Fake struct containing the common beginning of all extension structs.
+SadbExt = cstruct.Struct("SadbExt", "=HH", "len exttype")
+
+SadbSa = cstruct.Struct(
+ "SadbSa", "=IBBBBI", "spi replay state auth encrypt flags")
+
+SadbLifetime = cstruct.Struct(
+ "SadbLifetime", "=IQQQ", "allocations bytes addtime usetime")
+
+SadbAddress = cstruct.Struct("SadbAddress", "=BB2x", "proto prefixlen")
+
+SadbKey = cstruct.Struct("SadbKey", "=H2x", "bits")
+
+SadbXSa2 = cstruct.Struct("SadbXSa2", "=B3xII", "mode sequence reqid")
+
+SadbXNatTType = cstruct.Struct("SadbXNatTType", "=B3x", "type")
+
+SadbXNatTPort = cstruct.Struct("SadbXNatTPort", "!H2x", "port")
+
+
+def _GetConstantName(value, prefix):
+ """Translates a number to a constant of the same value in this file."""
+ thismodule = sys.modules[__name__]
+ # Match shorter constant names first. This allows us to match SADB_DUMP and
+ # instead of, say, SADB_EXT_LIFETIME_HARD if we pass in a prefix of "SADB_"
+ # and a value of 3, and match SADB_EXT_LIFETIME_HARD just by specifying
+ # a longer prefix.
+ for name in sorted(dir(thismodule), key=len):
+ if (name.startswith(prefix) and
+ name.isupper() and getattr(thismodule, name) == value):
+ return name
+ return value
+
+
+def _GetMultiConstantName(value, prefixes):
+ for prefix in prefixes:
+ name = _GetConstantName(value, prefix)
+ try:
+ int(name)
+ continue
+ except ValueError:
+ return name
+
+
+# Converts extension blobs to a (name, struct, attrs) tuple.
+def ParseExtension(exttype, data):
+ struct_type = None
+ if exttype == SADB_EXT_SA:
+ struct_type = SadbSa
+ elif exttype in [SADB_EXT_LIFETIME_CURRENT, SADB_EXT_LIFETIME_HARD,
+ SADB_EXT_LIFETIME_SOFT]:
+ struct_type = SadbLifetime
+ elif exttype in [SADB_EXT_ADDRESS_SRC, SADB_EXT_ADDRESS_DST,
+ SADB_EXT_ADDRESS_PROXY]:
+ struct_type = SadbAddress
+ elif exttype in [SADB_EXT_KEY_AUTH, SADB_EXT_KEY_ENCRYPT]:
+ struct_type = SadbKey
+ elif exttype == SADB_X_EXT_SA2:
+ struct_type = SadbXSa2
+ elif exttype == SADB_X_EXT_NAT_T_TYPE:
+ struct_type = SadbXNatTType
+ elif exttype in [SADB_X_EXT_NAT_T_SPORT, SADB_X_EXT_NAT_T_DPORT]:
+ struct_type = SadbXNatTPort
+
+ if struct_type:
+ ext, attrs = cstruct.Read(data, struct_type)
+ else:
+ ext, attrs, = data, ""
+
+ return exttype, ext, attrs
+
+class PfKey(object):
+
+ """PF_KEY interface to kernel IPsec implementation."""
+
+ def __init__(self):
+ self.sock = socket(AF_KEY, SOCK_RAW, PF_KEY_V2)
+ net_test.SetNonBlocking(self.sock)
+ self.seq = 0
+
+ def Recv(self):
+ reply = self.sock.recv(4096)
+ msg = SadbMsg(reply)
+ # print "RECV:", self.DecodeSadbMsg(msg)
+ if msg.errno != 0:
+ raise OSError(msg.errno, os.strerror(msg.errno))
+ return reply
+
+ def SendAndRecv(self, msg, extensions):
+ self.seq += 1
+ msg.seq = self.seq
+ msg.pid = os.getpid()
+ msg.len = (len(SadbMsg) + len(extensions)) / 8
+ self.sock.send(msg.Pack() + extensions)
+ # print "SEND:", self.DecodeSadbMsg(msg)
+ return self.Recv()
+
+ def PackPfKeyExtensions(self, extlist):
+ extensions = ""
+ for exttype, extstruct, attrs in extlist:
+ extdata = extstruct.Pack()
+ ext = SadbExt(((len(extdata) + len(SadbExt) + len(attrs)) / 8, exttype))
+ extensions += ext.Pack() + extdata + attrs
+ return extensions
+
+ def MakeSadbMsg(self, msgtype, satype):
+ # errno is 0. seq, pid and len are filled in by SendAndRecv().
+ return SadbMsg((PF_KEY_V2, msgtype, 0, satype, 0, 0, 0, 0))
+
+ def MakeSadbExtAddr(self, exttype, addr):
+ prefixlen = {AF_INET: 32, AF_INET6: 128}[addr.family]
+ packed = addr.Pack()
+ padbytes = (len(SadbExt) + len(SadbAddress) + len(packed)) % 8
+ packed += "\x00" * padbytes
+ return (exttype, SadbAddress((0, prefixlen)), packed)
+
+ def AddSa(self, src, dst, spi, satype, mode, reqid, encryption,
+ encryption_key, auth, auth_key):
+ """Adds a security association."""
+ msg = self.MakeSadbMsg(SADB_ADD, satype)
+ replay = 4
+ extlist = [
+ (SADB_EXT_SA, SadbSa((htonl(spi), replay, SADB_SASTATE_MATURE,
+ auth, encryption, 0)), ""),
+ self.MakeSadbExtAddr(SADB_EXT_ADDRESS_SRC, src),
+ self.MakeSadbExtAddr(SADB_EXT_ADDRESS_DST, dst),
+ (SADB_X_EXT_SA2, SadbXSa2((mode, 0, reqid)), ""),
+ (SADB_EXT_KEY_AUTH, SadbKey((len(auth_key) * 8,)), auth_key),
+ (SADB_EXT_KEY_ENCRYPT, SadbKey((len(encryption_key) * 8,)),
+ encryption_key)
+ ]
+ self.SendAndRecv(msg, self.PackPfKeyExtensions(extlist))
+
+ def DelSa(self, src, dst, spi, satype):
+ """Deletes a security association."""
+ msg = self.MakeSadbMsg(SADB_DELETE, satype)
+ extlist = [
+ (SADB_EXT_SA, SadbSa((htonl(spi), 4, SADB_SASTATE_MATURE,
+ 0, 0, 0)), ""),
+ self.MakeSadbExtAddr(SADB_EXT_ADDRESS_SRC, src),
+ self.MakeSadbExtAddr(SADB_EXT_ADDRESS_DST, dst),
+ ]
+ self.SendAndRecv(msg, self.PackPfKeyExtensions(extlist))
+
+ @staticmethod
+ def DecodeSadbMsg(msg):
+ msgtype = _GetConstantName(msg.type, "SADB_")
+ satype = _GetConstantName(msg.satype, "SADB_TYPE_")
+ return ("SadbMsg(version=%d, type=%s, errno=%d, satype=%s, "
+ "len=%d, reserved=%d, seq=%d, pid=%d)" % (
+ msg.version, msgtype, msg.errno, satype, msg.len,
+ msg.reserved, msg.seq, msg.pid))
+
+ @staticmethod
+ def DecodeSadbSa(sa):
+ state = _GetConstantName(sa.state, "SADB_SASTATE_")
+ auth = _GetMultiConstantName(sa.auth, ["SADB_AALG_", "SADB_X_AALG"])
+ encrypt = _GetMultiConstantName(sa.encrypt, ["SADB_EALG_",
+ "SADB_X_EALG_"])
+ return ("SadbSa(spi=%x, replay=%d, state=%s, "
+ "auth=%s, encrypt=%s, flags=%x)" % (
+ sa.spi, sa.replay, state, auth, encrypt, sa.flags))
+
+ @staticmethod
+ def ExtensionsLength(msg, struct_type):
+ return (msg.len * 8) - len(struct_type)
+
+ @staticmethod
+ def ParseExtensions(data):
+ """Parses the extensions in a SADB message."""
+ extensions = []
+ while data:
+ ext, data = cstruct.Read(data, SadbExt)
+ datalen = PfKey.ExtensionsLength(ext, SadbExt)
+ extdata, data = data[:datalen], data[datalen:]
+ extensions.append(ParseExtension(ext.exttype, extdata))
+ return extensions
+
+ def DumpSaInfo(self):
+ """Returns a list of (SadbMsg, [(extension, attr), ...], ...) tuples."""
+ dump = []
+ msg = self.MakeSadbMsg(SADB_DUMP, SADB_TYPE_UNSPEC)
+ received = self.SendAndRecv(msg, "")
+ while received:
+ msg, data = cstruct.Read(received, SadbMsg)
+ extlen = self.ExtensionsLength(msg, SadbMsg)
+ extensions, data = data[:extlen], data[extlen:]
+ dump.append((msg, self.ParseExtensions(extensions)))
+ if msg.seq == 0: # End of dump.
+ break
+ received = self.Recv()
+ return dump
+
+ def PrintSaInfos(self, dump):
+ for msg, extensions in dump:
+ print self.DecodeSadbMsg(msg)
+ for exttype, ext, attrs in extensions:
+ exttype = _GetMultiConstantName(exttype, ["SADB_EXT", "SADB_X_EXT"])
+ if exttype == SADB_EXT_SA:
+ print " ", exttype, self.DecodeSadbSa(ext), attrs.encode("hex")
+ print " ", exttype, ext, attrs.encode("hex")
+ print
+
+
+if __name__ == "__main__":
+ p = PfKey()
+ p.DumpSaInfo()
diff --git a/net/test/pf_key_test.py b/net/test/pf_key_test.py
new file mode 100755
index 0000000..e58947c
--- /dev/null
+++ b/net/test/pf_key_test.py
@@ -0,0 +1,99 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
+from socket import *
+import unittest
+
+import csocket
+import pf_key
+import xfrm
+
+ENCRYPTION_KEY = ("308146eb3bd84b044573d60f5a5fd159"
+ "57c7d4fe567a2120f35bae0f9869ec22".decode("hex"))
+AUTH_KEY = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
+
+
+class PfKeyTest(unittest.TestCase):
+
+ def setUp(self):
+ self.pf_key = pf_key.PfKey()
+ self.xfrm = xfrm.Xfrm()
+
+ def testAddDelSa(self):
+ src4 = csocket.Sockaddr(("192.0.2.1", 0))
+ dst4 = csocket.Sockaddr(("192.0.2.2", 1))
+ self.pf_key.AddSa(src4, dst4, 0xdeadbeef, pf_key.SADB_TYPE_ESP,
+ pf_key.IPSEC_MODE_TRANSPORT, 54321,
+ pf_key.SADB_X_EALG_AESCBC, ENCRYPTION_KEY,
+ pf_key.SADB_X_AALG_SHA2_256HMAC, ENCRYPTION_KEY)
+
+ src6 = csocket.Sockaddr(("2001:db8::1", 0))
+ dst6 = csocket.Sockaddr(("2001:db8::2", 0))
+ self.pf_key.AddSa(src6, dst6, 0xbeefdead, pf_key.SADB_TYPE_ESP,
+ pf_key.IPSEC_MODE_TRANSPORT, 12345,
+ pf_key.SADB_X_EALG_AESCBC, ENCRYPTION_KEY,
+ pf_key.SADB_X_AALG_SHA2_256HMAC, ENCRYPTION_KEY)
+
+ sainfos = self.xfrm.DumpSaInfo()
+ self.assertEquals(2, len(sainfos))
+ state4, attrs4 = [(s, a) for s, a in sainfos if s.family == AF_INET][0]
+ state6, attrs6 = [(s, a) for s, a in sainfos if s.family == AF_INET6][0]
+
+ pfkey_sainfos = self.pf_key.DumpSaInfo()
+ self.assertEquals(2, len(pfkey_sainfos))
+ self.assertTrue(all(msg.satype == pf_key.SDB_TYPE_ESP)
+ for msg, _ in pfkey_sainfos)
+
+ self.assertEquals(xfrm.IPPROTO_ESP, state4.id.proto)
+ self.assertEquals(xfrm.IPPROTO_ESP, state6.id.proto)
+ self.assertEquals(54321, state4.reqid)
+ self.assertEquals(12345, state6.reqid)
+ self.assertEquals(0xdeadbeef, state4.id.spi)
+ self.assertEquals(0xbeefdead, state6.id.spi)
+
+ self.assertEquals(xfrm.PaddedAddress("192.0.2.1"), state4.saddr)
+ self.assertEquals(xfrm.PaddedAddress("192.0.2.2"), state4.id.daddr)
+ self.assertEquals(xfrm.PaddedAddress("2001:db8::1"), state6.saddr)
+ self.assertEquals(xfrm.PaddedAddress("2001:db8::2"), state6.id.daddr)
+
+ # The algorithm names are null-terminated, but after that contain garbage.
+ # Kernel bug?
+ aes_name = "cbc(aes)\x00"
+ sha256_name = "hmac(sha256)\x00"
+ self.assertTrue(attrs4["XFRMA_ALG_CRYPT"].name.startswith(aes_name))
+ self.assertTrue(attrs6["XFRMA_ALG_CRYPT"].name.startswith(aes_name))
+ self.assertTrue(attrs4["XFRMA_ALG_AUTH"].name.startswith(sha256_name))
+ self.assertTrue(attrs6["XFRMA_ALG_AUTH"].name.startswith(sha256_name))
+
+ self.assertEquals(256, attrs4["XFRMA_ALG_CRYPT"].key_len)
+ self.assertEquals(256, attrs4["XFRMA_ALG_CRYPT"].key_len)
+ self.assertEquals(256, attrs6["XFRMA_ALG_AUTH"].key_len)
+ self.assertEquals(256, attrs6["XFRMA_ALG_AUTH"].key_len)
+ self.assertEquals(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len)
+ self.assertEquals(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len)
+
+ self.assertEquals(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len)
+ self.assertEquals(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len)
+
+ self.pf_key.DelSa(src4, dst4, 0xdeadbeef, pf_key.SADB_TYPE_ESP)
+ self.assertEquals(1, len(self.xfrm.DumpSaInfo()))
+ self.pf_key.DelSa(src6, dst6, 0xbeefdead, pf_key.SADB_TYPE_ESP)
+ self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/net/test/ping6_test.py b/net/test/ping6_test.py
index cf9e2c2..dd73e88 100755
--- a/net/test/ping6_test.py
+++ b/net/test/ping6_test.py
@@ -306,6 +306,10 @@
self.assertEqual(len(data), len(rcvd))
self.assertEqual(data[6:].encode("hex"), rcvd[6:].encode("hex"))
+ @staticmethod
+ def IsAlmostEqual(expected, actual, delta):
+ return abs(expected - actual) < delta
+
def CheckSockStatFile(self, name, srcaddr, srcport, dstaddr, dstport, state,
txmem=0, rxmem=0):
expected = ["%s:%04X" % (net_test.FormatSockStatAddress(srcaddr), srcport),
@@ -313,8 +317,20 @@
"%02X" % state,
"%08X:%08X" % (txmem, rxmem),
str(os.getuid()), "2", "0"]
- actual = self.ReadProcNetSocket(name)[-1]
- self.assertListEqual(expected, actual)
+ for actual in self.ReadProcNetSocket(name):
+ # Check that rxmem and txmem don't differ too much from each other.
+ actual_txmem, actual_rxmem = expected[3].split(":")
+ if self.IsAlmostEqual(txmem, int(actual_txmem, 16), txmem / 4):
+ return
+ if self.IsAlmostEqual(rxmem, int(actual_rxmem, 16), rxmem / 4):
+ return
+
+ # Check all the parameters except rxmem and txmem.
+ expected[3] = actual[3]
+ if expected == actual:
+ return
+
+ self.fail("Cound not find socket matching %s" % expected)
def testIPv4SendWithNoConnection(self):
s = net_test.IPv4PingSocket()
@@ -468,7 +484,7 @@
s5.bind(("0.0.0.0", 167))
s4.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 44))
self.assertValidPingResponse(s5, net_test.IPV4_PING)
- net_test.SetSocketTimeout(s4, 100)
+ csocket.SetSocketTimeout(s4, 100)
self.assertRaisesErrno(errno.EAGAIN, s4.recv, 32768)
# If SO_REUSEADDR is turned off, then we get EADDRINUSE.
@@ -615,15 +631,16 @@
upstream net:
5e45789 net: ipv6: Fix ping to link-local addresses.
"""
- s = net_test.IPv6PingSocket()
for mode in ["oif", "ucast_oif", None]:
s = net_test.IPv6PingSocket()
for netid in self.NETIDS:
+ s2 = net_test.IPv6PingSocket()
dst = self._RouterAddress(netid, 6)
self.assertTrue(dst.startswith("fe80:"))
if mode:
self.SelectInterface(s, netid, mode)
+ self.SelectInterface(s2, netid, mode)
scopeid = 0
else:
scopeid = self.ifindices[netid]
@@ -637,11 +654,21 @@
self.assertRaisesErrno(
errno.EINVAL,
s.sendto, net_test.IPV6_PING, (dst, 55, 0, otherscopeid))
+ self.assertRaisesErrno(
+ errno.EINVAL,
+ s.connect, (dst, 55, 0, otherscopeid))
+ # Try using both sendto and connect/send.
+ # If we get a reply, we sent the packet out on the right interface.
s.sendto(net_test.IPV6_PING, (dst, 123, 0, scopeid))
- # If we got a reply, we sent the packet out on the right interface.
self.assertValidPingResponse(s, net_test.IPV6_PING)
+ # IPV6_UNICAST_IF doesn't work on connected sockets.
+ if mode != "ucast_oif":
+ s2.connect((dst, 123, 0, scopeid))
+ s2.send(net_test.IPV6_PING)
+ self.assertValidPingResponse(s2, net_test.IPV6_PING)
+
def testMappedAddressFails(self):
s = net_test.IPv6PingSocket()
s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55))
diff --git a/net/test/qtaguid_test.py b/net/test/qtaguid_test.py
new file mode 100755
index 0000000..c121df2
--- /dev/null
+++ b/net/test/qtaguid_test.py
@@ -0,0 +1,228 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for xt_qtaguid."""
+
+import errno
+from socket import * # pylint: disable=wildcard-import
+import unittest
+import os
+
+import net_test
+import packets
+import tcp_test
+
+CTRL_PROCPATH = "/proc/net/xt_qtaguid/ctrl"
+OTHER_UID_GID = 12345
+HAVE_QTAGUID = os.path.exists(CTRL_PROCPATH)
+
+
+@unittest.skipUnless(HAVE_QTAGUID, "xt_qtaguid not supported")
+class QtaguidTest(tcp_test.TcpBaseTest):
+
+ def RunIptablesCommand(self, args):
+ self.assertFalse(net_test.RunIptablesCommand(4, args))
+ self.assertFalse(net_test.RunIptablesCommand(6, args))
+
+ def setUp(self):
+ self.RunIptablesCommand("-N qtaguid_test_OUTPUT")
+ self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT")
+
+ def tearDown(self):
+ self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT")
+ self.RunIptablesCommand("-F qtaguid_test_OUTPUT")
+ self.RunIptablesCommand("-X qtaguid_test_OUTPUT")
+
+ def WriteToCtrl(self, command):
+ ctrl_file = open(CTRL_PROCPATH, 'w')
+ ctrl_file.write(command)
+ ctrl_file.close()
+
+ def CheckTag(self, tag, uid):
+ for line in open(CTRL_PROCPATH, 'r').readlines():
+ if "tag=0x%x (uid=%d)" % ((tag|uid), uid) in line:
+ return True
+ return False
+
+ def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted):
+ add_del = "-A" if is_add else "-D"
+ uid_gid = "--gid-owner" if is_gid else "--uid-owner"
+ if inverted:
+ args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id)
+ else:
+ args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id)
+ self.assertFalse(net_test.RunIptablesCommand(version, args))
+
+ def AddIptablesRule(self, version, is_gid, myId):
+ self.SetIptablesRule(version, True, is_gid, myId, False)
+
+ def AddIptablesInvertedRule(self, version, is_gid, myId):
+ self.SetIptablesRule(version, True, is_gid, myId, True)
+
+ def DelIptablesRule(self, version, is_gid, myId):
+ self.SetIptablesRule(version, False, is_gid, myId, False)
+
+ def DelIptablesInvertedRule(self, version, is_gid, myId):
+ self.SetIptablesRule(version, False, is_gid, myId, True)
+
+ def CheckSocketOutput(self, version, is_gid):
+ myId = os.getgid() if is_gid else os.getuid()
+ self.AddIptablesRule(version, is_gid, myId)
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ s = socket(family, SOCK_DGRAM, 0)
+ addr = {4: "127.0.0.1", 6: "::1"}[version]
+ s.bind((addr, 0))
+ addr = s.getsockname()
+ self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr)
+ self.DelIptablesRule(version, is_gid, myId)
+ s.sendto("foo", addr)
+ data, sockaddr = s.recvfrom(4096)
+ self.assertEqual("foo", data)
+ self.assertEqual(sockaddr, addr)
+
+ def CheckSocketOutputInverted(self, version, is_gid):
+ # Load a inverted iptable rule on current uid/gid 0, traffic from other
+ # uid/gid should be blocked and traffic from current uid/gid should pass.
+ myId = os.getgid() if is_gid else os.getuid()
+ self.AddIptablesInvertedRule(version, is_gid, myId)
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ s = socket(family, SOCK_DGRAM, 0)
+ addr1 = {4: "127.0.0.1", 6: "::1"}[version]
+ s.bind((addr1, 0))
+ addr1 = s.getsockname()
+ s.sendto("foo", addr1)
+ data, sockaddr = s.recvfrom(4096)
+ self.assertEqual("foo", data)
+ self.assertEqual(sockaddr, addr1)
+ with net_test.RunAsUidGid(0 if is_gid else 12345,
+ 12345 if is_gid else 0):
+ s2 = socket(family, SOCK_DGRAM, 0)
+ addr2 = {4: "127.0.0.1", 6: "::1"}[version]
+ s2.bind((addr2, 0))
+ addr2 = s2.getsockname()
+ self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2)
+ self.DelIptablesInvertedRule(version, is_gid, myId)
+ s.sendto("foo", addr1)
+ data, sockaddr = s.recvfrom(4096)
+ self.assertEqual("foo", data)
+ self.assertEqual(sockaddr, addr1)
+
+ def SendRSTOnClosedSocket(self, version, netid, expect_rst):
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
+ self.accepted.setsockopt(net_test.SOL_TCP, net_test.TCP_LINGER2, -1)
+ net_test.EnableFinWait(self.accepted)
+ self.accepted.shutdown(SHUT_WR)
+ desc, fin = self.FinPacket()
+ self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin)
+ finversion = 4 if version == 5 else version
+ desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
+ self.ReceivePacketOn(netid, finack)
+ try:
+ self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin)
+ except AssertionError:
+ pass
+ self.accepted.close()
+ desc, rst = packets.RST(version, self.myaddr, self.remoteaddr, self.last_packet)
+ if expect_rst:
+ msg = "closing socket with linger2, expecting %s: " % desc
+ self.ExpectPacketOn(netid, msg, rst)
+ else:
+ msg = "closing socket with linger2, expecting no packets"
+ self.ExpectNoPacketsOn(netid, msg)
+
+ def CheckUidGidCombination(self, version, invert_gid, invert_uid):
+ my_uid = os.getuid()
+ my_gid = os.getgid()
+ if invert_gid:
+ self.AddIptablesInvertedRule(version, True, my_gid)
+ else:
+ self.AddIptablesRule(version, True, OTHER_UID_GID)
+ if invert_uid:
+ self.AddIptablesInvertedRule(version, False, my_uid)
+ else:
+ self.AddIptablesRule(version, False, OTHER_UID_GID)
+ for netid in self.NETIDS:
+ self.SendRSTOnClosedSocket(version, netid, not invert_gid)
+ if invert_gid:
+ self.DelIptablesInvertedRule(version, True, my_gid)
+ else:
+ self.DelIptablesRule(version, True, OTHER_UID_GID)
+ if invert_uid:
+ self.AddIptablesInvertedRule(version, False, my_uid)
+ else:
+ self.DelIptablesRule(version, False, OTHER_UID_GID)
+
+ def testCloseWithoutUntag(self):
+ self.dev_file = open("/dev/xt_qtaguid", "r");
+ sk = socket(AF_INET, SOCK_DGRAM, 0)
+ uid = os.getuid()
+ tag = 0xff00ff00 << 32
+ command = "t %d %d %d" % (sk.fileno(), tag, uid)
+ self.WriteToCtrl(command)
+ self.assertTrue(self.CheckTag(tag, uid))
+ sk.close();
+ self.assertFalse(self.CheckTag(tag, uid))
+ self.dev_file.close();
+
+ def testTagWithoutDeviceOpen(self):
+ sk = socket(AF_INET, SOCK_DGRAM, 0)
+ uid = os.getuid()
+ tag = 0xff00ff00 << 32
+ command = "t %d %d %d" % (sk.fileno(), tag, uid)
+ self.WriteToCtrl(command)
+ self.assertTrue(self.CheckTag(tag, uid))
+ self.dev_file = open("/dev/xt_qtaguid", "r")
+ sk.close()
+ self.assertFalse(self.CheckTag(tag, uid))
+ self.dev_file.close();
+
+ def testUidGidMatch(self):
+ self.CheckSocketOutput(4, False)
+ self.CheckSocketOutput(6, False)
+ self.CheckSocketOutput(4, True)
+ self.CheckSocketOutput(6, True)
+ self.CheckSocketOutputInverted(4, True)
+ self.CheckSocketOutputInverted(6, True)
+ self.CheckSocketOutputInverted(4, False)
+ self.CheckSocketOutputInverted(6, False)
+
+ def testCheckNotMatchGid(self):
+ self.assertIn("match_no_sk_gid", open(CTRL_PROCPATH, 'r').read())
+
+ def testRstPacketNotDropped(self):
+ my_uid = os.getuid()
+ self.AddIptablesInvertedRule(4, False, my_uid)
+ for netid in self.NETIDS:
+ self.SendRSTOnClosedSocket(4, netid, True)
+ self.DelIptablesInvertedRule(4, False, my_uid)
+ self.AddIptablesInvertedRule(6, False, my_uid)
+ for netid in self.NETIDS:
+ self.SendRSTOnClosedSocket(6, netid, True)
+ self.DelIptablesInvertedRule(6, False, my_uid)
+
+ def testUidGidCombineMatch(self):
+ self.CheckUidGidCombination(4, invert_gid=True, invert_uid=True)
+ self.CheckUidGidCombination(4, invert_gid=True, invert_uid=False)
+ self.CheckUidGidCombination(4, invert_gid=False, invert_uid=True)
+ self.CheckUidGidCombination(4, invert_gid=False, invert_uid=False)
+ self.CheckUidGidCombination(6, invert_gid=True, invert_uid=True)
+ self.CheckUidGidCombination(6, invert_gid=True, invert_uid=False)
+ self.CheckUidGidCombination(6, invert_gid=False, invert_uid=True)
+ self.CheckUidGidCombination(6, invert_gid=False, invert_uid=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/net/test/resilient_rs_test.py b/net/test/resilient_rs_test.py
new file mode 100755
index 0000000..12843c4
--- /dev/null
+++ b/net/test/resilient_rs_test.py
@@ -0,0 +1,172 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import posix
+import select
+from socket import * # pylint: disable=wildcard-import
+import time
+import unittest
+from math import pow
+
+import multinetwork_base
+
+def accumulate(lis):
+ total = 0
+ for x in lis:
+ total += x
+ yield total
+
+# This test attempts to validate time related behavior of the kernel
+# under test and is therefore inherently prone to races. To avoid
+# flakes, this test is biased to declare RFC 7559 RS backoff is
+# present on the assumption that repeated runs will detect
+# non-compliant kernels with high probability.
+#
+# If higher confidence is required, REQUIRED_SAMPLES and
+# SAMPLE_INTERVAL can be increased at the cost of increased runtime.
+class ResilientRouterSolicitationTest(multinetwork_base.MultiNetworkBaseTest):
+ """Tests for IPv6 'resilient rs' RFC 7559 backoff behaviour.
+
+ Relevant kernel commits:
+ upstream:
+ bd11f0741fa5 ipv6 addrconf: implement RFC7559 router solicitation backoff
+ android-4.4:
+ e246a2f11fcc UPSTREAM: ipv6 addrconf: implement RFC7559 router solicitation backoff
+
+ android-4.1:
+ c6e9a50816a0 UPSTREAM: ipv6 addrconf: implement RFC7559 router solicitation backoff
+
+ android-3.18:
+ 2a7561c61417 UPSTREAM: ipv6 addrconf: implement RFC7559 router solicitation backoff
+
+ android-3.10:
+ ce2d59ac01f3 BACKPORT: ipv6 addrconf: implement RFC7559 router solicitation backoff
+
+ """
+ ROUTER_SOLICIT = 133
+
+ _TEST_NETID = 123
+ _PROC_NET_TUNABLE = "/proc/sys/net/ipv6/conf/%s/%s"
+
+ @classmethod
+ def setUpClass(cls):
+ return
+
+ def setUp(self):
+ return
+
+ @classmethod
+ def tearDownClass(cls):
+ return
+
+ def tearDown(self):
+ return
+
+ @classmethod
+ def isIPv6RouterSolicitation(cls, packet):
+ return ((len(packet) >= 14 + 40 + 1) and
+ # Use net_test.ETH_P_IPV6 here
+ (ord(packet[12]) == 0x86) and
+ (ord(packet[13]) == 0xdd) and
+ (ord(packet[14]) >> 4 == 6) and
+ (ord(packet[14 + 40]) == cls.ROUTER_SOLICIT))
+
+ def makeTunInterface(self, netid):
+ defaultDisableIPv6Path = self._PROC_NET_TUNABLE % ("default", "disable_ipv6")
+ savedDefaultDisableIPv6 = self.GetSysctl(defaultDisableIPv6Path)
+ self.SetSysctl(defaultDisableIPv6Path, 1)
+ tun = self.CreateTunInterface(netid)
+ self.SetSysctl(defaultDisableIPv6Path, savedDefaultDisableIPv6)
+ return tun
+
+ def testFeatureExists(self):
+ return
+
+ def testRouterSolicitationBackoff(self):
+ # Test error tolerance
+ EPSILON = 0.1
+ # Minimum RFC3315 S14 backoff
+ MIN_EXP = 1.9 - EPSILON
+ # Maximum RFC3315 S14 backoff
+ MAX_EXP = 2.1 + EPSILON
+ SOLICITATION_INTERVAL = 1
+ # Linear backoff for 4 samples yields 3.6 < T < 4.4
+ # Exponential backoff for 4 samples yields 4.83 < T < 9.65
+ REQUIRED_SAMPLES = 4
+ # Give up after 10 seconds. Tuned for REQUIRED_SAMPLES = 4
+ SAMPLE_INTERVAL = 10
+ # Practically unlimited backoff
+ SOLICITATION_MAX_INTERVAL = 1000
+ MIN_LIN = SOLICITATION_INTERVAL * (0.9 - EPSILON)
+ MAX_LIN = SOLICITATION_INTERVAL * (1.1 + EPSILON)
+ netid = self._TEST_NETID
+ tun = self.makeTunInterface(netid)
+ poll = select.poll()
+ poll.register(tun, select.POLLIN | select.POLLPRI)
+
+ PROC_SETTINGS = [
+ ("router_solicitation_delay", 1),
+ ("router_solicitation_interval", SOLICITATION_INTERVAL),
+ ("router_solicitation_max_interval", SOLICITATION_MAX_INTERVAL),
+ ("router_solicitations", -1),
+ ("disable_ipv6", 0) # MUST be last
+ ]
+
+ iface = self.GetInterfaceName(netid)
+ for tunable, value in PROC_SETTINGS:
+ self.SetSysctl(self._PROC_NET_TUNABLE % (iface, tunable), value)
+
+ start = time.time()
+ deadline = start + SAMPLE_INTERVAL
+
+ rsSendTimes = []
+ while True:
+ now = time.time();
+ poll.poll((deadline - now) * 1000)
+ try:
+ packet = posix.read(tun.fileno(), 4096)
+ except OSError:
+ break
+
+ txTime = time.time()
+ if txTime > deadline:
+ break;
+ if not self.isIPv6RouterSolicitation(packet):
+ continue
+
+ # Record time relative to first router solicitation
+ rsSendTimes.append(txTime - start)
+
+ # Exit early if we have at least REQUIRED_SAMPLES
+ if len(rsSendTimes) >= REQUIRED_SAMPLES:
+ continue
+
+ # Expect at least REQUIRED_SAMPLES router solicitations
+ self.assertLessEqual(REQUIRED_SAMPLES, len(rsSendTimes))
+
+ # Compute minimum and maximum bounds for RFC3315 S14 exponential backoff.
+ # First retransmit is linear backoff, subsequent retransmits are exponential
+ min_exp_bound = accumulate(map(lambda i: MIN_LIN * pow(MIN_EXP, i), range(0, len(rsSendTimes))))
+ max_exp_bound = accumulate(map(lambda i: MAX_LIN * pow(MAX_EXP, i), range(0, len(rsSendTimes))))
+
+ # Assert that each sample falls within the worst case interval. If all samples fit we accept
+ # the exponential backoff hypothesis
+ for (t, min_exp, max_exp) in zip(rsSendTimes[1:], min_exp_bound, max_exp_bound):
+ self.assertLess(min_exp, t)
+ self.assertGreater(max_exp, t)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/net/test/rootfs/common.sh b/net/test/rootfs/common.sh
new file mode 100644
index 0000000..172d9b6
--- /dev/null
+++ b/net/test/rootfs/common.sh
@@ -0,0 +1,57 @@
+#!/bin/sh
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+chroot_sanity_check() {
+ if [ ! -f /var/log/bootstrap.log ]; then
+ echo "Do not run this script directly!"
+ echo "This is supposed to be run from inside a debootstrap chroot!"
+ echo "Aborting."
+ exit 1
+ fi
+}
+
+chroot_cleanup() {
+ # Read-only root breaks booting via init
+ cat >/etc/fstab << EOF
+tmpfs /tmp tmpfs defaults 0 0
+tmpfs /var/log tmpfs defaults 0 0
+tmpfs /var/tmp tmpfs defaults 0 0
+EOF
+
+ # systemd will attempt to re-create this symlink if it does not exist,
+ # which fails if it is booting from a read-only root filesystem (which
+ # is normally the case). The syslink must be relative, not absolute,
+ # and it must point to /proc/self/mounts, not /proc/mounts.
+ ln -sf ../proc/self/mounts /etc/mtab
+
+ # Remove contaminants coming from the debootstrap process
+ echo vm >/etc/hostname
+ echo "nameserver 127.0.0.1" >/etc/resolv.conf
+
+ # Put the helper net_test.sh script into place
+ mv /root/net_test.sh /sbin/net_test.sh
+
+ # Make sure the /host mountpoint exists for net_test.sh
+ mkdir /host
+
+ # Disable the root password
+ passwd -d root
+
+ # Clean up any junk created by the imaging process
+ rm -rf /var/lib/apt/lists/* /var/log/bootstrap.log /root/* /tmp/*
+ find /var/log -type f -exec rm -f '{}' ';'
+}
diff --git a/net/test/rootfs/net_test.sh b/net/test/rootfs/net_test.sh
new file mode 100755
index 0000000..9c94d06
--- /dev/null
+++ b/net/test/rootfs/net_test.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+mount -t proc none /proc
+mount -t sysfs none /sys
+mount -t tmpfs tmpfs /tmp
+mount -t tmpfs tmpfs /run
+
+# If this system was booted under UML, it will always have a /proc/exitcode
+# file. If it was booted natively or under QEMU, it will not have this file.
+if [ -e /proc/exitcode ]; then
+ mount -t hostfs hostfs /host
+else
+ mount -t 9p -o trans=virtio,version=9p2000.L host /host
+fi
+
+test=$(cat /proc/cmdline | sed -re 's/.*net_test=([^ ]*).*/\1/g')
+cd $(dirname $test)
+./net_test.sh
+poweroff -f
diff --git a/net/test/rootfs/stretch.list b/net/test/rootfs/stretch.list
new file mode 100644
index 0000000..fbeddde
--- /dev/null
+++ b/net/test/rootfs/stretch.list
@@ -0,0 +1,33 @@
+apt
+apt-utils
+bash-completion
+bsdmainutils
+ca-certificates
+file
+gpgv
+ifupdown
+insserv
+iputils-ping
+less
+libnetfilter-conntrack3
+libnfnetlink0
+mime-support
+netbase
+netcat-openbsd
+netcat-traditional
+net-tools
+openssl
+pciutils
+procps
+psmisc
+python
+python-scapy
+strace
+systemd-sysv
+tcpdump
+traceroute
+udev
+udhcpc
+usbutils
+vim-tiny
+wget
diff --git a/net/test/rootfs/stretch.sh b/net/test/rootfs/stretch.sh
new file mode 100755
index 0000000..6d8a9a4
--- /dev/null
+++ b/net/test/rootfs/stretch.sh
@@ -0,0 +1,150 @@
+#!/bin/bash
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+set -e
+
+SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P)
+
+. $SCRIPT_DIR/common.sh
+
+chroot_sanity_check
+
+cd /root
+
+# Add the needed debian sources
+cat >/etc/apt/sources.list <<EOF
+deb http://ftp.debian.org/debian stretch main
+deb-src http://ftp.debian.org/debian stretch main
+deb http://ftp.debian.org/debian stretch-backports main
+deb-src http://ftp.debian.org/debian stretch-backports main
+deb http://ftp.debian.org/debian buster main
+deb-src http://ftp.debian.org/debian buster main
+EOF
+
+# Make sure apt doesn't want to install from buster by default
+cat >/etc/apt/apt.conf.d/80default <<EOF
+APT::Default-Release "stretch";
+EOF
+
+# Disable the automatic installation of recommended packages
+cat >/etc/apt/apt.conf.d/90recommends <<EOF
+APT::Install-Recommends "0";
+EOF
+
+# Deprioritize buster, so it must be specified manually
+cat >/etc/apt/preferences.d/90buster <<EOF
+Package: *
+Pin: release a=buster
+Pin-Priority: 90
+EOF
+
+# Update for the above changes
+apt-get update
+
+# Install python-scapy from buster, because stretch's version is broken
+apt-get install -y -t buster python-scapy
+
+# Note what we have installed; we will go back to this
+LANG=C dpkg --get-selections | sort >originally-installed
+
+# Install everything needed from stretch to build iptables
+apt-get install -y \
+ build-essential \
+ autoconf \
+ automake \
+ bison \
+ debhelper \
+ devscripts \
+ fakeroot \
+ flex \
+ libmnl-dev \
+ libnetfilter-conntrack-dev \
+ libnfnetlink-dev \
+ libnftnl-dev \
+ libtool
+
+# Install newer linux-libc headers (these are from 4.16)
+apt-get install -y -t stretch-backports linux-libc-dev
+
+# We are done with apt; reclaim the disk space
+apt-get clean
+
+# Construct the iptables source package to build
+iptables=iptables-1.6.1
+mkdir -p /usr/src/$iptables
+
+cd /usr/src/$iptables
+# Download a specific revision of iptables from AOSP
+aosp_iptables=android-wear-p-preview-2
+wget -qO - \
+ https://android.googlesource.com/platform/external/iptables/+archive/$aosp_iptables.tar.gz | \
+ tar -zxf -
+# Download a compatible 'debian' overlay from Debian salsa
+# We don't want all of the sources, just the Debian modifications
+debian_iptables=1.6.1-2_bpo9+1
+debian_iptables_dir=pkg-iptables-debian-$debian_iptables
+wget -qO - \
+ https://salsa.debian.org/pkg-netfilter-team/pkg-iptables/-/archive/debian/$debian_iptables/$debian_iptables_dir.tar.gz | \
+ tar --strip-components 1 -zxf - \
+ $debian_iptables_dir/debian
+cd -
+
+cd /usr/src
+# Generate a source package to leave in the filesystem. This is done for license
+# compliance and build reproducibility.
+tar --exclude=debian -cf - $iptables | \
+ xz -9 >`echo $iptables | tr -s '-' '_'`.orig.tar.xz
+cd -
+
+cd /usr/src/$iptables
+# Build debian packages from the integrated iptables source
+dpkg-buildpackage -F -us -uc
+cd -
+
+# Record the list of packages we have installed now
+LANG=C dpkg --get-selections | sort >installed
+
+# Compute the difference, and remove anything installed between the snapshots
+dpkg -P `comm -3 originally-installed installed | sed -e 's,install,,' -e 's,\t,,' | xargs`
+
+cd /usr/src
+# Find any packages generated, resolve to the debian package name, then
+# exclude any compat, header or symbol packages
+packages=`find -maxdepth 1 -name '*.deb' | colrm 1 2 | cut -d'_' -f1 |
+ grep -ve '-compat$\|-dbg$\|-dbgsym$\|-dev$' | xargs`
+# Install the patched iptables packages, and 'hold' then so
+# "apt-get dist-upgrade" doesn't replace them
+dpkg -i `
+for package in $packages; do
+ echo ${package}_*.deb
+done | xargs`
+for package in $packages; do
+ echo "$package hold" | dpkg --set-selections
+done
+# Tidy up the mess we left behind, leaving just the source tarballs
+rm -rf $iptables *.buildinfo *.changes *.deb *.dsc
+cd -
+
+# Ensure a getty is spawned on ttyS0, if booting the image manually
+ln -s /lib/systemd/system/serial-getty\@.service \
+ /etc/systemd/system/getty.target.wants/serial-getty\@ttyS0.service
+
+# systemd needs some directories to be created
+mkdir -p /var/lib/systemd/coredump /var/lib/systemd/rfkill
+
+# Finalize and tidy up the created image
+chroot_cleanup
diff --git a/net/test/rootfs/wheezy.list b/net/test/rootfs/wheezy.list
new file mode 100644
index 0000000..44e3d85
--- /dev/null
+++ b/net/test/rootfs/wheezy.list
@@ -0,0 +1,33 @@
+adduser
+apt
+apt-utils
+bash-completion
+binutils
+bsdmainutils
+ca-certificates
+file
+gpgv
+ifupdown
+insserv
+iptables
+iputils-ping
+less
+libpopt0
+mime-support
+netbase
+netcat6
+netcat-traditional
+net-tools
+module-init-tools
+openssl
+procps
+psmisc
+python2.7
+python-scapy
+strace
+tcpdump
+traceroute
+udev
+udhcpc
+vim-tiny
+wget
diff --git a/net/test/rootfs/wheezy.sh b/net/test/rootfs/wheezy.sh
new file mode 100755
index 0000000..81cfad7
--- /dev/null
+++ b/net/test/rootfs/wheezy.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# NOTE: It is highly recommended that you do not create new wheezy rootfs
+# images. This script is here for forensic purposes only, to understand
+# how the original rootfs was created.
+
+set -e
+
+SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd -P)
+
+. $SCRIPT_DIR/common.sh
+
+chroot_sanity_check
+
+# Remove things pulled in by debootstrap that we do not need
+dpkg -P \
+ debconf-i18n \
+ liblocale-gettext-perl \
+ libtext-charwidth-perl \
+ libtext-iconv-perl \
+ libtext-wrapi18n-perl \
+ python2.6 \
+ python2.6-minimal \
+ xz-utils
+
+# We are done with apt; reclaim the disk space
+apt-get clean
+
+# Ensure a getty is spawned on ttyS0, if booting the image manually
+# This also removes the vt gettys, as we may have no vt
+sed -i '/tty[123456]/d' /etc/inittab
+echo "s0:1235:respawn:/sbin/getty 115200 ttyS0 linux" >>/etc/inittab
+
+# Finalize and tidy up the created image
+chroot_cleanup
diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh
index e07d10b..a81ad33 100755
--- a/net/test/run_net_test.sh
+++ b/net/test/run_net_test.sh
@@ -1,38 +1,72 @@
#!/bin/bash
-# Kernel configuration options.
+# Builds mysteriously fail if stdout is non-blocking.
+fixup_ptys() {
+ python << 'EOF'
+import fcntl, os, sys
+fd = sys.stdout.fileno()
+flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+flags &= ~(fcntl.FASYNC | os.O_NONBLOCK | os.O_APPEND)
+fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+EOF
+}
+
+# Common kernel options
OPTIONS=" DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES"
+OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT FHANDLE"
OPTIONS="$OPTIONS IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
OPTIONS="$OPTIONS IP_NF_IPTABLES IP_NF_MANGLE IP_NF_FILTER"
OPTIONS="$OPTIONS IP6_NF_IPTABLES IP6_NF_MANGLE IP6_NF_FILTER INET6_IPCOMP"
-OPTIONS="$OPTIONS IPV6_PRIVACY IPV6_OPTIMISTIC_DAD"
-OPTIONS="$OPTIONS CONFIG_IPV6_ROUTE_INFO CONFIG_IPV6_ROUTER_PREF"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2"
-OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG"
-OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY"
-OPTIONS="$OPTIONS IP_SCTP INET_SCTP_DIAG"
-OPTIONS="$OPTIONS CONFIG_IP_NF_TARGET_REJECT CONFIG_IP_NF_TARGET_REJECT_SKERR"
-OPTIONS="$OPTIONS CONFIG_IP6_NF_TARGET_REJECT CONFIG_IP6_NF_TARGET_REJECT_SKERR"
-OPTIONS="$OPTIONS BPF_SYSCALL XFRM_USER CRYPTO_CBC CRYPTO_CTR"
-OPTIONS="$OPTIONS CRYPTO_HMAC CRYPTO_AES CRYPTO_SHA1 CRYPTO_SHA256 CRYPTO_SHA12"
-OPTIONS="$OPTIONS CRYPTO_USER INET_AH INET_ESP INET_XFRM_MODE"
-OPTIONS="$OPTIONS TRANSPORT INET_XFRM_MODE_TUNNEL INET6_AH INET6_ESP"
+OPTIONS="$OPTIONS IPV6_OPTIMISTIC_DAD"
+OPTIONS="$OPTIONS IPV6_ROUTE_INFO IPV6_ROUTER_PREF"
+OPTIONS="$OPTIONS NETFILTER_XT_TARGET_NFLOG"
+OPTIONS="$OPTIONS NETFILTER_XT_MATCH_QUOTA"
+OPTIONS="$OPTIONS NETFILTER_XT_MATCH_QUOTA2"
+OPTIONS="$OPTIONS NETFILTER_XT_MATCH_QUOTA2_LOG"
+OPTIONS="$OPTIONS NETFILTER_XT_MATCH_SOCKET"
+OPTIONS="$OPTIONS NETFILTER_XT_MATCH_QTAGUID"
+OPTIONS="$OPTIONS INET_UDP_DIAG INET_DIAG_DESTROY"
+OPTIONS="$OPTIONS IP_SCTP"
+OPTIONS="$OPTIONS IP_NF_TARGET_REJECT IP_NF_TARGET_REJECT_SKERR"
+OPTIONS="$OPTIONS IP6_NF_TARGET_REJECT IP6_NF_TARGET_REJECT_SKERR"
+OPTIONS="$OPTIONS NET_KEY XFRM_USER XFRM_STATISTICS CRYPTO_CBC"
+OPTIONS="$OPTIONS CRYPTO_CTR CRYPTO_HMAC CRYPTO_AES CRYPTO_SHA1"
+OPTIONS="$OPTIONS CRYPTO_USER INET_ESP INET_XFRM_MODE_TRANSPORT"
+OPTIONS="$OPTIONS INET_XFRM_MODE_TUNNEL INET6_ESP"
OPTIONS="$OPTIONS INET6_XFRM_MODE_TRANSPORT INET6_XFRM_MODE_TUNNEL"
-OPTIONS="$OPTIONS CRYPTO_SHA256 CRYPTO_SHA512 CRYPTO_AES_X86_64"
-OPTIONS="$OPTIONS CRYPTO_ECHAINIV"
+OPTIONS="$OPTIONS CRYPTO_SHA256 CRYPTO_SHA512 CRYPTO_AES_X86_64 CRYPTO_NULL"
+OPTIONS="$OPTIONS CRYPTO_GCM CRYPTO_ECHAINIV NET_IPVTI"
-# For 3.1 kernels, where devtmpfs is not on by default.
-OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
+# Kernel version specific options
+OPTIONS="$OPTIONS XFRM_INTERFACE" # Various device kernels
+OPTIONS="$OPTIONS CGROUP_BPF" # Added in android-4.9
+OPTIONS="$OPTIONS NF_SOCKET_IPV4 NF_SOCKET_IPV6" # Added in 4.9
+OPTIONS="$OPTIONS INET_SCTP_DIAG" # Added in 4.7
+OPTIONS="$OPTIONS SOCK_CGROUP_DATA" # Added in 4.5
+OPTIONS="$OPTIONS CRYPTO_ECHAINIV" # Added in 4.1
+OPTIONS="$OPTIONS BPF_SYSCALL" # Added in 3.18
+OPTIONS="$OPTIONS IPV6_VTI" # Added in 3.13
+OPTIONS="$OPTIONS IPV6_PRIVACY" # Removed in 3.12
+OPTIONS="$OPTIONS NETFILTER_TPROXY" # Removed in 3.11
+
+# UML specific options
+OPTIONS="$OPTIONS BLK_DEV_UBD HOSTFS"
+
+# QEMU specific options
+OPTIONS="$OPTIONS VIRTIO VIRTIO_PCI VIRTIO_BLK NET_9P NET_9P_VIRTIO 9P_FS"
+OPTIONS="$OPTIONS SERIAL_8250 SERIAL_8250_PCI"
+
+# Obsolete options present at some time in Android kernels
+OPTIONS="$OPTIONS IP_NF_TARGET_REJECT_SKERR IP6_NF_TARGET_REJECT_SKERR"
# These two break the flo kernel due to differences in -Werror on recent GCC.
-DISABLE_OPTIONS=" CONFIG_REISERFS_FS CONFIG_ANDROID_PMEM"
+DISABLE_OPTIONS=" REISERFS_FS ANDROID_PMEM"
+
# This one breaks the fugu kernel due to a nonexistent sem_wait_array.
-DISABLE_OPTIONS="$DISABLE_OPTIONS CONFIG_SYSVIPC"
+DISABLE_OPTIONS="$DISABLE_OPTIONS SYSVIPC"
# How many TAP interfaces to create to provide the VM with real network access
# via the host. This requires privileges (e.g., root access) on the host.
@@ -47,11 +81,12 @@
NUMTAPINTERFACES=0
# The root filesystem disk image we'll use.
-ROOTFS=net_test.rootfs.20150203
+ROOTFS=${ROOTFS:-net_test.rootfs.20150203}
COMPRESSED_ROOTFS=$ROOTFS.xz
URL=https://dl.google.com/dl/android/$COMPRESSED_ROOTFS
# Parse arguments and figure out which test to run.
+ARCH=${ARCH:-um}
J=${J:-64}
MAKE="make"
OUT_DIR=$(readlink -f ${OUT_DIR:-.})
@@ -63,25 +98,42 @@
CONFIG_SCRIPT=${KERNEL_DIR}/scripts/config
CONFIG_FILE=${OUT_DIR}/.config
consolemode=
+netconfig=
testmode=
-blockdevice=ubda
+cmdline=
+nowrite=1
nobuild=0
norun=0
-while [ -n "$1" ]; do
- if [ "$1" = "--builder" ]; then
+if tty >/dev/null; then
+ verbose=
+else
+ verbose=1
+fi
+
+while [[ -n "$1" ]]; do
+ if [[ "$1" == "--builder" ]]; then
consolemode="con=null,fd:1"
testmode=builder
shift
- elif [ "$1" == "--readonly" ]; then
- blockdevice="${blockdevice}r"
+ elif [[ "$1" == "--readwrite" || "$1" == "--rw" ]]; then
+ nowrite=0
shift
- elif [ "$1" == "--nobuild" ]; then
+ elif [[ "$1" == "--readonly" || "$1" == "--ro" ]]; then
+ nowrite=1
+ shift
+ elif [[ "$1" == "--nobuild" ]]; then
nobuild=1
shift
- elif [ "$1" == "--norun" ]; then
+ elif [[ "$1" == "--norun" ]]; then
norun=1
shift
+ elif [[ "$1" == "--verbose" ]]; then
+ verbose=1
+ shift
+ elif [[ "$1" == "--noverbose" ]]; then
+ verbose=
+ shift
else
test=$1
break # Arguments after the test file are passed to the test itself.
@@ -113,7 +165,7 @@
if ! isRunningTest && ! isBuildOnly; then
echo "Usage:" >&2
- echo " $0 [--builder] [--readonly] [--nobuild] <test>" >&2
+ echo " $0 [--builder] [--readonly|--ro|--readwrite|--rw] [--nobuild] [--verbose] <test>" >&2
echo " $0 --norun" >&2
exit 1
fi
@@ -143,12 +195,16 @@
if (( $NUMTAPINTERFACES > 0 )); then
user=${USER:0:10}
tapinterfaces=
- netconfig=
for id in $(seq 0 $(( NUMTAPINTERFACES - 1 )) ); do
tap=${user}TAP$id
tapinterfaces="$tapinterfaces $tap"
mac=$(printf fe:fd:00:00:00:%02x $id)
- netconfig="$netconfig eth$id=tuntap,$tap,$mac"
+ if [ "$ARCH" == "um" ]; then
+ netconfig="$netconfig eth$id=tuntap,$tap,$mac"
+ else
+ netconfig="$netconfig -netdev tap,id=hostnet$id,ifname=$tap,script=no,downscript=no"
+ netconfig="$netconfig -device virtio-net-pci,netdev=hostnet$id,id=net$id,mac=$mac"
+ fi
done
for tap in $tapinterfaces; do
@@ -163,49 +219,184 @@
if [ -n "$KERNEL_BINARY" ]; then
nobuild=1
else
- KERNEL_BINARY=./linux
+ # Set default KERNEL_BINARY location if it was not provided.
+ if [ "$ARCH" == "um" ]; then
+ KERNEL_BINARY=./linux
+ elif [ "$ARCH" == "i386" -o "$ARCH" == "x86_64" -o "$ARCH" == "x86" ]; then
+ KERNEL_BINARY=./arch/x86/boot/bzImage
+ elif [ "$ARCH" == "arm64" ]; then
+ KERNEL_BINARY=./arch/arm64/boot/Image.gz
+ fi
fi
if ((nobuild == 0)); then
- # Exporting ARCH=um SUBARCH=x86_64 doesn't seem to work, as it "sometimes"
- # (?) results in a 32-bit kernel.
-
- # If there's no kernel config at all, create one or UML won't work.
- [ -f $CONFIG_FILE ] || (cd $KERNEL_DIR && $MAKE defconfig ARCH=um SUBARCH=x86_64)
-
- # Enable the kernel config options listed in $OPTIONS.
- cmdline=${OPTIONS// / -e }
- $CONFIG_SCRIPT --file $CONFIG_FILE $cmdline
-
- # Disable the kernel config options listed in $DISABLE_OPTIONS.
- cmdline=${DISABLE_OPTIONS// / -d }
- $CONFIG_SCRIPT --file $CONFIG_FILE $cmdline
-
- # olddefconfig doesn't work on old kernels.
- if ! $MAKE olddefconfig ARCH=um SUBARCH=x86_64 CROSS_COMPILE= ; then
- cat >&2 << EOF
-
-Warning: "make olddefconfig" failed.
-Perhaps this kernel is too old to support it.
-You may get asked lots of questions.
-Keep enter pressed to accept the defaults.
-
-EOF
+ make_flags=
+ if [ "$ARCH" == "um" ]; then
+ # Exporting ARCH=um SUBARCH=x86_64 doesn't seem to work, as it
+ # "sometimes" (?) results in a 32-bit kernel.
+ make_flags="$make_flags ARCH=$ARCH SUBARCH=x86_64 CROSS_COMPILE= "
+ fi
+ if [ -n "$CC" ]; then
+ # The CC flag is *not* inherited from the environment, so it must be
+ # passed in on the command line.
+ make_flags="$make_flags CC=$CC"
fi
+ # If there's no kernel config at all, create one or UML won't work.
+ [ -n "$DEFCONFIG" ] || DEFCONFIG=defconfig
+ [ -f $CONFIG_FILE ] || (cd $KERNEL_DIR && $MAKE $make_flags $DEFCONFIG)
+
+ # Enable the kernel config options listed in $OPTIONS.
+ $CONFIG_SCRIPT --file $CONFIG_FILE ${OPTIONS// / -e }
+
+ # Disable the kernel config options listed in $DISABLE_OPTIONS.
+ $CONFIG_SCRIPT --file $CONFIG_FILE ${DISABLE_OPTIONS// / -d }
+
+ $MAKE $make_flags olddefconfig
+
# Compile the kernel.
- $MAKE -j$J linux ARCH=um SUBARCH=x86_64 CROSS_COMPILE=
+ if [ "$ARCH" == "um" ]; then
+ $MAKE -j$J $make_flags linux
+ else
+ $MAKE -j$J $make_flags
+ fi
fi
if (( norun == 1 )); then
exit 0
fi
-# Get the absolute path to the test file that's being run.
-dir=/host$SCRIPT_DIR
+if (( nowrite == 1 )); then
+ cmdline="ro"
+fi
-# Start the VM.
-exec $KERNEL_BINARY umid=net_test $blockdevice=$SCRIPT_DIR/$ROOTFS \
- mem=512M init=/sbin/net_test.sh net_test=$dir/$test \
- net_test_args=\"$test_args\" \
- net_test_mode=$testmode $netconfig $consolemode >&2
+if (( verbose == 1 )); then
+ cmdline="$cmdline verbose=1"
+fi
+
+cmdline="$cmdline init=/sbin/net_test.sh"
+cmdline="$cmdline net_test_args=\"$test_args\" net_test_mode=$testmode"
+
+if [ "$ARCH" == "um" ]; then
+ # Get the absolute path to the test file that's being run.
+ cmdline="$cmdline net_test=/host$SCRIPT_DIR/$test"
+
+ # Use UML's /proc/exitcode feature to communicate errors on test failure
+ cmdline="$cmdline net_test_exitcode=/proc/exitcode"
+
+ # Experience shows that we need at least 128 bits of entropy for the
+ # kernel's crng init to complete (before it fully initializes stuff behaves
+ # *weirdly* and there's plenty of kernel warnings and some tests even fail),
+ # hence net_test.sh needs at least 32 hex chars (which is the amount of hex
+ # in a single random UUID) provided to it on the kernel cmdline.
+ #
+ # Just to be safe, we'll pass in 384 bits, and we'll do this as a random
+ # 64 character base64 seed (because this is shorter than base16).
+ # We do this by getting *three* random UUIDs and concatenating their hex
+ # digits into an *even* length hex encoded string, which we then convert
+ # into base64.
+ entropy="$(cat /proc/sys/kernel/random{/,/,/}uuid | tr -d '\n-')"
+ entropy="$(xxd -r -p <<< "${entropy}" | base64 -w 0)"
+ cmdline="${cmdline} entropy=${entropy}"
+
+ # Map the --readonly flag to UML block device names
+ if ((nowrite == 0)); then
+ blockdevice=ubda
+ else
+ blockdevice=ubdar
+ fi
+
+ exitcode=0
+ $KERNEL_BINARY >&2 umid=net_test mem=512M \
+ $blockdevice=$SCRIPT_DIR/$ROOTFS $netconfig $consolemode $cmdline \
+ || exitcode=$?
+
+ # UML is kind of crazy in how guest syscalls work. It requires host kernel
+ # to not be in vsyscall=none mode.
+ if [[ "${exitcode}" != '0' ]]; then
+ {
+ # Hopefully one of these exists
+ cat /proc/config || :
+ zcat /proc/config.gz || :
+ cat "/boot/config-$(uname -r)" || :
+ zcat "/boot/config-$(uname -r).gz" || :
+ } 2>/dev/null \
+ | egrep -q '^CONFIG_LEGACY_VSYSCALL_NONE=y' \
+ && ! egrep -q '(^| )vsyscall=(native|emulate)( |$)' /proc/cmdline \
+ && {
+ echo '-----=====-----'
+ echo 'If above you saw a "net_test.sh[1]: segfault at ..." followed by'
+ echo '"Kernel panic - not syncing: Attempted to kill init!" then please'
+ echo 'set "vsyscall=emulate" on *host* kernel command line.'
+ echo '(for example via GRUB_CMDLINE_LINUX in /etc/default/grub)'
+ echo '-----=====-----'
+ }
+ fi
+else
+ # We boot into the filesystem image directly in all cases
+ cmdline="$cmdline root=/dev/vda"
+
+ # The path is stripped by the 9p export; we don't need SCRIPT_DIR
+ cmdline="$cmdline net_test=/host/$test"
+
+ # Map the --readonly flag to a QEMU block device flag
+ if ((nowrite > 0)); then
+ blockdevice=",readonly"
+ else
+ blockdevice=
+ fi
+ blockdevice="-drive file=$SCRIPT_DIR/$ROOTFS,format=raw,if=none,id=drive-virtio-disk0$blockdevice"
+ blockdevice="$blockdevice -device virtio-blk-pci,drive=drive-virtio-disk0"
+
+ # Pass through our current console/screen size to inner shell session
+ read rows cols < <(stty size 2>/dev/null)
+ [[ -z "${rows}" ]] || cmdline="${cmdline} console_rows=${rows}"
+ [[ -z "${cols}" ]] || cmdline="${cmdline} console_cols=${cols}"
+ unset rows cols
+
+ # QEMU has no way to modify its exitcode; simulate it with a serial port.
+ #
+ # Choose to do it this way over writing a file to /host, because QEMU will
+ # initialize the 'exitcode' file for us, it avoids unnecessary writes to the
+ # host filesystem (which is normally not written to) and it allows us to
+ # communicate an exit code back in cases we do not have /host mounted.
+ #
+ if [ "$ARCH" == "i386" -o "$ARCH" == "x86_64" -o "$ARCH" == "x86" ]; then
+ # Assume we have hardware-accelerated virtualization support for amd64
+ qemu="qemu-system-x86_64 -machine pc,accel=kvm -cpu host"
+
+ # The assignment of 'ttyS1' here is magical -- we know 'ttyS0' will be our
+ # serial port from the hard-coded '-serial stdio' flag below, and so this
+ # second serial port will be 'ttyS1'.
+ cmdline="$cmdline net_test_exitcode=/dev/ttyS1"
+ elif [ "$ARCH" == "arm64" ]; then
+ # This uses a software model CPU, based on cortex-a57
+ qemu="qemu-system-aarch64 -machine virt -cpu cortex-a57"
+
+ # The kernel will print messages via a virtual ARM serial port (ttyAMA0),
+ # but for command line consistency with x86, we put the exitcode serial
+ # port on the PCI bus, and it will be the only one.
+ cmdline="$cmdline net_test_exitcode=/dev/ttyS0"
+ fi
+
+ $qemu >&2 -name net_test -m 512 \
+ -kernel $KERNEL_BINARY \
+ -no-user-config -nodefaults -no-reboot \
+ -display none -nographic -serial mon:stdio -parallel none \
+ -smp 4,sockets=4,cores=1,threads=1 \
+ -device virtio-rng-pci \
+ -chardev file,id=exitcode,path=exitcode \
+ -device pci-serial,chardev=exitcode \
+ -fsdev local,security_model=mapped-xattr,id=fsdev0,fmode=0644,dmode=0755,path=$SCRIPT_DIR \
+ -device virtio-9p-pci,id=fs0,fsdev=fsdev0,mount_tag=host \
+ $blockdevice $netconfig -append "$cmdline"
+ [[ -s exitcode ]] && exitcode=`cat exitcode | tr -d '\r'` || exitcode=1
+ rm -f exitcode
+fi
+
+# UML reliably screws up the ptys, QEMU probably can as well...
+fixup_ptys
+stty sane || :
+
+echo "Returning exit code ${exitcode}." 1>&2
+exit "${exitcode}"
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
index 1865891..46cc92d 100755
--- a/net/test/sock_diag.py
+++ b/net/test/sock_diag.py
@@ -28,9 +28,6 @@
import net_test
import netlink
-### Base netlink constants. See include/uapi/linux/netlink.h.
-NETLINK_SOCK_DIAG = 4
-
### sock_diag constants. See include/uapi/linux/sock_diag.h.
# Message types.
SOCK_DIAG_BY_FAMILY = 20
@@ -112,9 +109,11 @@
class SockDiag(netlink.NetlinkSocket):
- FAMILY = NETLINK_SOCK_DIAG
NL_DEBUG = []
+ def __init__(self):
+ super(SockDiag, self).__init__(netlink.NETLINK_SOCK_DIAG)
+
def _Decode(self, command, msg, nla_type, nla_data):
"""Decodes netlink attributes to Python types."""
if msg.family == AF_INET or msg.family == AF_INET6:
@@ -375,6 +374,11 @@
sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
+ @staticmethod
+ def GetSocketCookie(s):
+ cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
+ return struct.unpack("=Q", cookie)[0]
+
def FindSockInfoFromFd(self, s):
"""Gets a diag_msg and attrs from the kernel for the specified socket."""
req = self.DiagReqFromSocket(s)
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index c7ac0d4..daa2fa4 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -18,22 +18,69 @@
from errno import * # pylint: disable=wildcard-import
import os
import random
-import re
+import select
from socket import * # pylint: disable=wildcard-import
+import struct
import threading
import time
import unittest
+import cstruct
import multinetwork_base
import net_test
-import netlink
import packets
import sock_diag
import tcp_test
+# Mostly empty structure definition containing only the fields we currently use.
+TcpInfo = cstruct.Struct("TcpInfo", "64xI", "tcpi_rcv_ssthresh")
NUM_SOCKETS = 30
NO_BYTECODE = ""
+LINUX_4_9_OR_ABOVE = net_test.LINUX_VERSION >= (4, 9, 0)
+LINUX_4_19_OR_ABOVE = net_test.LINUX_VERSION >= (4, 19, 0)
+
+IPPROTO_SCTP = 132
+
+def HaveUdpDiag():
+ """Checks if the current kernel has config CONFIG_INET_UDP_DIAG enabled.
+
+ This config is required for device running 4.9 kernel that ship with P, In
+ this case always assume the config is there and use the tests to check if the
+ config is enabled as required.
+
+ For all ther other kernel version, there is no way to tell whether a dump
+ succeeded: if the appropriate handler wasn't found, __inet_diag_dump just
+ returns an empty result instead of an error. So, just check to see if a UDP
+ dump returns no sockets when we know it should return one. If not, some tests
+ will be skipped.
+
+ Returns:
+ True if the kernel is 4.9 or above, or the CONFIG_INET_UDP_DIAG is enabled.
+ False otherwise.
+ """
+ if LINUX_4_9_OR_ABOVE:
+ return True;
+ s = socket(AF_INET6, SOCK_DGRAM, 0)
+ s.bind(("::", 0))
+ s.connect((s.getsockname()))
+ sd = sock_diag.SockDiag()
+ have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0
+ s.close()
+ return have_udp_diag
+
+def HaveSctp():
+ if net_test.LINUX_VERSION < (4, 7, 0):
+ return False
+ try:
+ s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
+ s.close()
+ return True
+ except IOError:
+ return False
+
+HAVE_UDP_DIAG = HaveUdpDiag()
+HAVE_SCTP = HaveSctp()
class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
@@ -108,19 +155,37 @@
self.assertFalse("???" in decoded)
return bytecode
- def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
+ """Simulates an external event during a blocking call on sock.
+
+ Args:
+ sock: The socket to use.
+ call: A function, the call to make. Takes one parameter, sock.
+ expected_errno: The value that call is expected to fail with, or None if
+ call is expected to succeed.
+ event: A function, the event that will happen during the blocking call.
+ Takes one parameter, sock.
+ """
thread = SocketExceptionThread(sock, call)
thread.start()
time.sleep(0.1)
- self.sock_diag.CloseSocketFromFd(sock)
+ event(sock)
thread.join(1)
self.assertFalse(thread.is_alive())
- self.assertIsNotNone(thread.exception)
- self.assertTrue(isinstance(thread.exception, IOError),
- "Expected IOError, got %s" % thread.exception)
- self.assertEqual(expected_errno, thread.exception.errno)
+ if expected_errno is not None:
+ self.assertIsNotNone(thread.exception)
+ self.assertTrue(isinstance(thread.exception, IOError),
+ "Expected IOError, got %s" % thread.exception)
+ self.assertEqual(expected_errno, thread.exception.errno)
+ else:
+ self.assertIsNone(thread.exception)
self.assertSocketClosed(sock)
+ def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ self._EventDuringBlockingCall(
+ sock, call, expected_errno,
+ lambda sock: self.sock_diag.CloseSocketFromFd(sock))
+
def setUp(self):
super(SockDiagBaseTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
@@ -145,10 +210,10 @@
self.sock_diag.GetSockInfo(diag_req)
# No errors? Good.
- def testFindsAllMySockets(self):
+ def CheckFindsAllMySockets(self, socktype, proto):
"""Tests that basic socket dumping works."""
- self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
- sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
+ self.socketpairs = self._CreateLotsOfSockets(socktype)
+ sockets = self.sock_diag.DumpAllInetSockets(proto, NO_BYTECODE)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
# Find the cookies for all of our sockets.
@@ -178,9 +243,21 @@
# Check that we can find a diag_msg once we know the cookie.
req = self.sock_diag.DiagReqFromSocket(sock)
req.id.cookie = cookie
+ if proto == IPPROTO_UDP:
+ # Kernel bug: for UDP sockets, the order of arguments must be swapped.
+ # See testDemonstrateUdpGetSockIdBug.
+ req.id.sport, req.id.dport = req.id.dport, req.id.sport
+ req.id.src, req.id.dst = req.id.dst, req.id.src
info = self.sock_diag.GetSockInfo(req)
self.assertSockInfoMatchesSocket(sock, info)
+ def testFindsAllMySocketsTcp(self):
+ self.CheckFindsAllMySockets(SOCK_STREAM, IPPROTO_TCP)
+
+ @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
+ def testFindsAllMySocketsUdp(self):
+ self.CheckFindsAllMySockets(SOCK_DGRAM, IPPROTO_UDP)
+
def testBytecodeCompilation(self):
# pylint: disable=bad-whitespace
instructions = [
@@ -305,6 +382,53 @@
DiagDump(op) # No errors? Good.
self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+ def CheckSocketCookie(self, inet, addr):
+ """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
+ socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
+ for sock in socketpair:
+ diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
+ cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
+ self.assertEqual(diag_msg.id.cookie, cookie)
+
+ @unittest.skipUnless(LINUX_4_9_OR_ABOVE, "SO_COOKIE not supported")
+ def testGetsockoptcookie(self):
+ self.CheckSocketCookie(AF_INET, "127.0.0.1")
+ self.CheckSocketCookie(AF_INET6, "::1")
+
+ @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
+ def testDemonstrateUdpGetSockIdBug(self):
+ # TODO: this is because udp_dump_one mistakenly uses __udp[46]_lib_lookup
+ # by passing the source address as the source address argument.
+ # Unfortunately those functions are intended to match local sockets based
+ # on received packets, and the argument that ends up being compared with
+ # e.g., sk_daddr is actually saddr, not daddr. udp_diag_destroy does not
+ # have this bug. Upstream has confirmed that this will not be fixed:
+ # https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html
+ """Documents a bug: getting UDP sockets requires swapping src and dst."""
+ for version in [4, 5, 6]:
+ family = net_test.GetAddressFamily(version)
+ s = socket(family, SOCK_DGRAM, 0)
+ self.SelectInterface(s, self.RandomNetid(), "mark")
+ s.connect((self.GetRemoteSocketAddress(version), 53))
+
+ # Create a fully-specified diag req from our socket, including cookie if
+ # we can get it.
+ req = self.sock_diag.DiagReqFromSocket(s)
+ if LINUX_4_9_OR_ABOVE:
+ req.id.cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
+ else:
+ req.id.cookie = "\xff" * 16 # INET_DIAG_NOCOOKIE[2]
+
+ # As is, this request does not find anything.
+ with self.assertRaisesErrno(ENOENT):
+ self.sock_diag.GetSockInfo(req)
+
+ # But if we swap src and dst, the kernel finds our socket.
+ req.id.sport, req.id.dport = req.id.dport, req.id.sport
+ req.id.src, req.id.dst = req.id.dst, req.id.src
+
+ self.assertSockInfoMatchesSocket(s, self.sock_diag.GetSockInfo(req))
+
class SockDestroyTest(SockDiagBaseTest):
"""Tests that SOCK_DESTROY works correctly.
@@ -397,7 +521,7 @@
def run(self):
try:
self.operation(self.sock)
- except IOError, e:
+ except (IOError, AssertionError), e:
self.exception = e
@@ -421,12 +545,56 @@
self.assertTrue(children)
for child, unused_args in children:
self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
- self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
+ self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
child.id.dst)
- self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
+ self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
child.id.src)
+class TcpRcvWindowTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
+
+ RWND_SIZE = 64000 if LINUX_4_19_OR_ABOVE else 42000
+ TCP_DEFAULT_INIT_RWND = "/proc/sys/net/ipv4/tcp_default_init_rwnd"
+
+ def setUp(self):
+ super(TcpRcvWindowTest, self).setUp()
+ if LINUX_4_19_OR_ABOVE:
+ self.assertRaisesErrno(ENOENT, open, self.TCP_DEFAULT_INIT_RWND, "w")
+ return
+
+ f = open(self.TCP_DEFAULT_INIT_RWND, "w")
+ f.write("60")
+
+ def checkInitRwndSize(self, version, netid):
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
+ tcpInfo = TcpInfo(self.accepted.getsockopt(net_test.SOL_TCP,
+ net_test.TCP_INFO, len(TcpInfo)))
+ self.assertLess(self.RWND_SIZE, tcpInfo.tcpi_rcv_ssthresh,
+ "Tcp rwnd of netid=%d, version=%d is not enough. "
+ "Expect: %d, actual: %d" % (netid, version, self.RWND_SIZE,
+ tcpInfo.tcpi_rcv_ssthresh))
+
+ def checkSynPacketWindowSize(self, version, netid):
+ s = self.BuildSocket(version, net_test.TCPSocket, netid, "mark")
+ myaddr = self.MyAddress(version, netid)
+ dstaddr = self.GetRemoteAddress(version)
+ dstsockaddr = self.GetRemoteSocketAddress(version)
+ desc, expected = packets.SYN(53, version, myaddr, dstaddr,
+ sport=None, seq=None)
+ self.assertRaisesErrno(EINPROGRESS, s.connect, (dstsockaddr, 53))
+ msg = "IPv%s TCP connect: expected %s on %s" % (
+ version, desc, self.GetInterfaceName(netid))
+ syn = self.ExpectPacketOn(netid, msg, expected)
+ self.assertLess(self.RWND_SIZE, syn.window)
+ s.close()
+
+ def testTcpCwndSize(self):
+ for version in [4, 5, 6]:
+ for netid in self.NETIDS:
+ self.checkInitRwndSize(version, netid)
+ self.checkSynPacketWindowSize(version, netid)
+
+
class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
def setUp(self):
@@ -599,9 +767,12 @@
"""Tests that accept() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
+ self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
self.assertRaisesErrno(EINVAL, self.s.accept)
+ # TODO: this should really return an error such as ENOTCONN...
+ self.assertEquals("", self.s.recv(4096))
def testReadInterrupted(self):
"""Tests that read() is interrupted by SOCK_DESTROY."""
@@ -609,7 +780,10 @@
self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
ECONNABORTED)
+ # Writing returns EPIPE, and reading returns EOF.
self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+ self.assertEquals("", self.accepted.recv(4096))
+ self.assertEquals("", self.accepted.recv(4096))
def testConnectInterrupted(self):
"""Tests that connect() is interrupted by SOCK_DESTROY."""
@@ -617,15 +791,13 @@
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
self.SelectInterface(s, self.netid, "mark")
- if version == 5:
- remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
- version = 4
- else:
- remoteaddr = self.GetRemoteAddress(version)
+
+ remotesockaddr = self.GetRemoteSocketAddress(version)
+ remoteaddr = self.GetRemoteAddress(version)
s.bind(("", 0))
_, sport = s.getsockname()[:2]
self.CloseDuringBlockingCall(
- s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
+ s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
remoteaddr, sport=sport, seq=None)
self.ExpectPacketOn(self.netid, desc, syn)
@@ -633,6 +805,106 @@
self.ExpectNoPacketsOn(self.netid, msg)
+class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
+ """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
+
+ The behaviour of poll() in these cases is not what we might expect: if only
+ POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
+ is (also) specified, it will only return POLLOUT.
+ """
+
+ POLLIN_OUT = select.POLLIN | select.POLLOUT
+ POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
+
+ def setUp(self):
+ super(PollOnCloseTest, self).setUp()
+ self.netid = random.choice(self.tuns.keys())
+
+ POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
+ (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
+
+ def PollResultToString(self, poll_events, ignoremask):
+ out = []
+ for fd, event in poll_events:
+ flags = [name for (flag, name) in self.POLL_FLAGS
+ if event & flag & ~ignoremask != 0]
+ out.append((fd, "|".join(flags)))
+ return out
+
+ def BlockingPoll(self, sock, mask, expected, ignoremask):
+ p = select.poll()
+ p.register(sock, mask)
+ expected_fds = [(sock.fileno(), expected)]
+ # Don't block forever or we'll hang continuous test runs on failure.
+ # A 5-second timeout should be long enough not to be flaky.
+ actual_fds = p.poll(5000)
+ self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
+ self.PollResultToString(actual_fds, ignoremask))
+
+ def RstDuringBlockingCall(self, sock, call, expected_errno):
+ self._EventDuringBlockingCall(
+ sock, call, expected_errno,
+ lambda _: self.ReceiveRstPacketOn(self.netid))
+
+ def assertSocketErrors(self, errno):
+ # The first operation returns the expected errno.
+ self.assertRaisesErrno(errno, self.accepted.recv, 4096)
+
+ # Subsequent operations behave as normal.
+ self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+ self.assertEquals("", self.accepted.recv(4096))
+ self.assertEquals("", self.accepted.recv(4096))
+
+ def CheckPollDestroy(self, mask, expected, ignoremask):
+ """Interrupts a poll() with SOCK_DESTROY."""
+ for version in [4, 5, 6]:
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+ self.CloseDuringBlockingCall(
+ self.accepted,
+ lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
+ None)
+ self.assertSocketErrors(ECONNABORTED)
+
+ def CheckPollRst(self, mask, expected, ignoremask):
+ """Interrupts a poll() by receiving a TCP RST."""
+ for version in [4, 5, 6]:
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+ self.RstDuringBlockingCall(
+ self.accepted,
+ lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
+ None)
+ self.assertSocketErrors(ECONNRESET)
+
+ def testReadPollRst(self):
+ # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
+ # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
+ # is due to a race inside the kernel and thus is not visible on the VM, only
+ # on physical hardware.
+ if net_test.LINUX_VERSION < (4, 14, 0):
+ ignoremask = select.POLLIN | select.POLLHUP
+ else:
+ ignoremask = 0
+ self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
+
+ def testWritePollRst(self):
+ self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
+
+ def testReadWritePollRst(self):
+ self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
+
+ def testReadPollDestroy(self):
+ # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
+ ignoremask = select.POLLIN | select.POLLHUP
+ self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
+
+ def testWritePollDestroy(self):
+ self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
+
+ def testReadWritePollDestroy(self):
+ self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
+
+
+@unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
class SockDestroyUdpTest(SockDiagBaseTest):
"""Tests SOCK_DESTROY on UDP sockets.
@@ -672,7 +944,7 @@
def testSocketAddressesAfterClose(self):
for version in 4, 5, 6:
netid = random.choice(self.NETIDS)
- dst = self.GetRemoteAddress(version)
+ dst = self.GetRemoteSocketAddress(version)
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
@@ -686,7 +958,7 @@
# Closing a socket bound to an IP address leaves the address as is.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
- src = self.MyAddress(version, netid)
+ src = self.MySocketAddress(version, netid)
s.bind((src, 0))
s.connect((dst, 53))
port = s.getsockname()[1]
@@ -702,7 +974,7 @@
# Closing a socket bound to IP address and port leaves both as is.
s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
- src = self.MyAddress(version, netid)
+ src = self.MySocketAddress(version, netid)
port = self.BindToRandomPort(s, src)
self.sock_diag.CloseSocketFromFd(s)
self.assertEqual((src, port), s.getsockname()[:2])
@@ -713,7 +985,7 @@
family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
s = net_test.UDPSocket(family)
self.SelectInterface(s, random.choice(self.NETIDS), "mark")
- addr = self.GetRemoteAddress(version)
+ addr = self.GetRemoteSocketAddress(version)
# Check that reads on connected sockets are interrupted.
s.connect((addr, 53))
@@ -749,6 +1021,7 @@
self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
+ @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
def testUdp(self):
self.CheckPermissions(SOCK_DGRAM)
@@ -767,8 +1040,6 @@
d545cac net: inet: diag: expose the socket mark to privileged processes.
"""
- IPPROTO_SCTP = 132
-
def FilterEstablishedSockets(self, mark, mask):
instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
bytecode = self.sock_diag.PackBytecode(instructions)
@@ -883,21 +1154,21 @@
# Other TCP states are tested in SockDestroyTcpTest.
# UDP sockets.
- s = socket(family, SOCK_DGRAM, 0)
- mark = self.SetRandomMark(s)
- s.connect(("", 53))
- self.assertSocketMarkIs(s, mark)
- s.close()
+ if HAVE_UDP_DIAG:
+ s = socket(family, SOCK_DGRAM, 0)
+ mark = self.SetRandomMark(s)
+ s.connect(("", 53))
+ self.assertSocketMarkIs(s, mark)
+ s.close()
# Basic test for SCTP. sctp_diag was only added in 4.7.
- if net_test.LINUX_VERSION >= (4, 7, 0):
- s = socket(family, SOCK_STREAM, self.IPPROTO_SCTP)
+ if HAVE_SCTP:
+ s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
s.bind((addr, 0))
s.listen(1)
mark = self.SetRandomMark(s)
self.assertSocketMarkIs(s, mark)
- sockets = self.sock_diag.DumpAllInetSockets(self.IPPROTO_SCTP,
- NO_BYTECODE)
+ sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
self.assertEqual(1, len(sockets))
self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
s.close()
diff --git a/net/test/tcp_fastopen_test.py b/net/test/tcp_fastopen_test.py
new file mode 100755
index 0000000..9257a19
--- /dev/null
+++ b/net/test/tcp_fastopen_test.py
@@ -0,0 +1,132 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from errno import *
+from socket import *
+from scapy import all as scapy
+
+import multinetwork_base
+import net_test
+import packets
+import tcp_metrics
+
+
+TCPOPT_FASTOPEN = 34
+TCP_FASTOPEN_CONNECT = 30
+
+
+class TcpFastOpenTest(multinetwork_base.MultiNetworkBaseTest):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TcpFastOpenTest, cls).setUpClass()
+ cls.tcp_metrics = tcp_metrics.TcpMetrics()
+
+ def TFOClientSocket(self, version, netid):
+ s = net_test.TCPSocket(net_test.GetAddressFamily(version))
+ net_test.DisableFinWait(s)
+ self.SelectInterface(s, netid, "mark")
+ s.setsockopt(IPPROTO_TCP, TCP_FASTOPEN_CONNECT, 1)
+ return s
+
+ def assertSocketNotConnected(self, sock):
+ self.assertRaisesErrno(ENOTCONN, sock.getpeername)
+
+ def assertSocketConnected(self, sock):
+ sock.getpeername() # No errors? Socket is alive and connected.
+
+ def clearTcpMetrics(self, version, netid):
+ saddr = self.MyAddress(version, netid)
+ daddr = self.GetRemoteAddress(version)
+ self.tcp_metrics.DelMetrics(saddr, daddr)
+ with self.assertRaisesErrno(ESRCH):
+ print self.tcp_metrics.GetMetrics(saddr, daddr)
+
+ def assertNoTcpMetrics(self, version, netid):
+ saddr = self.MyAddress(version, netid)
+ daddr = self.GetRemoteAddress(version)
+ with self.assertRaisesErrno(ENOENT):
+ self.tcp_metrics.GetMetrics(saddr, daddr)
+
+ def CheckConnectOption(self, version):
+ ip_layer = {4: scapy.IP, 6: scapy.IPv6}[version]
+ netid = self.RandomNetid()
+ s = self.TFOClientSocket(version, netid)
+
+ self.clearTcpMetrics(version, netid)
+
+ # Connect the first time.
+ remoteaddr = self.GetRemoteAddress(version)
+ with self.assertRaisesErrno(EINPROGRESS):
+ s.connect((remoteaddr, 53))
+ self.assertSocketNotConnected(s)
+
+ # Expect a SYN handshake with an empty TFO option.
+ myaddr = self.MyAddress(version, netid)
+ port = s.getsockname()[1]
+ self.assertNotEqual(0, port)
+ desc, syn = packets.SYN(53, version, myaddr, remoteaddr, port, seq=None)
+ syn.getlayer("TCP").options = [(TCPOPT_FASTOPEN, "")]
+ msg = "Fastopen connect: expected %s" % desc
+ syn = self.ExpectPacketOn(netid, msg, syn)
+ syn = ip_layer(str(syn))
+
+ # Receive a SYN+ACK with a TFO cookie and expect the connection to proceed
+ # as normal.
+ desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn)
+ synack.getlayer("TCP").options = [
+ (TCPOPT_FASTOPEN, "helloT"), ("NOP", None), ("NOP", None)]
+ self.ReceivePacketOn(netid, synack)
+ synack = ip_layer(str(synack))
+ desc, ack = packets.ACK(version, myaddr, remoteaddr, synack)
+ msg = "First connect: got SYN+ACK, expected %s" % desc
+ self.ExpectPacketOn(netid, msg, ack)
+ self.assertSocketConnected(s)
+ s.close()
+ desc, rst = packets.RST(version, myaddr, remoteaddr, synack)
+ msg = "Closing client socket, expecting %s" % desc
+ self.ExpectPacketOn(netid, msg, rst)
+
+ # Connect to the same destination again. Expect the connect to succeed
+ # without sending a SYN packet.
+ s = self.TFOClientSocket(version, netid)
+ s.connect((remoteaddr, 53))
+ self.assertSocketNotConnected(s)
+ self.ExpectNoPacketsOn(netid, "Second TFO connect, expected no packets")
+
+ # Issue a write and expect a SYN with data.
+ port = s.getsockname()[1]
+ s.send(net_test.UDP_PAYLOAD)
+ desc, syn = packets.SYN(53, version, myaddr, remoteaddr, port, seq=None)
+ t = syn.getlayer(scapy.TCP)
+ t.options = [ (TCPOPT_FASTOPEN, "helloT"), ("NOP", None), ("NOP", None)]
+ t.payload = scapy.Raw(net_test.UDP_PAYLOAD)
+ msg = "TFO write, expected %s" % desc
+ self.ExpectPacketOn(netid, msg, syn)
+
+ @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not yet backported")
+ def testConnectOptionIPv4(self):
+ self.CheckConnectOption(4)
+
+ @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not yet backported")
+ def testConnectOptionIPv6(self):
+ self.CheckConnectOption(6)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/net/test/tcp_metrics.py b/net/test/tcp_metrics.py
new file mode 100755
index 0000000..574a755
--- /dev/null
+++ b/net/test/tcp_metrics.py
@@ -0,0 +1,137 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Generic netlink interface to TCP metrics."""
+
+from socket import * # pylint: disable=wildcard-import
+import struct
+
+import cstruct
+import genetlink
+import net_test
+import netlink
+
+
+### TCP metrics constants. See include/uapi/linux/tcp_metrics.h.
+# Family name and version
+TCP_METRICS_GENL_NAME = "tcp_metrics"
+TCP_METRICS_GENL_VERSION = 1
+
+# Message types.
+TCP_METRICS_CMD_GET = 1
+TCP_METRICS_CMD_DEL = 2
+
+# Attributes.
+TCP_METRICS_ATTR_UNSPEC = 0
+TCP_METRICS_ATTR_ADDR_IPV4 = 1
+TCP_METRICS_ATTR_ADDR_IPV6 = 2
+TCP_METRICS_ATTR_AGE = 3
+TCP_METRICS_ATTR_TW_TSVAL = 4
+TCP_METRICS_ATTR_TW_TS_STAMP = 5
+TCP_METRICS_ATTR_VALS = 6
+TCP_METRICS_ATTR_FOPEN_MSS = 7
+TCP_METRICS_ATTR_FOPEN_SYN_DROPS = 8
+TCP_METRICS_ATTR_FOPEN_SYN_DROP_TS = 9
+TCP_METRICS_ATTR_FOPEN_COOKIE = 10
+TCP_METRICS_ATTR_SADDR_IPV4 = 11
+TCP_METRICS_ATTR_SADDR_IPV6 = 12
+TCP_METRICS_ATTR_PAD = 13
+
+
+class TcpMetrics(genetlink.GenericNetlink):
+
+ NL_DEBUG = ["ALL"]
+
+ def __init__(self):
+ super(TcpMetrics, self).__init__()
+ # Generic netlink family IDs are dynamically assigned. Find ours.
+ ctrl = genetlink.GenericNetlinkControl()
+ self.family = ctrl.GetFamily(TCP_METRICS_GENL_NAME)
+
+ def _Decode(self, command, msg, nla_type, nla_data):
+ """Decodes TCP metrics netlink attributes to human-readable format."""
+
+ name = self._GetConstantName(__name__, nla_type, "TCP_METRICS_ATTR_")
+
+ if name in ["TCP_METRICS_ATTR_ADDR_IPV4", "TCP_METRICS_ATTR_SADDR_IPV4"]:
+ data = inet_ntop(AF_INET, nla_data)
+ elif name in ["TCP_METRICS_ATTR_ADDR_IPV6", "TCP_METRICS_ATTR_SADDR_IPV6"]:
+ data = inet_ntop(AF_INET6, nla_data)
+ elif name in ["TCP_METRICS_ATTR_AGE"]:
+ data = struct.unpack("=Q", nla_data)[0]
+ elif name in ["TCP_METRICS_ATTR_TW_TSVAL", "TCP_METRICS_ATTR_TW_TS_STAMP"]:
+ data = struct.unpack("=I", nla_data)[0]
+ elif name == "TCP_METRICS_ATTR_FOPEN_MSS":
+ data = struct.unpack("=H", nla_data)[0]
+ elif name == "TCP_METRICS_ATTR_FOPEN_COOKIE":
+ data = nla_data
+ else:
+ data = nla_data.encode("hex")
+
+ return name, data
+
+ def MaybeDebugCommand(self, command, unused_flags, data):
+ if "ALL" not in self.NL_DEBUG and command not in self.NL_DEBUG:
+ return
+ parsed = self._ParseNLMsg(data, genetlink.Genlmsghdr)
+
+ def _NlAttrSaddr(self, address):
+ if ":" not in address:
+ family = AF_INET
+ nla_type = TCP_METRICS_ATTR_SADDR_IPV4
+ else:
+ family = AF_INET6
+ nla_type = TCP_METRICS_ATTR_SADDR_IPV6
+ return self._NlAttrIPAddress(nla_type, family, address)
+
+ def _NlAttrTcpMetricsAddr(self, address, is_source):
+ version = net_test.GetAddressVersion(address)
+ family = net_test.GetAddressFamily(version)
+ if version == 5:
+ address = address.replace("::ffff:", "")
+ nla_name = "TCP_METRICS_ATTR_%s_IPV%d" % (
+ "SADDR" if is_source else "ADDR", version)
+ nla_type = globals()[nla_name]
+ return self._NlAttrIPAddress(nla_type, family, address)
+
+ def _NlAttrAddr(self, address):
+ return self._NlAttrTcpMetricsAddr(address, False)
+
+ def _NlAttrSaddr(self, address):
+ return self._NlAttrTcpMetricsAddr(address, True)
+
+ def DumpMetrics(self):
+ """Dumps all TCP metrics."""
+ return self._Dump(self.family, TCP_METRICS_CMD_GET, 1)
+
+ def GetMetrics(self, saddr, daddr):
+ """Returns TCP metrics for the specified src/dst pair."""
+ data = self._NlAttrSaddr(saddr) + self._NlAttrAddr(daddr)
+ self._SendCommand(self.family, TCP_METRICS_CMD_GET, 1, data,
+ netlink.NLM_F_REQUEST)
+ hdr, attrs = self._GetMsg(genetlink.Genlmsghdr)
+ return attrs
+
+ def DelMetrics(self, saddr, daddr):
+ """Deletes TCP metrics for the specified src/dst pair."""
+ data = self._NlAttrSaddr(saddr) + self._NlAttrAddr(daddr)
+ self._SendCommand(self.family, TCP_METRICS_CMD_DEL, 1, data,
+ netlink.NLM_F_REQUEST)
+
+
+if __name__ == "__main__":
+ t = TcpMetrics()
+ print t.DumpMetrics()
diff --git a/net/test/tcp_test.py b/net/test/tcp_test.py
index e60f0d1..5043d46 100644
--- a/net/test/tcp_test.py
+++ b/net/test/tcp_test.py
@@ -64,6 +64,14 @@
super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
self.last_packet = packet
+ def ReceiveRstPacketOn(self, netid):
+ # self.last_packet is the last packet we received. Invert direction twice.
+ _, ack = packets.ACK(self.version, self.myaddr, self.remoteaddr,
+ self.last_packet)
+ desc, rst = packets.RST(self.version, self.remoteaddr, self.myaddr,
+ ack)
+ super(TcpBaseTest, self).ReceivePacketOn(netid, rst)
+
def RstPacket(self):
return packets.RST(self.version, self.myaddr, self.remoteaddr,
self.last_packet)
@@ -78,7 +86,10 @@
self.end_state = end_state
remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+ remotesockaddr = self.remotesockaddr = self.GetRemoteSocketAddress(version)
+
myaddr = self.myaddr = self.MyAddress(version, netid)
+ mysockaddr = self.mysockaddr = self.MySocketAddress(version, netid)
if version == 5: version = 4
self.version = version
diff --git a/net/test/tun_twister.py b/net/test/tun_twister.py
new file mode 100644
index 0000000..2ed25c9
--- /dev/null
+++ b/net/test/tun_twister.py
@@ -0,0 +1,214 @@
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A utility for "twisting" packets on a tun/tap interface.
+
+TunTwister and TapTwister echo packets on a tun/tap while swapping the source
+and destination at the ethernet and IP layers. This allows sockets to
+effectively loop back packets through the full networking stack, avoiding any
+shortcuts the kernel may take for actual IP loopback. Additionally, users can
+inspect each packet to assert testing invariants.
+"""
+
+import os
+import select
+import threading
+from scapy import all as scapy
+
+
+class TunTwister(object):
+ """TunTwister transports traffic travelling twixt two terminals.
+
+ TunTwister is a context manager that will read packets from a tun file
+ descriptor, swap the source and dest of the IP header, and write them back.
+ To use this class, tests also need to set up routing so that packets will be
+ routed to the tun interface.
+
+ Two sockets can communicate with each other through a TunTwister as if they
+ were each connecting to a remote endpoint. Both sockets will have the
+ perspective that the address of the other is a remote address.
+
+ Packet inspection can be done with a validator function. This can be any
+ function that takes a scapy packet object as its only argument. Exceptions
+ raised by your validator function will be re-raised on the main thread to fail
+ your tests.
+
+ NOTE: Exceptions raised by a validator function will supercede exceptions
+ raised in the context.
+
+ EXAMPLE:
+ def testFeatureFoo(self):
+ my_tun = MakeTunInterface()
+ # Set up routing so packets go to my_tun.
+
+ def ValidatePortNumber(packet):
+ self.assertEquals(8080, packet.getlayer(scapy.UDP).sport)
+ self.assertEquals(8080, packet.getlayer(scapy.UDP).dport)
+
+ with TunTwister(tun_fd=my_tun, validator=ValidatePortNumber):
+ sock = socket(AF_INET, SOCK_DGRAM, 0)
+ sock.bind(("0.0.0.0", 8080))
+ sock.settimeout(1.0)
+ sock.sendto("hello", ("1.2.3.4", 8080))
+ data, addr = sock.recvfrom(1024)
+ self.assertEquals("hello", data)
+ self.assertEquals(("1.2.3.4", 8080), addr)
+ """
+
+ # Hopefully larger than any packet.
+ _READ_BUF_SIZE = 2048
+ _POLL_TIMEOUT_SEC = 2.0
+ _POLL_FAST_TIMEOUT_MS = 100
+
+ def __init__(self, fd=None, validator=None):
+ """Construct a TunTwister.
+
+ The TunTwister will listen on the given TUN fd.
+ The validator is called for each packet *before* twisting. The packet is
+ passed in as a scapy packet object, and is the only argument passed to the
+ validator.
+
+ Args:
+ fd: File descriptor of a TUN interface.
+ validator: Function taking one scapy packet object argument.
+ """
+ self._fd = fd
+ # Use a pipe to signal the thread to exit.
+ self._signal_read, self._signal_write = os.pipe()
+ self._thread = threading.Thread(target=self._RunLoop, name="TunTwister")
+ self._validator = validator
+ self._error = None
+
+ def __enter__(self):
+ self._thread.start()
+
+ def __exit__(self, *args):
+ # Signal thread exit.
+ os.write(self._signal_write, "bye")
+ os.close(self._signal_write)
+ self._thread.join(TunTwister._POLL_TIMEOUT_SEC)
+ os.close(self._signal_read)
+ if self._thread.isAlive():
+ raise RuntimeError("Timed out waiting for thread exit")
+ # Re-raise any error thrown from our thread.
+ if isinstance(self._error, Exception):
+ raise self._error # pylint: disable=raising-bad-type
+
+ def _RunLoop(self):
+ """Twist packets until exit signal."""
+ try:
+ while True:
+ read_fds, _, _ = select.select([self._fd, self._signal_read], [], [],
+ TunTwister._POLL_TIMEOUT_SEC)
+ if self._signal_read in read_fds:
+ self._Flush()
+ return
+ if self._fd in read_fds:
+ self._ProcessPacket()
+ except Exception as e: # pylint: disable=broad-except
+ self._error = e
+
+ def _Flush(self):
+ """Ensure no packets are left in the buffer."""
+ p = select.poll()
+ p.register(self._fd, select.POLLIN)
+ while p.poll(TunTwister._POLL_FAST_TIMEOUT_MS):
+ self._ProcessPacket()
+
+ def _ProcessPacket(self):
+ """Read, twist, and write one packet on the tun/tap."""
+ # TODO: Handle EAGAIN "errors".
+ bytes_in = os.read(self._fd, TunTwister._READ_BUF_SIZE)
+ packet = self.DecodePacket(bytes_in)
+ # the user may wish to filter certain packets, such as
+ # Ethernet multicast packets
+ if self._DropPacket(packet):
+ return
+
+ if self._validator:
+ self._validator(packet)
+ packet = self.TwistPacket(packet)
+ os.write(self._fd, packet.build())
+
+ def _DropPacket(self, packet):
+ """Determine whether to drop the provided packet by inspection"""
+ return False
+
+ @classmethod
+ def DecodePacket(cls, bytes_in):
+ """Decode a byte array into a scapy object."""
+ return cls._DecodeIpPacket(bytes_in)
+
+ @classmethod
+ def TwistPacket(cls, packet):
+ """Swap the src and dst in the IP header."""
+ ip_type = type(packet)
+ if ip_type not in (scapy.IP, scapy.IPv6):
+ raise TypeError("Expected an IPv4 or IPv6 packet.")
+ packet.src, packet.dst = packet.dst, packet.src
+ packet = ip_type(packet.build()) # Fix the IP checksum.
+ return packet
+
+ @staticmethod
+ def _DecodeIpPacket(packet_bytes):
+ """Decode 'packet_bytes' as an IPv4 or IPv6 scapy object."""
+ ip_ver = (ord(packet_bytes[0]) & 0xF0) >> 4
+ if ip_ver == 4:
+ return scapy.IP(packet_bytes)
+ elif ip_ver == 6:
+ return scapy.IPv6(packet_bytes)
+ else:
+ raise ValueError("packet_bytes is not a valid IPv4 or IPv6 packet")
+
+
+class TapTwister(TunTwister):
+ """Test util for tap interfaces.
+
+ TapTwister works just like TunTwister, except it operates on tap interfaces
+ instead of tuns. Ethernet headers will have their sources and destinations
+ swapped in addition to IP headers.
+ """
+
+ @staticmethod
+ def _IsMulticastPacket(eth_pkt):
+ return int(eth_pkt.dst.split(":")[0], 16) & 0x1
+
+ def __init__(self, fd=None, validator=None, drop_multicast=True):
+ """Construct a TapTwister.
+
+ TapTwister works just like TunTwister, but handles both ethernet and IP
+ headers.
+
+ Args:
+ fd: File descriptor of a TAP interface.
+ validator: Function taking one scapy packet object argument.
+ drop_multicast: Drop Ethernet multicast packets
+ """
+ super(TapTwister, self).__init__(fd=fd, validator=validator)
+ self._drop_multicast = drop_multicast
+
+ def _DropPacket(self, packet):
+ return self._drop_multicast and self._IsMulticastPacket(packet)
+
+ @classmethod
+ def DecodePacket(cls, bytes_in):
+ return scapy.Ether(bytes_in)
+
+ @classmethod
+ def TwistPacket(cls, packet):
+ """Swap the src and dst in the ethernet and IP headers."""
+ packet.src, packet.dst = packet.dst, packet.src
+ ip_layer = packet.payload
+ twisted_ip_layer = super(TapTwister, cls).TwistPacket(ip_layer)
+ packet.payload = twisted_ip_layer
+ return packet
diff --git a/net/test/util.py b/net/test/util.py
new file mode 100644
index 0000000..cbcd2d0
--- /dev/null
+++ b/net/test/util.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def GetPadLength(block_size, length):
+ return (block_size - (length % block_size)) % block_size
+
+
+def InjectParameterizedTest(cls, param_list, name_generator):
+ """Injects parameterized tests into the provided class
+
+ This method searches for all tests that start with the name "ParamTest",
+ and injects a test method for each set of parameters in param_list. Names
+ are generated via the use of the name_generator.
+
+ Args:
+ cls: the class for which to inject all parameterized tests
+ param_list: a list of tuples, where each tuple is a combination of
+ of parameters to test (i.e. representing a single test case)
+ name_generator: A function that takes a combination of parameters and
+ returns a string that identifies the test case.
+ """
+ param_test_names = [name for name in dir(cls) if name.startswith("ParamTest")]
+
+ # Force param_list to an actual list; otherwise itertools.Product will hit
+ # the end, resulting in only the first ParamTest* method actually being
+ # parameterized
+ param_list = list(param_list)
+
+ # Parameterize each test method starting with "ParamTest"
+ for test_name in param_test_names:
+ func = getattr(cls, test_name)
+
+ for params in param_list:
+ # Give the test method a readable, debuggable name.
+ param_string = name_generator(*params)
+ new_name = "%s_%s" % (func.__name__.replace("ParamTest", "test"),
+ param_string)
+ new_name = new_name.replace("(", "-").replace(")", "") # remove parens
+
+ # Inject the test method
+ setattr(cls, new_name, _GetTestClosure(func, params))
+
+
+def _GetTestClosure(func, params):
+ """ Creates a no-argument test method for the given function and parameters.
+
+ This is required to be separate from the InjectParameterizedTest method, due
+ to some interesting scoping issues with internal function declarations. If
+ left in InjectParameterizedTest, all the tests end up using the same
+ instance of TestClosure
+
+ Args:
+ func: the function for which this test closure should run
+ params: the parameters for the run of this test function
+ """
+
+ def TestClosure(self):
+ func(self, *params)
+
+ return TestClosure
diff --git a/net/test/xfrm.py b/net/test/xfrm.py
index e9b4fce..acdfd4f 100755
--- a/net/test/xfrm.py
+++ b/net/test/xfrm.py
@@ -18,18 +18,15 @@
# pylint: disable=g-bad-todo
-import errno
import os
from socket import * # pylint: disable=wildcard-import
import struct
+import net_test
import csocket
import cstruct
-import net_test
import netlink
-# Base netlink constants. See include/uapi/linux/netlink.h.
-NETLINK_XFRM = 6
# Netlink constants. See include/uapi/linux/xfrm.h.
# Message types.
@@ -86,6 +83,10 @@
XFRMA_PROTO = 25
XFRMA_ADDRESS_FILTER = 26
XFRMA_PAD = 27
+XFRMA_OFFLOAD_DEV = 28
+XFRMA_OUTPUT_MARK = 29
+XFRMA_INPUT_MARK = 30
+XFRMA_IF_ID = 31
# Other netlink constants. See include/uapi/linux/xfrm.h.
@@ -113,10 +114,22 @@
XFRM_POLICY_ALLOW = 0
XFRM_POLICY_BLOCK = 1
-# Flags.
+# Policy flags.
XFRM_POLICY_LOCALOK = 1
XFRM_POLICY_ICMP = 2
+# State flags.
+XFRM_STATE_AF_UNSPEC = 32
+
+# XFRM algorithm names, as defined in net/xfrm/xfrm_algo.c.
+XFRM_EALG_CBC_AES = "cbc(aes)"
+XFRM_AALG_HMAC_MD5 = "hmac(md5)"
+XFRM_AALG_HMAC_SHA1 = "hmac(sha1)"
+XFRM_AALG_HMAC_SHA256 = "hmac(sha256)"
+XFRM_AALG_HMAC_SHA384 = "hmac(sha384)"
+XFRM_AALG_HMAC_SHA512 = "hmac(sha512)"
+XFRM_AEAD_GCM_AES = "rfc4106(gcm(aes))"
+
# Data structure formats.
# These aren't constants, they're classes. So, pylint: disable=invalid-name
XfrmSelector = cstruct.Struct(
@@ -134,14 +147,15 @@
XfrmAlgo = cstruct.Struct("XfrmAlgo", "=64AI", "name key_len")
-XfrmAlgoAuth = cstruct.Struct("XfrmAlgo", "=64AII", "name key_len trunc_len")
+XfrmAlgoAuth = cstruct.Struct("XfrmAlgoAuth", "=64AII",
+ "name key_len trunc_len")
XfrmAlgoAead = cstruct.Struct("XfrmAlgoAead", "=64AII", "name key_len icv_len")
XfrmStats = cstruct.Struct(
"XfrmStats", "=III", "replay_window replay integrity_failed")
-XfrmId = cstruct.Struct("XfrmId", "=16sIBxxx", "daddr spi proto")
+XfrmId = cstruct.Struct("XfrmId", "!16sIBxxx", "daddr spi proto")
XfrmUserTmpl = cstruct.Struct(
"XfrmUserTmpl", "=SHxx16sIBBBxIII",
@@ -156,14 +170,28 @@
"sel id saddr lft curlft stats seq reqid family mode replay_window flags",
[XfrmSelector, XfrmId, XfrmLifetimeCfg, XfrmLifetimeCur, XfrmStats])
-XfrmUsersaId = cstruct.Struct(
- "XfrmUsersaInfo", "=16sIHBx", "daddr spi family proto")
+XfrmUserSpiInfo = cstruct.Struct(
+ "XfrmUserSpiInfo", "=SII", "info min max", [XfrmUsersaInfo])
+# Technically the family is a 16-bit field, but only a few families are in use,
+# and if we pretend it's 8 bits (i.e., use "Bx" instead of "H") we can think
+# of the whole structure as being in network byte order.
+XfrmUsersaId = cstruct.Struct(
+ "XfrmUsersaId", "!16sIBxBx", "daddr spi family proto")
+
+# xfrm.h - struct xfrm_userpolicy_info
XfrmUserpolicyInfo = cstruct.Struct(
"XfrmUserpolicyInfo", "=SSSIIBBBBxxxx",
"sel lft curlft priority index dir action flags share",
[XfrmSelector, XfrmLifetimeCfg, XfrmLifetimeCur])
+XfrmUserpolicyId = cstruct.Struct(
+ "XfrmUserpolicyId", "=SIBxxx", "sel index dir", [XfrmSelector])
+
+XfrmUsersaFlush = cstruct.Struct("XfrmUsersaFlush", "=B", "proto")
+
+XfrmMark = cstruct.Struct("XfrmMark", "=II", "mark mask")
+
# Socket options. See include/uapi/linux/in.h.
IP_IPSEC_POLICY = 16
IP_XFRM_POLICY = 17
@@ -179,6 +207,23 @@
NO_LIFETIME_CFG = XfrmLifetimeCfg((_INF, _INF, _INF, _INF, 0, 0, 0, 0))
NO_LIFETIME_CUR = "\x00" * len(XfrmLifetimeCur)
+# IPsec constants.
+IPSEC_PROTO_ANY = 255
+
+# ESP header, not technically XFRM but we need a place for a protocol
+# header and this is the only one we have.
+# TODO: move this somewhere more appropriate when possible
+EspHdr = cstruct.Struct("EspHdr", "!II", "spi seqnum")
+
+# Local constants.
+_DEFAULT_REPLAY_WINDOW = 4
+ALL_ALGORITHMS = 0xffffffff
+
+# Policy-SA match method (for VTI/XFRM-I).
+MATCH_METHOD_ALL = "all"
+MATCH_METHOD_MARK = "mark"
+MATCH_METHOD_IFID = "ifid"
+
def RawAddress(addr):
"""Converts an IP address string to binary format."""
@@ -194,14 +239,99 @@
return padded
+XFRM_ADDR_ANY = PaddedAddress("::")
+
+
+def EmptySelector(family):
+ """A selector that matches all packets of the specified address family."""
+ return XfrmSelector(family=family)
+
+
+def SrcDstSelector(src, dst):
+ """A selector that matches packets between the specified IP addresses."""
+ srcver = csocket.AddressVersion(src)
+ dstver = csocket.AddressVersion(dst)
+ if srcver != dstver:
+ raise ValueError("Cross-address family selector specified: %s -> %s" %
+ (src, dst))
+ prefixlen = net_test.AddressLengthBits(srcver)
+ family = net_test.GetAddressFamily(srcver)
+ return XfrmSelector(saddr=PaddedAddress(src), daddr=PaddedAddress(dst),
+ prefixlen_s=prefixlen, prefixlen_d=prefixlen, family=family)
+
+
+def UserPolicy(direction, selector):
+ """Create an IPsec policy.
+
+ Args:
+ direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
+ selector: An XfrmSelector, the packets to transform.
+
+ Return: a XfrmUserpolicyInfo cstruct.
+ """
+ # Create a user policy that specifies that all packets in the specified
+ # direction matching the selector should be encrypted.
+ return XfrmUserpolicyInfo(
+ sel=selector,
+ lft=NO_LIFETIME_CFG,
+ curlft=NO_LIFETIME_CUR,
+ dir=direction,
+ action=XFRM_POLICY_ALLOW,
+ flags=XFRM_POLICY_LOCALOK,
+ share=XFRM_SHARE_UNIQUE)
+
+
+def UserTemplate(family, spi, reqid, tun_addrs):
+ """Create an ESP policy and template.
+
+ Args:
+ spi: 32-bit SPI in host byte order
+ reqid: 32-bit ID matched against SAs
+ tun_addrs: A tuple of (local, remote) addresses for tunnel mode, or None
+ to request a transport mode SA.
+
+ Return: a tuple of XfrmUserpolicyInfo, XfrmUserTmpl
+ """
+ # For transport mode, set template source and destination are empty.
+ # For tunnel mode, explicitly specify source and destination addresses.
+ if tun_addrs is None:
+ mode = XFRM_MODE_TRANSPORT
+ saddr = XFRM_ADDR_ANY
+ daddr = XFRM_ADDR_ANY
+ else:
+ mode = XFRM_MODE_TUNNEL
+ saddr = PaddedAddress(tun_addrs[0])
+ daddr = PaddedAddress(tun_addrs[1])
+
+ # Create a template that specifies the SPI and the protocol.
+ xfrmid = XfrmId(daddr=daddr, spi=spi, proto=IPPROTO_ESP)
+ template = XfrmUserTmpl(
+ id=xfrmid,
+ family=family,
+ saddr=saddr,
+ reqid=reqid,
+ mode=mode,
+ share=XFRM_SHARE_UNIQUE,
+ optional=0, #require
+ aalgos=ALL_ALGORITHMS,
+ ealgos=ALL_ALGORITHMS,
+ calgos=ALL_ALGORITHMS)
+
+ return template
+
+
+def ExactMatchMark(mark):
+ """An XfrmMark that matches only the specified mark."""
+ return XfrmMark((mark, 0xffffffff))
+
+
class Xfrm(netlink.NetlinkSocket):
"""Netlink interface to xfrm."""
- FAMILY = NETLINK_XFRM
DEBUG = False
def __init__(self):
- super(Xfrm, self).__init__()
+ super(Xfrm, self).__init__(netlink.NETLINK_XFRM)
def _GetConstantName(self, value, prefix):
return super(Xfrm, self)._GetConstantName(__name__, value, prefix)
@@ -217,6 +347,10 @@
struct_type = XfrmUsersaId
elif command == XFRM_MSG_DELSA:
struct_type = XfrmUsersaId
+ elif command == XFRM_MSG_ALLOCSPI:
+ struct_type = XfrmUserSpiInfo
+ elif command == XFRM_MSG_NEWPOLICY:
+ struct_type = XfrmUserpolicyInfo
else:
struct_type = None
@@ -236,53 +370,345 @@
data = cstruct.Read(nla_data, XfrmAlgoAuth)[0]
elif name == "XFRMA_ENCAP":
data = cstruct.Read(nla_data, XfrmEncapTmpl)[0]
+ elif name == "XFRMA_MARK":
+ data = cstruct.Read(nla_data, XfrmMark)[0]
+ elif name == "XFRMA_OUTPUT_MARK":
+ data = struct.unpack("=I", nla_data)[0]
+ elif name == "XFRMA_TMPL":
+ data = cstruct.Read(nla_data, XfrmUserTmpl)[0]
+ elif name == "XFRMA_IF_ID":
+ data = struct.unpack("=I", nla_data)[0]
else:
data = nla_data
return name, data
- def AddSaInfo(self, selector, xfrm_id, saddr, lifetimes, reqid, family, mode,
- replay_window, flags, nlattrs):
- # The kernel ignores these on input.
- cur = "\x00" * len(XfrmLifetimeCur)
- stats = "\x00" * len(XfrmStats)
- seq = 0
- sa = XfrmUsersaInfo((selector, xfrm_id, saddr, lifetimes, cur, stats, seq,
- reqid, family, mode, replay_window, flags))
- msg = sa.Pack() + nlattrs
- flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
- self._SendNlRequest(XFRM_MSG_NEWSA, msg, flags)
+ def _UpdatePolicyInfo(self, msg, policy, tmpl, mark, xfrm_if_id):
+ """Send a policy to the Security Policy Database"""
+ nlattrs = []
+ if tmpl is not None:
+ nlattrs.append((XFRMA_TMPL, tmpl))
+ if mark is not None:
+ nlattrs.append((XFRMA_MARK, mark))
+ if xfrm_if_id is not None:
+ nlattrs.append((XFRMA_IF_ID, struct.pack("=I", xfrm_if_id)))
+ self.SendXfrmNlRequest(msg, policy, nlattrs)
- def AddMinimalSaInfo(self, src, dst, spi, proto, mode, reqid,
- encryption, encryption_key,
- auth_trunc, auth_trunc_key, encap):
- selector = XfrmSelector("\x00" * len(XfrmSelector))
+ def AddPolicyInfo(self, policy, tmpl, mark, xfrm_if_id=None):
+ """Add a new policy to the Security Policy Database
+
+ If the policy exists, then return an error (EEXIST).
+
+ Args:
+ policy: an unpacked XfrmUserpolicyInfo
+ tmpl: an unpacked XfrmUserTmpl
+ mark: an unpacked XfrmMark
+ xfrm_if_id: the XFRM interface ID as an integer, or None
+ """
+ self._UpdatePolicyInfo(XFRM_MSG_NEWPOLICY, policy, tmpl, mark, xfrm_if_id)
+
+ def UpdatePolicyInfo(self, policy, tmpl, mark, xfrm_if_id):
+ """Update an existing policy in the Security Policy Database
+
+ If the policy does not exist, then create it; otherwise, update the
+ existing policy record.
+
+ Args:
+ policy: an unpacked XfrmUserpolicyInfo
+ tmpl: an unpacked XfrmUserTmpl to update
+ mark: an unpacked XfrmMark to match the existing policy or None
+ xfrm_if_id: an XFRM interface ID or None
+ """
+ self._UpdatePolicyInfo(XFRM_MSG_UPDPOLICY, policy, tmpl, mark, xfrm_if_id)
+
+ def DeletePolicyInfo(self, selector, direction, mark, xfrm_if_id=None):
+ """Delete a policy from the Security Policy Database
+
+ Args:
+ selector: an XfrmSelector matching the policy to delete
+ direction: policy direction
+ mark: an unpacked XfrmMark to match the policy or None
+ """
+ nlattrs = []
+ if mark is not None:
+ nlattrs.append((XFRMA_MARK, mark))
+ if xfrm_if_id is not None:
+ nlattrs.append((XFRMA_IF_ID, struct.pack("=I", xfrm_if_id)))
+ self.SendXfrmNlRequest(XFRM_MSG_DELPOLICY,
+ XfrmUserpolicyId(sel=selector, dir=direction),
+ nlattrs)
+
+ # TODO: this function really needs to be in netlink.py
+ def SendXfrmNlRequest(self, msg_type, req, nlattrs=None,
+ flags=netlink.NLM_F_ACK|netlink.NLM_F_REQUEST):
+ """Sends a netlink request message
+
+ Args:
+ msg_type: an XFRM_MSG_* type
+ req: an unpacked netlink request message body cstruct
+ nlattrs: an unpacked list of two-tuples of (NLATTR_* type, body) where
+ the body is an unpacked cstruct
+ flags: a list of flags for the expected handling; if no flags are
+ provided, an ACK response is assumed.
+ """
+ msg = req.Pack()
+ if nlattrs is None:
+ nlattrs = []
+ for attr_type, attr_msg in nlattrs:
+ # TODO: find a better way to deal with the fact that many XFRM messages
+ # use nlattrs that aren't cstructs.
+ #
+ # This code allows callers to pass in either something that has a Pack()
+ # method or a packed netlink attr, but not other types of attributes.
+ # Alternatives include:
+ #
+ # 1. Require callers to marshal netlink attributes themselves and call
+ # _SendNlRequest directly. Delete this method.
+ # 2. Rename this function to _SendXfrmNlRequestCstructOnly (or other name
+ # that makes it clear that this only takes cstructs). Switch callers
+ # that need non-cstruct elements to calling _SendNlRequest directly.
+ # 3. Make this function somehow automatically detect what to do for
+ # all types of XFRM attributes today and in the future. This may be
+ # feasible because all XFRM attributes today occupy the same number
+ # space, but what about nested attributes? It is unlikley feasible via
+ # things like "if isinstance(attr_msg, str): ...", because that would
+ # not be able to determine the right size or byte order for non-struct
+ # types such as int.
+ # 4. Define fictitious cstructs which have no correspondence to actual
+ # kernel structs such as the following to represent a raw integer.
+ # XfrmAttrOutputMark = cstruct.Struct("=I", mark)
+ if hasattr(attr_msg, "Pack"):
+ attr_msg = attr_msg.Pack()
+ msg += self._NlAttr(attr_type, attr_msg)
+ return self._SendNlRequest(msg_type, msg, flags)
+
+ def AddSaInfo(self, src, dst, spi, mode, reqid, encryption, auth_trunc, aead,
+ encap, mark, output_mark, is_update=False, xfrm_if_id=None):
+ """Adds an IPsec security association.
+
+ Args:
+ src: A string, the source IP address. May be a wildcard in transport mode.
+ dst: A string, the destination IP address. Forms part of the XFRM ID, and
+ must match the destination address of the packets sent by this SA.
+ spi: An integer, the SPI.
+ mode: An IPsec mode such as XFRM_MODE_TRANSPORT.
+ reqid: A request ID. Can be used in policies to match the SA.
+ encryption: A tuple of an XfrmAlgo and raw key bytes, or None.
+ auth_trunc: A tuple of an XfrmAlgoAuth and raw key bytes, or None.
+ aead: A tuple of an XfrmAlgoAead and raw key bytes, or None.
+ encap: An XfrmEncapTmpl structure, or None.
+ mark: A mark match specifier, such as returned by ExactMatchMark(), or
+ None for an SA that matches all possible marks.
+ output_mark: An integer, the output mark. 0 means unset.
+ is_update: If true, update an existing SA otherwise create a new SA. For
+ compatibility reasons, this value defaults to False.
+ xfrm_if_id: The XFRM interface ID, or None.
+ """
+ proto = IPPROTO_ESP
xfrm_id = XfrmId((PaddedAddress(dst), spi, proto))
family = AF_INET6 if ":" in dst else AF_INET
- nlattrs = self._NlAttr(XFRMA_ALG_CRYPT,
- encryption.Pack() + encryption_key)
- nlattrs += self._NlAttr(XFRMA_ALG_AUTH_TRUNC,
- auth_trunc.Pack() + auth_trunc_key)
+
+ nlattrs = ""
+ if encryption is not None:
+ enc, key = encryption
+ nlattrs += self._NlAttr(XFRMA_ALG_CRYPT, enc.Pack() + key)
+
+ if auth_trunc is not None:
+ auth, key = auth_trunc
+ nlattrs += self._NlAttr(XFRMA_ALG_AUTH_TRUNC, auth.Pack() + key)
+
+ if aead is not None:
+ aead_alg, key = aead
+ nlattrs += self._NlAttr(XFRMA_ALG_AEAD, aead_alg.Pack() + key)
+
+ # if a user provides either mark or mask, then we send the mark attribute
+ if mark is not None:
+ nlattrs += self._NlAttr(XFRMA_MARK, mark.Pack())
if encap is not None:
nlattrs += self._NlAttr(XFRMA_ENCAP, encap.Pack())
- self.AddSaInfo(selector, xfrm_id, PaddedAddress(src), NO_LIFETIME_CFG,
- reqid, family, mode, 4, 0, nlattrs)
+ if output_mark is not None:
+ nlattrs += self._NlAttrU32(XFRMA_OUTPUT_MARK, output_mark)
+ if xfrm_if_id is not None:
+ nlattrs += self._NlAttrU32(XFRMA_IF_ID, xfrm_if_id)
- def DeleteSaInfo(self, daddr, spi, proto):
- # TODO: deletes take a mark as well.
- family = AF_INET6 if ":" in daddr else AF_INET
- usersa_id = XfrmUsersaId((PaddedAddress(daddr), spi, family, proto))
+ # The kernel ignores these on input, so make them empty.
+ cur = XfrmLifetimeCur()
+ stats = XfrmStats()
+ seq = 0
+ replay = _DEFAULT_REPLAY_WINDOW
+
+ # The XFRM_STATE_AF_UNSPEC flag determines how AF_UNSPEC selectors behave.
+ #
+ # - If the flag is not set, an AF_UNSPEC selector has its family changed to
+ # the SA family, which in our case is the address family of dst.
+ # - If the flag is set, an AF_UNSPEC selector is left as is. In transport
+ # mode this fails with EPROTONOSUPPORT, but in tunnel mode, it results in
+ # a dual-stack SA that can tunnel both IPv4 and IPv6 packets.
+ #
+ # This allows us to pass an empty selector to the kernel regardless of which
+ # mode we're in: when creating transport mode SAs, the kernel will pick the
+ # selector family based on the SA family, and when creating tunnel mode SAs,
+ # we'll just create SAs that select both IPv4 and IPv6 traffic, and leave it
+ # up to the policy selectors to determine what traffic we actually want to
+ # transform.
+ flags = XFRM_STATE_AF_UNSPEC if mode == XFRM_MODE_TUNNEL else 0
+ selector = EmptySelector(AF_UNSPEC)
+
+ sa = XfrmUsersaInfo((selector, xfrm_id, PaddedAddress(src), NO_LIFETIME_CFG,
+ cur, stats, seq, reqid, family, mode, replay, flags))
+ msg = sa.Pack() + nlattrs
flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
- self._SendNlRequest(XFRM_MSG_DELSA, usersa_id.Pack(), flags)
+ nl_msg_type = XFRM_MSG_UPDSA if is_update else XFRM_MSG_NEWSA
+ self._SendNlRequest(nl_msg_type, msg, flags)
+
+ def DeleteSaInfo(self, dst, spi, proto, mark=None, xfrm_if_id=None):
+ """Delete an SA from the SAD
+
+ Args:
+ dst: A string, the destination IP address. Forms part of the XFRM ID, and
+ must match the destination address of the packets sent by this SA.
+ spi: An integer, the SPI.
+ proto: The protocol DB of the SA, such as IPPROTO_ESP.
+ mark: A mark match specifier, such as returned by ExactMatchMark(), or
+ None for an SA without a Mark attribute.
+ """
+ family = AF_INET6 if ":" in dst else AF_INET
+ usersa_id = XfrmUsersaId((PaddedAddress(dst), spi, family, proto))
+ nlattrs = []
+ if mark is not None:
+ nlattrs.append((XFRMA_MARK, mark))
+ if xfrm_if_id is not None:
+ nlattrs.append((XFRMA_IF_ID, struct.pack("=I", xfrm_if_id)))
+ self.SendXfrmNlRequest(XFRM_MSG_DELSA, usersa_id, nlattrs)
+
+ def AllocSpi(self, dst, proto, min_spi, max_spi):
+ """Allocate (reserve) an SPI.
+
+ This sends an XFRM_MSG_ALLOCSPI message and returns the resulting
+ XfrmUsersaInfo struct.
+
+ Args:
+ dst: A string, the destination IP address. Forms part of the XFRM ID, and
+ must match the destination address of the packets sent by this SA.
+ proto: the protocol DB of the SA, such as IPPROTO_ESP.
+ min_spi: The minimum value of the acceptable SPI range (inclusive).
+ max_spi: The maximum value of the acceptable SPI range (inclusive).
+ """
+ spi = XfrmUserSpiInfo("\x00" * len(XfrmUserSpiInfo))
+ spi.min = min_spi
+ spi.max = max_spi
+ spi.info.id.daddr = PaddedAddress(dst)
+ spi.info.id.proto = proto
+ spi.info.family = AF_INET6 if ":" in dst else AF_INET
+
+ msg = spi.Pack()
+ flags = netlink.NLM_F_REQUEST
+ self._SendNlRequest(XFRM_MSG_ALLOCSPI, msg, flags)
+ # Read the response message.
+ data = self._Recv()
+ nl_hdr, data = cstruct.Read(data, netlink.NLMsgHdr)
+ if nl_hdr.type == XFRM_MSG_NEWSA:
+ return XfrmUsersaInfo(data)
+ if nl_hdr.type == netlink.NLMSG_ERROR:
+ error = netlink.NLMsgErr(data).error
+ raise IOError(error, os.strerror(-error))
+ raise ValueError("Unexpected netlink message type: %d" % nl_hdr.type)
def DumpSaInfo(self):
return self._Dump(XFRM_MSG_GETSA, None, XfrmUsersaInfo, "")
+ def DumpPolicyInfo(self):
+ return self._Dump(XFRM_MSG_GETPOLICY, None, XfrmUserpolicyInfo, "")
+
def FindSaInfo(self, spi):
sainfo = [sa for sa, attrs in self.DumpSaInfo() if sa.id.spi == spi]
return sainfo[0] if sainfo else None
+ def FlushPolicyInfo(self):
+ """Send a Netlink Request to Flush all records from the SPD"""
+ flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
+ self._SendNlRequest(XFRM_MSG_FLUSHPOLICY, "", flags)
+
+ def FlushSaInfo(self):
+ usersa_flush = XfrmUsersaFlush((IPSEC_PROTO_ANY,))
+ flags = netlink.NLM_F_REQUEST | netlink.NLM_F_ACK
+ self._SendNlRequest(XFRM_MSG_FLUSHSA, usersa_flush.Pack(), flags)
+
+ def CreateTunnel(self, direction, selector, src, dst, spi, encryption,
+ auth_trunc, mark, output_mark, xfrm_if_id, match_method):
+ """Create an XFRM Tunnel Consisting of a Policy and an SA.
+
+ Create a unidirectional XFRM tunnel, which entails one Policy and one
+ security association.
+
+ Args:
+ direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
+ selector: An XfrmSelector that specifies the packets to be transformed.
+ This is only applied to the policy; the selector in the SA is always
+ empty. If the passed-in selector is None, then the tunnel is made
+ dual-stack. This requires two policies, one for IPv4 and one for IPv6.
+ src: The source address of the tunneled packets
+ dst: The destination address of the tunneled packets
+ spi: The SPI for the IPsec SA that encapsulates the tunneled packet
+ encryption: A tuple (XfrmAlgo, key), the encryption parameters.
+ auth_trunc: A tuple (XfrmAlgoAuth, key), the authentication parameters.
+ mark: An XfrmMark, the mark used for selecting packets to be tunneled, and
+ for matching the security policy. None means unspecified.
+ output_mark: The mark used to select the underlying network for packets
+ outbound from xfrm. None means unspecified.
+ xfrm_if_id: The ID of the XFRM interface to use or None.
+ match_method: One of MATCH_METHOD_[MARK | ALL | IFID]. This determines how
+ SAs and policies are matched.
+ """
+ outer_family = net_test.GetAddressFamily(net_test.GetAddressVersion(dst))
+
+ # SA mark is currently unused due to UPDSA not updating marks.
+ # Kept as documentation of ideal/desired behavior.
+ if match_method == MATCH_METHOD_MARK:
+ # sa_mark = mark
+ tmpl_spi = 0
+ if_id = None
+ elif match_method == MATCH_METHOD_ALL:
+ # sa_mark = mark
+ tmpl_spi = spi
+ if_id = xfrm_if_id
+ elif match_method == MATCH_METHOD_IFID:
+ # sa_mark = None
+ tmpl_spi = 0
+ if_id = xfrm_if_id
+ else:
+ raise ValueError("Unknown match_method supplied: %s" % match_method)
+
+ # Device code does not use mark; during AllocSpi, the mark is unset, and
+ # UPDSA does not update marks at this time. Actual use case will have no
+ # mark set. Test this use case.
+ self.AddSaInfo(src, dst, spi, XFRM_MODE_TUNNEL, 0, encryption, auth_trunc,
+ None, None, None, output_mark, xfrm_if_id=xfrm_if_id)
+
+ if selector is None:
+ selectors = [EmptySelector(AF_INET), EmptySelector(AF_INET6)]
+ else:
+ selectors = [selector]
+
+ for selector in selectors:
+ policy = UserPolicy(direction, selector)
+ tmpl = UserTemplate(outer_family, tmpl_spi, 0, (src, dst))
+ self.AddPolicyInfo(policy, tmpl, mark, xfrm_if_id=xfrm_if_id)
+
+ def DeleteTunnel(self, direction, selector, dst, spi, mark, xfrm_if_id):
+ if mark is not None:
+ mark = ExactMatchMark(mark)
+
+ self.DeleteSaInfo(dst, spi, IPPROTO_ESP, mark, xfrm_if_id)
+ if selector is None:
+ selectors = [EmptySelector(AF_INET), EmptySelector(AF_INET6)]
+ else:
+ selectors = [selector]
+ for selector in selectors:
+ self.DeletePolicyInfo(selector, direction, mark, xfrm_if_id)
+
if __name__ == "__main__":
x = Xfrm()
print x.DumpSaInfo()
+ print x.DumpPolicyInfo()
diff --git a/net/test/xfrm_algorithm_test.py b/net/test/xfrm_algorithm_test.py
new file mode 100755
index 0000000..0176265
--- /dev/null
+++ b/net/test/xfrm_algorithm_test.py
@@ -0,0 +1,295 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
+from errno import * # pylint: disable=wildcard-import
+import os
+import itertools
+from scapy import all as scapy
+from socket import * # pylint: disable=wildcard-import
+import subprocess
+import threading
+import unittest
+
+import multinetwork_base
+import net_test
+from tun_twister import TapTwister
+import util
+import xfrm
+import xfrm_base
+
+# List of encryption algorithms for use in ParamTests.
+CRYPT_ALGOS = [
+ xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)),
+ xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)),
+ xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)),
+]
+
+# List of auth algorithms for use in ParamTests.
+AUTH_ALGOS = [
+ # RFC 4868 specifies that the only supported truncation length is half the
+ # hash size.
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)),
+ # Test larger truncation lengths for good measure.
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)),
+ xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)),
+]
+
+# List of aead algorithms for use in ParamTests.
+AEAD_ALGOS = [
+ # RFC 4106 specifies that key length must be 128, 192 or 256 bits,
+ # with an additional 4 bytes (32 bits) of salt. The salt must be unique
+ # for each new SA using the same key.
+ # RFC 4106 specifies that ICV length must be 8, 12, or 16 bytes
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 8*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 8*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 8*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)),
+ xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)),
+]
+
+def InjectTests():
+ XfrmAlgorithmTest.InjectTests()
+
+
+class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest):
+ @classmethod
+ def InjectTests(cls):
+ VERSIONS = (4, 6)
+ TYPES = (SOCK_DGRAM, SOCK_STREAM)
+
+ # Tests all combinations of auth & crypt. Mutually exclusive with aead.
+ param_list = itertools.product(VERSIONS, TYPES, AUTH_ALGOS, CRYPT_ALGOS,
+ [None])
+ util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
+
+ # Tests all combinations of aead. Mutually exclusive with auth/crypt.
+ param_list = itertools.product(VERSIONS, TYPES, [None], [None], AEAD_ALGOS)
+ util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
+
+ @staticmethod
+ def TestNameGenerator(version, proto, auth, crypt, aead):
+ # Produce a unique and readable name for each test. e.g.
+ # testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP
+ param_string = ""
+ if crypt is not None:
+ param_string += "%s_%d_" % (crypt.name, crypt.key_len)
+
+ if auth is not None:
+ param_string += "%s_%d_%d_" % (auth.name, auth.key_len,
+ auth.trunc_len)
+
+ if aead is not None:
+ param_string += "%s_%d_%d_" % (aead.name, aead.key_len,
+ aead.icv_len)
+
+ param_string += "%s_%s" % ("IPv4" if version == 4 else "IPv6",
+ "UDP" if proto == SOCK_DGRAM else "TCP")
+ return param_string
+
+ def ParamTestSocketPolicySimple(self, version, proto, auth, crypt, aead):
+ """Test two-way traffic using transport mode and socket policies."""
+
+ def AssertEncrypted(packet):
+ # This gives a free pass to ICMP and ICMPv6 packets, which show up
+ # nondeterministically in tests.
+ self.assertEquals(None,
+ packet.getlayer(scapy.UDP),
+ "UDP packet sent in the clear")
+ self.assertEquals(None,
+ packet.getlayer(scapy.TCP),
+ "TCP packet sent in the clear")
+
+ # We create a pair of sockets, "left" and "right", that will talk to each
+ # other using transport mode ESP. Because of TapTwister, both sockets
+ # perceive each other as owning "remote_addr".
+ netid = self.RandomNetid()
+ family = net_test.GetAddressFamily(version)
+ local_addr = self.MyAddress(version, netid)
+ remote_addr = self.GetRemoteSocketAddress(version)
+ auth_left = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
+ os.urandom(auth.key_len / 8)) if auth else None
+ auth_right = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
+ os.urandom(auth.key_len / 8)) if auth else None
+ crypt_left = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
+ os.urandom(crypt.key_len / 8)) if crypt else None
+ crypt_right = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
+ os.urandom(crypt.key_len / 8)) if crypt else None
+ aead_left = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
+ os.urandom(aead.key_len / 8)) if aead else None
+ aead_right = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
+ os.urandom(aead.key_len / 8)) if aead else None
+ spi_left = 0xbeefface
+ spi_right = 0xcafed00d
+ req_ids = [100, 200, 300, 400] # Used to match templates and SAs.
+
+ # Left outbound SA
+ self.xfrm.AddSaInfo(
+ src=local_addr,
+ dst=remote_addr,
+ spi=spi_right,
+ mode=xfrm.XFRM_MODE_TRANSPORT,
+ reqid=req_ids[0],
+ encryption=crypt_right,
+ auth_trunc=auth_right,
+ aead=aead_right,
+ encap=None,
+ mark=None,
+ output_mark=None)
+ # Right inbound SA
+ self.xfrm.AddSaInfo(
+ src=remote_addr,
+ dst=local_addr,
+ spi=spi_right,
+ mode=xfrm.XFRM_MODE_TRANSPORT,
+ reqid=req_ids[1],
+ encryption=crypt_right,
+ auth_trunc=auth_right,
+ aead=aead_right,
+ encap=None,
+ mark=None,
+ output_mark=None)
+ # Right outbound SA
+ self.xfrm.AddSaInfo(
+ src=local_addr,
+ dst=remote_addr,
+ spi=spi_left,
+ mode=xfrm.XFRM_MODE_TRANSPORT,
+ reqid=req_ids[2],
+ encryption=crypt_left,
+ auth_trunc=auth_left,
+ aead=aead_left,
+ encap=None,
+ mark=None,
+ output_mark=None)
+ # Left inbound SA
+ self.xfrm.AddSaInfo(
+ src=remote_addr,
+ dst=local_addr,
+ spi=spi_left,
+ mode=xfrm.XFRM_MODE_TRANSPORT,
+ reqid=req_ids[3],
+ encryption=crypt_left,
+ auth_trunc=auth_left,
+ aead=aead_left,
+ encap=None,
+ mark=None,
+ output_mark=None)
+
+ # Make two sockets.
+ sock_left = socket(family, proto, 0)
+ sock_left.settimeout(2.0)
+ sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ self.SelectInterface(sock_left, netid, "mark")
+ sock_right = socket(family, proto, 0)
+ sock_right.settimeout(2.0)
+ sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ self.SelectInterface(sock_right, netid, "mark")
+
+ # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
+ # in a TIME_WAIT state.
+ if proto == SOCK_STREAM:
+ net_test.DisableFinWait(sock_left)
+ net_test.DisableFinWait(sock_right)
+
+ # Apply the left outbound socket policy.
+ xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
+ spi_right, req_ids[0], None)
+ # Apply right inbound socket policy.
+ xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_IN,
+ spi_right, req_ids[1], None)
+ # Apply right outbound socket policy.
+ xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_OUT,
+ spi_left, req_ids[2], None)
+ # Apply left inbound socket policy.
+ xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_IN,
+ spi_left, req_ids[3], None)
+
+ server_ready = threading.Event()
+ server_error = None # Save exceptions thrown by the server.
+
+ def TcpServer(sock, client_port):
+ try:
+ sock.listen(1)
+ server_ready.set()
+ accepted, peer = sock.accept()
+ self.assertEquals(remote_addr, peer[0])
+ self.assertEquals(client_port, peer[1])
+ data = accepted.recv(2048)
+ self.assertEquals("hello request", data)
+ accepted.send("hello response")
+ except Exception as e:
+ server_error = e
+ finally:
+ sock.close()
+
+ def UdpServer(sock, client_port):
+ try:
+ server_ready.set()
+ data, peer = sock.recvfrom(2048)
+ self.assertEquals(remote_addr, peer[0])
+ self.assertEquals(client_port, peer[1])
+ self.assertEquals("hello request", data)
+ sock.sendto("hello response", peer)
+ except Exception as e:
+ server_error = e
+ finally:
+ sock.close()
+
+ # Server and client need to know each other's port numbers in advance.
+ wildcard_addr = net_test.GetWildcardAddress(version)
+ sock_left.bind((wildcard_addr, 0))
+ sock_right.bind((wildcard_addr, 0))
+ left_port = sock_left.getsockname()[1]
+ right_port = sock_right.getsockname()[1]
+
+ # Start the appropriate server type on sock_right.
+ target = TcpServer if proto == SOCK_STREAM else UdpServer
+ server = threading.Thread(
+ target=target,
+ args=(sock_right, left_port),
+ name="SocketServer")
+ server.start()
+ # Wait for server to be ready before attempting to connect. TCP retries
+ # hide this problem, but UDP will fail outright if the server socket has
+ # not bound when we send.
+ self.assertTrue(server_ready.wait(2.0), "Timed out waiting for server thread")
+
+ with TapTwister(fd=self.tuns[netid].fileno(), validator=AssertEncrypted):
+ sock_left.connect((remote_addr, right_port))
+ sock_left.send("hello request")
+ data = sock_left.recv(2048)
+ self.assertEquals("hello response", data)
+ sock_left.close()
+ server.join()
+ if server_error:
+ raise server_error
+
+
+if __name__ == "__main__":
+ XfrmAlgorithmTest.InjectTests()
+ unittest.main()
diff --git a/net/test/xfrm_base.py b/net/test/xfrm_base.py
new file mode 100644
index 0000000..1eaa302
--- /dev/null
+++ b/net/test/xfrm_base.py
@@ -0,0 +1,314 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from socket import * # pylint: disable=wildcard-import
+from scapy import all as scapy
+import struct
+
+import csocket
+import cstruct
+import multinetwork_base
+import net_test
+import util
+import xfrm
+
+_ENCRYPTION_KEY_256 = ("308146eb3bd84b044573d60f5a5fd159"
+ "57c7d4fe567a2120f35bae0f9869ec22".decode("hex"))
+_AUTHENTICATION_KEY_128 = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
+
+_ALGO_AUTH_NULL = (xfrm.XfrmAlgoAuth(("digest_null", 0, 0)), "")
+_ALGO_HMAC_SHA1 = (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 128, 96)),
+ _AUTHENTICATION_KEY_128)
+
+_ALGO_CRYPT_NULL = (xfrm.XfrmAlgo(("ecb(cipher_null)", 0)), "")
+_ALGO_CBC_AES_256 = (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)),
+ _ENCRYPTION_KEY_256)
+
+# Match all bits of the mark
+MARK_MASK_ALL = 0xffffffff
+
+
+def SetPolicySockopt(sock, family, opt_data):
+ optlen = len(opt_data) if opt_data is not None else 0
+ if family == AF_INET:
+ csocket.Setsockopt(sock, IPPROTO_IP, xfrm.IP_XFRM_POLICY, opt_data, optlen)
+ else:
+ csocket.Setsockopt(sock, IPPROTO_IPV6, xfrm.IPV6_XFRM_POLICY, opt_data,
+ optlen)
+
+
+def ApplySocketPolicy(sock, family, direction, spi, reqid, tun_addrs):
+ """Create and apply an ESP policy to a socket.
+
+ A socket may have only one policy per direction, so applying a policy will
+ remove any policy that was previously applied in that direction.
+
+ Args:
+ sock: The socket that needs a policy
+ family: AF_INET or AF_INET6
+ direction: XFRM_POLICY_IN or XFRM_POLICY_OUT
+ spi: 32-bit SPI in host byte order
+ reqid: 32-bit ID matched against SAs
+ tun_addrs: A tuple of (local, remote) addresses for tunnel mode, or None
+ to request a transport mode SA.
+ """
+ # Create a selector that matches all packets of the specified address family.
+ selector = xfrm.EmptySelector(family)
+
+ # Create an XFRM policy and template.
+ policy = xfrm.UserPolicy(direction, selector)
+ template = xfrm.UserTemplate(family, spi, reqid, tun_addrs)
+
+ # Set the policy and template on our socket.
+ opt_data = policy.Pack() + template.Pack()
+
+ # The policy family might not match the socket family. For example, we might
+ # have an IPv4 policy on a dual-stack socket.
+ sockfamily = sock.getsockopt(SOL_SOCKET, net_test.SO_DOMAIN)
+ SetPolicySockopt(sock, sockfamily, opt_data)
+
+def _GetCryptParameters(crypt_alg):
+ """Looks up encryption algorithm's block and IV lengths.
+
+ Args:
+ crypt_alg: the encryption algorithm constant
+ Returns:
+ A tuple of the block size, and IV length
+ """
+ cryptParameters = {
+ _ALGO_CRYPT_NULL: (4, 0),
+ _ALGO_CBC_AES_256: (16, 16)
+ }
+
+ return cryptParameters.get(crypt_alg, (0, 0))
+
+def GetEspPacketLength(mode, version, udp_encap, payload,
+ auth_alg, crypt_alg):
+ """Calculates encrypted length of a UDP packet with the given payload.
+
+ Args:
+ mode: XFRM_MODE_TRANSPORT or XFRM_MODE_TUNNEL.
+ version: IPPROTO_IP for IPv4, IPPROTO_IPV6 for IPv6. The inner header.
+ udp_encap: whether UDP encap overhead should be accounted for. Since the
+ outermost IP header is ignored (payload only), only add for udp
+ encap'd packets.
+ payload: UDP payload bytes.
+ auth_alg: The xfrm_base authentication algorithm used in the SA.
+ crypt_alg: The xfrm_base encryption algorithm used in the SA.
+
+ Return: the packet length.
+ """
+
+ crypt_iv_len, crypt_blk_size=_GetCryptParameters(crypt_alg)
+ auth_trunc_len = auth_alg[0].trunc_len
+
+ # Wrap in UDP payload
+ payload_len = len(payload) + net_test.UDP_HDR_LEN
+
+ # Size constants
+ esp_hdr_len = len(xfrm.EspHdr) # SPI + Seq number
+ icv_len = auth_trunc_len / 8
+
+ # Add inner IP header if tunnel mode
+ if mode == xfrm.XFRM_MODE_TUNNEL:
+ payload_len += net_test.GetIpHdrLength(version)
+
+ # Add ESP trailer
+ payload_len += 2 # Pad Length + Next Header fields
+
+ # Align to block size of encryption algorithm
+ payload_len += util.GetPadLength(crypt_blk_size, payload_len)
+
+ # Add initialization vector, header length and ICV length
+ payload_len += esp_hdr_len + crypt_iv_len + icv_len
+
+ # Add encap as needed
+ if udp_encap:
+ payload_len += net_test.UDP_HDR_LEN
+
+ return payload_len
+
+
+def EncryptPacketWithNull(packet, spi, seq, tun_addrs):
+ """Apply null encryption to a packet.
+
+ This performs ESP encapsulation on the given packet. The returned packet will
+ be a tunnel mode packet if tun_addrs is provided.
+
+ The input packet is assumed to be a UDP packet. The input packet *MUST* have
+ its length and checksum fields in IP and UDP headers set appropriately. This
+ can be done by "rebuilding" the scapy object. e.g.,
+ ip6_packet = scapy.IPv6(str(ip6_packet))
+
+ TODO: Support TCP
+
+ Args:
+ packet: a scapy.IPv6 or scapy.IP packet
+ spi: security parameter index for ESP header in host byte order
+ seq: sequence number for ESP header
+ tun_addrs: A tuple of (local, remote) addresses for tunnel mode, or None
+ to request a transport mode packet.
+
+ Return:
+ The encrypted packet (scapy.IPv6 or scapy.IP)
+ """
+ # The top-level packet changes in tunnel mode, which would invalidate
+ # the passed-in packet pointer. For consistency, this function now returns
+ # a new packet and does not modify the user's original packet.
+ packet = packet.copy()
+ udp_layer = packet.getlayer(scapy.UDP)
+ if not udp_layer:
+ raise ValueError("Expected a UDP packet")
+ # Build an ESP header.
+ esp_packet = scapy.Raw(xfrm.EspHdr((spi, seq)).Pack())
+
+ if tun_addrs:
+ tsrc_addr, tdst_addr = tun_addrs
+ outer_version = net_test.GetAddressVersion(tsrc_addr)
+ ip_type = {4: scapy.IP, 6: scapy.IPv6}[outer_version]
+ new_ip_layer = ip_type(src=tsrc_addr, dst=tdst_addr)
+ inner_layer = packet
+ esp_nexthdr = {scapy.IPv6: IPPROTO_IPV6,
+ scapy.IP: IPPROTO_IPIP}[type(packet)]
+ else:
+ new_ip_layer = None
+ inner_layer = udp_layer
+ esp_nexthdr = IPPROTO_UDP
+
+
+ # ESP padding per RFC 4303 section 2.4.
+ # For a null cipher with a block size of 1, padding is only necessary to
+ # ensure that the 1-byte Pad Length and Next Header fields are right aligned
+ # on a 4-byte boundary.
+ esplen = (len(inner_layer) + 2) # UDP length plus Pad Length and Next Header.
+ padlen = util.GetPadLength(4, esplen)
+ # The pad bytes are consecutive integers starting from 0x01.
+ padding = "".join((chr(i) for i in xrange(1, padlen + 1)))
+ trailer = padding + struct.pack("BB", padlen, esp_nexthdr)
+
+ # Assemble the packet.
+ esp_packet.payload = scapy.Raw(inner_layer)
+ packet = new_ip_layer if new_ip_layer else packet
+ packet.payload = scapy.Raw(str(esp_packet) + trailer)
+
+ # TODO: Can we simplify this and avoid the initial copy()?
+ # Fix the IPv4/IPv6 headers.
+ if type(packet) is scapy.IPv6:
+ packet.nh = IPPROTO_ESP
+ # Recompute plen.
+ packet.plen = None
+ packet = scapy.IPv6(str(packet))
+ elif type(packet) is scapy.IP:
+ packet.proto = IPPROTO_ESP
+ # Recompute IPv4 len and checksum.
+ packet.len = None
+ packet.chksum = None
+ packet = scapy.IP(str(packet))
+ else:
+ raise ValueError("First layer in packet should be IPv4 or IPv6: " + repr(packet))
+ return packet
+
+
+def DecryptPacketWithNull(packet):
+ """Apply null decryption to a packet.
+
+ This performs ESP decapsulation on the given packet. The input packet is
+ assumed to be a UDP packet. This function will remove the ESP header and
+ trailer bytes from an ESP packet.
+
+ TODO: Support TCP
+
+ Args:
+ packet: a scapy.IPv6 or scapy.IP packet
+
+ Returns:
+ A tuple of decrypted packet (scapy.IPv6 or scapy.IP) and EspHdr
+ """
+ esp_hdr, esp_data = cstruct.Read(str(packet.payload), xfrm.EspHdr)
+ # Parse and strip ESP trailer.
+ pad_len, esp_nexthdr = struct.unpack("BB", esp_data[-2:])
+ trailer_len = pad_len + 2 # Add the size of the pad_len and next_hdr fields.
+ LayerType = {
+ IPPROTO_IPIP: scapy.IP,
+ IPPROTO_IPV6: scapy.IPv6,
+ IPPROTO_UDP: scapy.UDP}[esp_nexthdr]
+ next_layer = LayerType(esp_data[:-trailer_len])
+ if esp_nexthdr in [IPPROTO_IPIP, IPPROTO_IPV6]:
+ # Tunnel mode decap is simple. Return the inner packet.
+ return next_layer, esp_hdr
+
+ # Cut out the ESP header.
+ packet.payload = next_layer
+ # Fix the IPv4/IPv6 headers.
+ if type(packet) is scapy.IPv6:
+ packet.nh = IPPROTO_UDP
+ packet.plen = None # Recompute packet length.
+ packet = scapy.IPv6(str(packet))
+ elif type(packet) is scapy.IP:
+ packet.proto = IPPROTO_UDP
+ packet.len = None # Recompute packet length.
+ packet.chksum = None # Recompute IPv4 checksum.
+ packet = scapy.IP(str(packet))
+ else:
+ raise ValueError("First layer in packet should be IPv4 or IPv6: " + repr(packet))
+ return packet, esp_hdr
+
+
+class XfrmBaseTest(multinetwork_base.MultiNetworkBaseTest):
+ """Base test class for all XFRM-related testing."""
+
+ def _ExpectEspPacketOn(self, netid, spi, seq, length, src_addr, dst_addr):
+ """Read a packet from a netid and verify its properties.
+
+ Args:
+ netid: netid from which to read an ESP packet
+ spi: SPI of the ESP packet in host byte order
+ seq: sequence number of the ESP packet
+ length: length of the packet's ESP payload or None to skip this check
+ src_addr: source address of the packet or None to skip this check
+ dst_addr: destination address of the packet or None to skip this check
+
+ Returns:
+ scapy.IP/IPv6: the read packet
+ """
+ packets = self.ReadAllPacketsOn(netid)
+ self.assertEquals(1, len(packets))
+ packet = packets[0]
+ if length is not None:
+ self.assertEquals(length, len(packet.payload))
+ if dst_addr is not None:
+ self.assertEquals(dst_addr, packet.dst)
+ if src_addr is not None:
+ self.assertEquals(src_addr, packet.src)
+ # extract the ESP header
+ esp_hdr, _ = cstruct.Read(str(packet.payload), xfrm.EspHdr)
+ self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr)
+ return packet
+
+
+# TODO: delete this when we're more diligent about deleting our SAs.
+class XfrmLazyTest(XfrmBaseTest):
+ """Base test class Xfrm tests that cleans XFRM state on teardown."""
+ def setUp(self):
+ super(XfrmBaseTest, self).setUp()
+ self.xfrm = xfrm.Xfrm()
+ self.xfrm.FlushSaInfo()
+ self.xfrm.FlushPolicyInfo()
+
+ def tearDown(self):
+ super(XfrmBaseTest, self).tearDown()
+ self.xfrm.FlushSaInfo()
+ self.xfrm.FlushPolicyInfo()
diff --git a/net/test/xfrm_test.py b/net/test/xfrm_test.py
index 5700ad8..3a3d9b0 100755
--- a/net/test/xfrm_test.py
+++ b/net/test/xfrm_test.py
@@ -16,75 +16,59 @@
# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
from errno import * # pylint: disable=wildcard-import
-import os
-import random
-import re
from scapy import all as scapy
from socket import * # pylint: disable=wildcard-import
import struct
import subprocess
-import time
+import threading
import unittest
+import csocket
+import cstruct
import multinetwork_base
import net_test
-import netlink
import packets
import xfrm
+import xfrm_base
-XFRM_ADDR_ANY = 16 * "\x00"
-LOOPBACK = 15 * "\x00" + "\x01"
ENCRYPTED_PAYLOAD = ("b1c74998efd6326faebe2061f00f2c750e90e76001664a80c287b150"
"59e74bf949769cc6af71e51b539e7de3a2a14cb05a231b969e035174"
"d98c5aa0cef1937db98889ec0d08fa408fecf616")
-ENCRYPTION_KEY = ("308146eb3bd84b044573d60f5a5fd159"
- "57c7d4fe567a2120f35bae0f9869ec22".decode("hex"))
-AUTH_TRUNC_KEY = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
TEST_ADDR1 = "2001:4860:4860::8888"
TEST_ADDR2 = "2001:4860:4860::8844"
+# IP addresses to use for tunnel endpoints. For generality, these should be
+# different from the addresses we send packets to.
+TUNNEL_ENDPOINTS = {4: "8.8.4.4", 6: TEST_ADDR2}
+
TEST_SPI = 0x1234
+TEST_SPI2 = 0x1235
-ALL_ALGORITHMS = 0xffffffff
-ALGO_CBC_AES_256 = xfrm.XfrmAlgo(("cbc(aes)", 256))
-ALGO_HMAC_SHA1 = xfrm.XfrmAlgoAuth(("hmac(sha1)", 128, 96))
-class XfrmTest(multinetwork_base.MultiNetworkBaseTest):
- @classmethod
- def setUpClass(cls):
- super(XfrmTest, cls).setUpClass()
- cls.xfrm = xfrm.Xfrm()
-
- def setUp(self):
- # TODO: delete this when we're more diligent about deleting our SAs.
- super(XfrmTest, self).setUp()
- subprocess.call("ip xfrm state flush".split())
-
- def expectIPv6EspPacketOn(self, netid, spi, seq, length):
- packets = self.ReadAllPacketsOn(netid)
- self.assertEquals(1, len(packets))
- packet = packets[0]
- self.assertEquals(IPPROTO_ESP, packet.nh)
- spi_seq = struct.pack("!II", spi, seq)
- self.assertEquals(spi_seq, str(packet.payload)[:len(spi_seq)])
- self.assertEquals(length, len(packet.payload))
+class XfrmFunctionalTest(xfrm_base.XfrmLazyTest):
def assertIsUdpEncapEsp(self, packet, spi, seq, length):
self.assertEquals(IPPROTO_UDP, packet.proto)
- self.assertEquals(4500, packet.dport)
- # Skip UDP header. TODO: isn't there a better way to do this?
- payload = str(packet.payload)[8:]
- self.assertEquals(length, len(payload))
- spi_seq = struct.pack("!II", ntohl(spi), seq)
- self.assertEquals(spi_seq, str(payload)[:len(spi_seq)])
+ udp_hdr = packet[scapy.UDP]
+ self.assertEquals(4500, udp_hdr.dport)
+ self.assertEquals(length, len(udp_hdr))
+ esp_hdr, _ = cstruct.Read(str(udp_hdr.payload), xfrm.EspHdr)
+ # FIXME: this file currently swaps SPI byte order manually, so SPI needs to
+ # be double-swapped here.
+ self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr)
+
+ def CreateNewSa(self, localAddr, remoteAddr, spi, reqId, encap_tmpl,
+ null_auth=False):
+ auth_algo = (
+ xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
+ self.xfrm.AddSaInfo(localAddr, remoteAddr, spi, xfrm.XFRM_MODE_TRANSPORT,
+ reqId, xfrm_base._ALGO_CBC_AES_256, auth_algo, None,
+ encap_tmpl, None, None)
def testAddSa(self):
- self.xfrm.AddMinimalSaInfo("::", TEST_ADDR1, htonl(TEST_SPI), IPPROTO_ESP,
- xfrm.XFRM_MODE_TRANSPORT, 3320,
- ALGO_CBC_AES_256, ENCRYPTION_KEY,
- ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, None)
+ self.CreateNewSa("::", TEST_ADDR1, TEST_SPI, 3320, None)
expected = (
"src :: dst 2001:4860:4860::8888\n"
"\tproto esp spi 0x00001234 reqid 3320 mode transport\n"
@@ -92,99 +76,137 @@
"\tauth-trunc hmac(sha1) 0x%s 96\n"
"\tenc cbc(aes) 0x%s\n"
"\tsel src ::/0 dst ::/0 \n" % (
- AUTH_TRUNC_KEY.encode("hex"), ENCRYPTION_KEY.encode("hex")))
+ xfrm_base._AUTHENTICATION_KEY_128.encode("hex"),
+ xfrm_base._ENCRYPTION_KEY_256.encode("hex")))
actual = subprocess.check_output("ip xfrm state".split())
+ # Newer versions of IP also show anti-replay context. Don't choke if it's
+ # missing.
+ actual = actual.replace(
+ "\tanti-replay context: seq 0x0, oseq 0x0, bitmap 0x00000000\n", "")
try:
self.assertMultiLineEqual(expected, actual)
finally:
- self.xfrm.DeleteSaInfo(TEST_ADDR1, htonl(TEST_SPI), IPPROTO_ESP)
+ self.xfrm.DeleteSaInfo(TEST_ADDR1, TEST_SPI, IPPROTO_ESP)
+ def testFlush(self):
+ self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
+ self.CreateNewSa("::", "2000::", TEST_SPI, 1234, None)
+ self.CreateNewSa("0.0.0.0", "192.0.2.1", TEST_SPI, 4321, None)
+ self.assertEquals(2, len(self.xfrm.DumpSaInfo()))
+ self.xfrm.FlushSaInfo()
+ self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
- @unittest.skipUnless(net_test.LINUX_VERSION < (4, 4, 0), "regression")
- def testSocketPolicy(self):
- # Open an IPv6 UDP socket and connect it.
- s = socket(AF_INET6, SOCK_DGRAM, 0)
- netid = random.choice(self.NETIDS)
+ def _TestSocketPolicy(self, version):
+ # Open a UDP socket and connect it.
+ family = net_test.GetAddressFamily(version)
+ s = socket(family, SOCK_DGRAM, 0)
+ netid = self.RandomNetid()
self.SelectInterface(s, netid, "mark")
- s.connect((TEST_ADDR1, 53))
+
+ remotesockaddr = self.GetRemoteSocketAddress(version)
+ s.connect((remotesockaddr, 53))
saddr, sport = s.getsockname()[:2]
daddr, dport = s.getpeername()[:2]
+ if version == 5:
+ saddr = saddr.replace("::ffff:", "")
+ daddr = daddr.replace("::ffff:", "")
- # Create a selector that matches all UDP packets. It's not actually used to
- # select traffic, that will be done by the socket policy, which selects the
- # SA entry (i.e., xfrm state) via the SPI and reqid.
- sel = xfrm.XfrmSelector((XFRM_ADDR_ANY, XFRM_ADDR_ANY, 0, 0, 0, 0,
- AF_INET6, 0, 0, IPPROTO_UDP, 0, 0))
+ reqid = 0
- # Create a user policy that specifies that all outbound packets matching the
- # (essentially no-op) selector should be encrypted.
- info = xfrm.XfrmUserpolicyInfo((sel,
- xfrm.NO_LIFETIME_CFG, xfrm.NO_LIFETIME_CUR,
- 100, 0,
- xfrm.XFRM_POLICY_OUT,
- xfrm.XFRM_POLICY_ALLOW,
- xfrm.XFRM_POLICY_LOCALOK,
- xfrm.XFRM_SHARE_UNIQUE))
+ desc, pkt = packets.UDP(version, saddr, daddr, sport=sport)
+ s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ self.ExpectPacketOn(netid, "Send after socket, expected %s" % desc, pkt)
- # Create a template that specifies the SPI and the protocol.
- xfrmid = xfrm.XfrmId((XFRM_ADDR_ANY, htonl(TEST_SPI), IPPROTO_ESP))
- tmpl = xfrm.XfrmUserTmpl((xfrmid, AF_INET6, XFRM_ADDR_ANY, 0,
- xfrm.XFRM_MODE_TRANSPORT, xfrm.XFRM_SHARE_UNIQUE,
- 0, # require
- ALL_ALGORITHMS, # auth algos
- ALL_ALGORITHMS, # encryption algos
- ALL_ALGORITHMS)) # compression algos
-
- # Set the policy and template on our socket.
- data = info.Pack() + tmpl.Pack()
- s.setsockopt(IPPROTO_IPV6, xfrm.IPV6_XFRM_POLICY, data)
+ # Using IPv4 XFRM on a dual-stack socket requires setting an AF_INET policy
+ # that's written in terms of IPv4 addresses.
+ xfrm_version = 4 if version == 5 else version
+ xfrm_family = net_test.GetAddressFamily(xfrm_version)
+ xfrm_base.ApplySocketPolicy(s, xfrm_family, xfrm.XFRM_POLICY_OUT,
+ TEST_SPI, reqid, None)
# Because the policy has level set to "require" (the default), attempting
# to send a packet results in an error, because there is no SA that
# matches the socket policy we set.
self.assertRaisesErrno(
EAGAIN,
- s.sendto, net_test.UDP_PAYLOAD, (TEST_ADDR1, 53))
+ s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+
+ # If there is a user space key manager, calling sendto() after applying the socket policy
+ # creates an SA whose state is XFRM_STATE_ACQ. So this just deletes it.
+ # If there is no user space key manager, deleting SA returns ESRCH as the error code.
+ try:
+ self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI, IPPROTO_ESP)
+ except IOError as e:
+ self.assertEquals(ESRCH, e.errno, "Unexpected error when deleting ACQ SA")
# Adding a matching SA causes the packet to go out encrypted. The SA's
# SPI must match the one in our template, and the destination address must
# match the packet's destination address (in tunnel mode, it has to match
# the tunnel destination).
- reqid = 0
- self.xfrm.AddMinimalSaInfo("::", TEST_ADDR1, htonl(TEST_SPI), IPPROTO_ESP,
- xfrm.XFRM_MODE_TRANSPORT, reqid,
- ALGO_CBC_AES_256, ENCRYPTION_KEY,
- ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, None)
+ self.CreateNewSa(
+ net_test.GetWildcardAddress(xfrm_version),
+ self.GetRemoteAddress(xfrm_version), TEST_SPI, reqid, None)
- s.sendto(net_test.UDP_PAYLOAD, (TEST_ADDR1, 53))
- self.expectIPv6EspPacketOn(netid, TEST_SPI, 1, 84)
+ s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ expected_length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TRANSPORT,
+ version, False,
+ net_test.UDP_PAYLOAD,
+ xfrm_base._ALGO_HMAC_SHA1,
+ xfrm_base._ALGO_CBC_AES_256)
+ self._ExpectEspPacketOn(netid, TEST_SPI, 1, expected_length, None, None)
# Sending to another destination doesn't work: again, no matching SA.
+ remoteaddr2 = self.GetOtherRemoteSocketAddress(version)
self.assertRaisesErrno(
EAGAIN,
- s.sendto, net_test.UDP_PAYLOAD, (TEST_ADDR2, 53))
+ s.sendto, net_test.UDP_PAYLOAD, (remoteaddr2, 53))
# Sending on another socket without the policy applied results in an
# unencrypted packet going out.
- s2 = socket(AF_INET6, SOCK_DGRAM, 0)
+ s2 = socket(family, SOCK_DGRAM, 0)
self.SelectInterface(s2, netid, "mark")
- s2.sendto(net_test.UDP_PAYLOAD, (TEST_ADDR1, 53))
- packets = self.ReadAllPacketsOn(netid)
- self.assertEquals(1, len(packets))
- packet = packets[0]
- self.assertEquals(IPPROTO_UDP, packet.nh)
+ s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ pkts = self.ReadAllPacketsOn(netid)
+ self.assertEquals(1, len(pkts))
+ packet = pkts[0]
+
+ protocol = packet.nh if version == 6 else packet.proto
+ self.assertEquals(IPPROTO_UDP, protocol)
# Deleting the SA causes the first socket to return errors again.
- self.xfrm.DeleteSaInfo(TEST_ADDR1, htonl(TEST_SPI), IPPROTO_ESP)
+ self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI,
+ IPPROTO_ESP)
self.assertRaisesErrno(
EAGAIN,
- s.sendto, net_test.UDP_PAYLOAD, (TEST_ADDR1, 53))
+ s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ # Clear the socket policy and expect a cleartext packet.
+ xfrm_base.SetPolicySockopt(s, family, None)
+ s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ self.ExpectPacketOn(netid, "Send after clear, expected %s" % desc, pkt)
- def testUdpEncapWithSocketPolicy(self):
- # TODO: test IPv6 instead of IPv4.
- netid = random.choice(self.NETIDS)
+ # Clearing the policy twice is safe.
+ xfrm_base.SetPolicySockopt(s, family, None)
+ s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
+ self.ExpectPacketOn(netid, "Send after clear 2, expected %s" % desc, pkt)
+
+ # Clearing if a policy was never set is safe.
+ s = socket(AF_INET6, SOCK_DGRAM, 0)
+ xfrm_base.SetPolicySockopt(s, family, None)
+
+ def testSocketPolicyIPv4(self):
+ self._TestSocketPolicy(4)
+
+ def testSocketPolicyIPv6(self):
+ self._TestSocketPolicy(6)
+
+ def testSocketPolicyMapped(self):
+ self._TestSocketPolicy(5)
+
+ # Sets up sockets and marks to correct netid
+ def _SetupUdpEncapSockets(self):
+ netid = self.RandomNetid()
myaddr = self.MyAddress(4, netid)
remoteaddr = self.GetRemoteAddress(4)
@@ -192,80 +214,66 @@
# packets works without this (and potentially can send packets with a source
# port belonging to another application), but receiving requires the port to
# be bound and the encapsulation socket option enabled.
- encap_socket = net_test.Socket(AF_INET, SOCK_DGRAM, 0)
- encap_socket.bind((myaddr, 0))
- encap_port = encap_socket.getsockname()[1]
- encap_socket.setsockopt(IPPROTO_UDP, xfrm.UDP_ENCAP,
- xfrm.UDP_ENCAP_ESPINUDP)
+ encap_sock = net_test.Socket(AF_INET, SOCK_DGRAM, 0)
+ encap_sock.bind((myaddr, 0))
+ encap_port = encap_sock.getsockname()[1]
+ encap_sock.setsockopt(IPPROTO_UDP, xfrm.UDP_ENCAP, xfrm.UDP_ENCAP_ESPINUDP)
# Open a socket to send traffic.
s = socket(AF_INET, SOCK_DGRAM, 0)
self.SelectInterface(s, netid, "mark")
s.connect((remoteaddr, 53))
- # Create a UDP encap policy and template inbound and outbound and apply
- # them to s.
- sel = xfrm.XfrmSelector((XFRM_ADDR_ANY, XFRM_ADDR_ANY, 0, 0, 0, 0,
- AF_INET, 0, 0, IPPROTO_UDP, 0, 0))
+ return netid, myaddr, remoteaddr, encap_sock, encap_port, s
- # Use the same SPI both inbound and outbound because this lets us receive
- # encrypted packets by simply replaying the packets the kernel sends.
+ # Sets up SAs and applies socket policy to given socket
+ def _SetupUdpEncapSaPair(self, myaddr, remoteaddr, in_spi, out_spi,
+ encap_port, s, use_null_auth):
in_reqid = 123
- in_spi = htonl(TEST_SPI)
out_reqid = 456
- out_spi = htonl(TEST_SPI)
-
- # Start with the outbound policy.
- # TODO: what happens without XFRM_SHARE_UNIQUE?
- info = xfrm.XfrmUserpolicyInfo((sel,
- xfrm.NO_LIFETIME_CFG, xfrm.NO_LIFETIME_CUR,
- 100, 0,
- xfrm.XFRM_POLICY_OUT,
- xfrm.XFRM_POLICY_ALLOW,
- xfrm.XFRM_POLICY_LOCALOK,
- xfrm.XFRM_SHARE_UNIQUE))
- xfrmid = xfrm.XfrmId((XFRM_ADDR_ANY, out_spi, IPPROTO_ESP))
- usertmpl = xfrm.XfrmUserTmpl((xfrmid, AF_INET, XFRM_ADDR_ANY, out_reqid,
- xfrm.XFRM_MODE_TRANSPORT, xfrm.XFRM_SHARE_UNIQUE,
- 0, # require
- ALL_ALGORITHMS, # auth algos
- ALL_ALGORITHMS, # encryption algos
- ALL_ALGORITHMS)) # compression algos
-
- data = info.Pack() + usertmpl.Pack()
- s.setsockopt(IPPROTO_IP, xfrm.IP_XFRM_POLICY, data)
-
- # Uncomment for debugging.
- # subprocess.call("ip xfrm policy".split())
# Create inbound and outbound SAs that specify UDP encapsulation.
encaptmpl = xfrm.XfrmEncapTmpl((xfrm.UDP_ENCAP_ESPINUDP, htons(encap_port),
htons(4500), 16 * "\x00"))
- self.xfrm.AddMinimalSaInfo(myaddr, remoteaddr, out_spi, IPPROTO_ESP,
- xfrm.XFRM_MODE_TRANSPORT, out_reqid,
- ALGO_CBC_AES_256, ENCRYPTION_KEY,
- ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, encaptmpl)
+ self.CreateNewSa(myaddr, remoteaddr, out_spi, out_reqid, encaptmpl,
+ use_null_auth)
# Add an encap template that's the mirror of the outbound one.
encaptmpl.sport, encaptmpl.dport = encaptmpl.dport, encaptmpl.sport
- self.xfrm.AddMinimalSaInfo(remoteaddr, myaddr, in_spi, IPPROTO_ESP,
- xfrm.XFRM_MODE_TRANSPORT, in_reqid,
- ALGO_CBC_AES_256, ENCRYPTION_KEY,
- ALGO_HMAC_SHA1, AUTH_TRUNC_KEY, encaptmpl)
+ self.CreateNewSa(remoteaddr, myaddr, in_spi, in_reqid, encaptmpl,
+ use_null_auth)
+
+ # Apply socket policies to s.
+ xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_OUT, out_spi,
+ out_reqid, None)
+
+ # TODO: why does this work without a per-socket policy applied?
+ # The received packet obviously matches an SA, but don't inbound packets
+ # need to match a policy as well? (b/71541609)
+ xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_IN, in_spi,
+ in_reqid, None)
# Uncomment for debugging.
# subprocess.call("ip xfrm state".split())
+ # Check that packets can be sent and received.
+ def _VerifyUdpEncapSocket(self, netid, remoteaddr, myaddr, encap_port, sock,
+ in_spi, out_spi, null_auth, seq_num):
# Now send a packet.
- s.sendto("foo", (remoteaddr, 53))
- srcport = s.getsockname()[1]
- # s.send("foo") # TODO: WHY DOES THIS NOT WORK?
+ sock.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
+ srcport = sock.getsockname()[1]
# Expect to see an UDP encapsulated packet.
- packets = self.ReadAllPacketsOn(netid)
- self.assertEquals(1, len(packets))
- packet = packets[0]
- self.assertIsUdpEncapEsp(packet, out_spi, 1, 52)
+ pkts = self.ReadAllPacketsOn(netid)
+ self.assertEquals(1, len(pkts))
+ packet = pkts[0]
+
+ auth_algo = (
+ xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
+ expected_len = xfrm_base.GetEspPacketLength(
+ xfrm.XFRM_MODE_TRANSPORT, 4, True, net_test.UDP_PAYLOAD, auth_algo,
+ xfrm_base._ALGO_CBC_AES_256)
+ self.assertIsUdpEncapEsp(packet, out_spi, seq_num, expected_len)
# Now test the receive path. Because we don't know how to decrypt packets,
# we just play back the encrypted packet that kernel sent earlier. We swap
@@ -276,46 +284,521 @@
# be sent from srcport to port 53. Open another socket on that port, and
# apply the inbound policy to it.
twisted_socket = socket(AF_INET, SOCK_DGRAM, 0)
- net_test.SetSocketTimeout(twisted_socket, 100)
+ csocket.SetSocketTimeout(twisted_socket, 100)
twisted_socket.bind(("0.0.0.0", 53))
- # TODO: why does this work even without the per-socket policy applied? The
- # received packet obviously matches an SA, but don't inbound packets need to
- # match a policy as well?
- info.dir = xfrm.XFRM_POLICY_IN
- xfrmid.spi = in_spi
- usertmpl.reqid = in_reqid
- data = info.Pack() + usertmpl.Pack()
- twisted_socket.setsockopt(IPPROTO_IP, xfrm.IP_XFRM_POLICY, data)
-
- # Save the payload of the packet so we can replay it back to ourselves.
+ # Save the payload of the packet so we can replay it back to ourselves, and
+ # replace the SPI with our inbound SPI.
payload = str(packet.payload)[8:]
- spi_seq = struct.pack("!II", ntohl(in_spi), 1)
+ spi_seq = xfrm.EspHdr((in_spi, seq_num)).Pack()
payload = spi_seq + payload[len(spi_seq):]
- # Tamper with the packet and check that it's dropped and counted as invalid.
sainfo = self.xfrm.FindSaInfo(in_spi)
- self.assertEquals(0, sainfo.stats.integrity_failed)
- broken = payload[:25] + chr((ord(payload[25]) + 1) % 256) + payload[26:]
- incoming = (scapy.IP(src=remoteaddr, dst=myaddr) /
- scapy.UDP(sport=4500, dport=encap_port) / broken)
- self.ReceivePacketOn(netid, incoming)
- sainfo = self.xfrm.FindSaInfo(in_spi)
- self.assertEquals(1, sainfo.stats.integrity_failed)
+ start_integrity_failures = sainfo.stats.integrity_failed
# Now play back the valid packet and check that we receive it.
incoming = (scapy.IP(src=remoteaddr, dst=myaddr) /
scapy.UDP(sport=4500, dport=encap_port) / payload)
+ incoming = scapy.IP(str(incoming))
self.ReceivePacketOn(netid, incoming)
- data, src = twisted_socket.recvfrom(4096)
- self.assertEquals("foo", data)
- self.assertEquals((remoteaddr, srcport), src)
- # Check that unencrypted packets are not received.
- unencrypted = (scapy.IP(src=remoteaddr, dst=myaddr) /
- scapy.UDP(sport=srcport, dport=53) / "foo")
+ sainfo = self.xfrm.FindSaInfo(in_spi)
+
+ # TODO: break this out into a separate test
+ # If our SPIs are different, and we aren't using null authentication,
+ # we expect the packet to be dropped. We also expect that the integrity
+ # failure counter to increase, as SPIs are part of the authenticated or
+ # integrity-verified portion of the packet.
+ if not null_auth and in_spi != out_spi:
+ self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
+ self.assertEquals(start_integrity_failures + 1,
+ sainfo.stats.integrity_failed)
+ else:
+ data, src = twisted_socket.recvfrom(4096)
+ self.assertEquals(net_test.UDP_PAYLOAD, data)
+ self.assertEquals((remoteaddr, srcport), src)
+ self.assertEquals(start_integrity_failures, sainfo.stats.integrity_failed)
+
+ # Check that unencrypted packets on twisted_socket are not received.
+ unencrypted = (
+ scapy.IP(src=remoteaddr, dst=myaddr) / scapy.UDP(
+ sport=srcport, dport=53) / net_test.UDP_PAYLOAD)
self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
+ def _RunEncapSocketPolicyTest(self, in_spi, out_spi, use_null_auth):
+ netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
+ self._SetupUdpEncapSockets()
+
+ self._SetupUdpEncapSaPair(myaddr, remoteaddr, in_spi, out_spi, encap_port,
+ s, use_null_auth)
+
+ # Check that UDP encap sockets work with socket policy and given SAs
+ self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s, in_spi,
+ out_spi, use_null_auth, 1)
+
+ # TODO: Add tests for ESP (non-encap) sockets.
+ def testUdpEncapSameSpisNullAuth(self):
+ # Use the same SPI both inbound and outbound because this lets us receive
+ # encrypted packets by simply replaying the packets the kernel sends
+ # without having to disable authentication
+ self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, True)
+
+ def testUdpEncapSameSpis(self):
+ self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, False)
+
+ def testUdpEncapDifferentSpisNullAuth(self):
+ self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, True)
+
+ def testUdpEncapDifferentSpis(self):
+ self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, False)
+
+ def testUdpEncapRekey(self):
+ # Select the two SPIs that will be used
+ start_spi = TEST_SPI
+ rekey_spi = TEST_SPI2
+
+ # Setup sockets
+ netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
+ self._SetupUdpEncapSockets()
+
+ # The SAs must use null authentication, since we change SPIs on the fly
+ # Without null authentication, this would result in an ESP authentication
+ # error since the SPI is part of the authenticated section. The packet
+ # would then be dropped
+ self._SetupUdpEncapSaPair(myaddr, remoteaddr, start_spi, start_spi,
+ encap_port, s, True)
+
+ # Check that UDP encap sockets work with socket policy and given SAs
+ self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
+ start_spi, start_spi, True, 1)
+
+ # Rekey this socket using the make-before-break paradigm. First we create
+ # new SAs, update the per-socket policies, and only then remove the old SAs
+ #
+ # This allows us to switch to the new SA without breaking the outbound path.
+ self._SetupUdpEncapSaPair(myaddr, remoteaddr, rekey_spi, rekey_spi,
+ encap_port, s, True)
+
+ # Check that UDP encap socket works with updated socket policy, sending
+ # using new SA, but receiving on both old and new SAs
+ self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
+ rekey_spi, rekey_spi, True, 1)
+ self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
+ start_spi, rekey_spi, True, 2)
+
+ # Delete old SAs
+ self.xfrm.DeleteSaInfo(remoteaddr, start_spi, IPPROTO_ESP)
+ self.xfrm.DeleteSaInfo(myaddr, start_spi, IPPROTO_ESP)
+
+ # Check that UDP encap socket works with updated socket policy and new SAs
+ self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
+ rekey_spi, rekey_spi, True, 3)
+
+ def testAllocSpecificSpi(self):
+ spi = 0xABCD
+ new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
+ self.assertEquals(spi, new_sa.id.spi)
+
+ def testAllocSpecificSpiUnavailable(self):
+ """Attempt to allocate the same SPI twice."""
+ spi = 0xABCD
+ new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
+ self.assertEquals(spi, new_sa.id.spi)
+ with self.assertRaisesErrno(ENOENT):
+ new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
+
+ def testAllocRangeSpi(self):
+ start, end = 0xABCD0, 0xABCDF
+ new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
+ spi = new_sa.id.spi
+ self.assertGreaterEqual(spi, start)
+ self.assertLessEqual(spi, end)
+
+ def testAllocRangeSpiUnavailable(self):
+ """Attempt to allocate N+1 SPIs from a range of size N."""
+ start, end = 0xABCD0, 0xABCDF
+ range_size = end - start + 1
+ spis = set()
+ # Assert that allocating SPI fails when none are available.
+ with self.assertRaisesErrno(ENOENT):
+ # Allocating range_size + 1 SPIs is guaranteed to fail. Due to the way
+ # kernel picks random SPIs, this has a high probability of failing before
+ # reaching that limit.
+ for i in xrange(range_size + 1):
+ new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
+ spi = new_sa.id.spi
+ self.assertNotIn(spi, spis)
+ spis.add(spi)
+
+ def testSocketPolicyDstCacheV6(self):
+ self._TestSocketPolicyDstCache(6)
+
+ def testSocketPolicyDstCacheV4(self):
+ self._TestSocketPolicyDstCache(4)
+
+ def _TestSocketPolicyDstCache(self, version):
+ """Test that destination cache is cleared with socket policy.
+
+ This relies on the fact that connect() on a UDP socket populates the
+ destination cache.
+ """
+
+ # Create UDP socket.
+ family = net_test.GetAddressFamily(version)
+ netid = self.RandomNetid()
+ s = socket(family, SOCK_DGRAM, 0)
+ self.SelectInterface(s, netid, "mark")
+
+ # Populate the socket's destination cache.
+ remote = self.GetRemoteAddress(version)
+ s.connect((remote, 53))
+
+ # Apply a policy to the socket. Should clear dst cache.
+ reqid = 123
+ xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT,
+ TEST_SPI, reqid, None)
+
+ # Policy with no matching SA should result in EAGAIN. If destination cache
+ # failed to clear, then the UDP packet will be sent normally.
+ with self.assertRaisesErrno(EAGAIN):
+ s.send(net_test.UDP_PAYLOAD)
+ self.ExpectNoPacketsOn(netid, "Packet not blocked by policy")
+
+ def _CheckNullEncryptionTunnelMode(self, version):
+ family = net_test.GetAddressFamily(version)
+ netid = self.RandomNetid()
+ local_addr = self.MyAddress(version, netid)
+ remote_addr = self.GetRemoteAddress(version)
+
+ # Borrow the address of another netId as the source address of the tunnel
+ tun_local = self.MyAddress(version, self.RandomNetid(netid))
+ # For generality, pick a tunnel endpoint that's not the address we
+ # connect the socket to.
+ tun_remote = TUNNEL_ENDPOINTS[version]
+
+ # Output
+ self.xfrm.AddSaInfo(
+ tun_local, tun_remote, 0xABCD, xfrm.XFRM_MODE_TUNNEL, 123,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
+ None, None, None, netid)
+ # Input
+ self.xfrm.AddSaInfo(
+ tun_remote, tun_local, 0x9876, xfrm.XFRM_MODE_TUNNEL, 456,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
+ None, None, None, None)
+
+ sock = net_test.UDPSocket(family)
+ self.SelectInterface(sock, netid, "mark")
+ sock.bind((local_addr, 0))
+ local_port = sock.getsockname()[1]
+ remote_port = 5555
+
+ xfrm_base.ApplySocketPolicy(
+ sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123,
+ (tun_local, tun_remote))
+ xfrm_base.ApplySocketPolicy(
+ sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456,
+ (tun_remote, tun_local))
+
+ # Create and receive an ESP packet.
+ IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
+ input_pkt = (IpType(src=remote_addr, dst=local_addr) /
+ scapy.UDP(sport=remote_port, dport=local_port) /
+ "input hello")
+ input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
+ input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876,
+ 1, (tun_remote, tun_local))
+
+ self.ReceivePacketOn(netid, input_pkt)
+ msg, addr = sock.recvfrom(1024)
+ self.assertEquals("input hello", msg)
+ self.assertEquals((remote_addr, remote_port), addr[:2])
+
+ # Send and capture a packet.
+ sock.sendto("output hello", (remote_addr, remote_port))
+ packets = self.ReadAllPacketsOn(netid)
+ self.assertEquals(1, len(packets))
+ output_pkt = packets[0]
+ output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
+ self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8)
+ self.assertEquals(remote_addr, output_pkt.dst)
+ self.assertEquals(remote_port, output_pkt[scapy.UDP].dport)
+ # length of the payload plus the UDP header
+ self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload))
+ self.assertEquals(0xABCD, esp_hdr.spi)
+
+ def testNullEncryptionTunnelMode(self):
+ """Verify null encryption in tunnel mode.
+
+ This test verifies both manual assembly and disassembly of UDP packets
+ with ESP in IPsec tunnel mode.
+ """
+ for version in [4, 6]:
+ self._CheckNullEncryptionTunnelMode(version)
+
+ def _CheckNullEncryptionTransportMode(self, version):
+ family = net_test.GetAddressFamily(version)
+ netid = self.RandomNetid()
+ local_addr = self.MyAddress(version, netid)
+ remote_addr = self.GetRemoteAddress(version)
+
+ # Output
+ self.xfrm.AddSaInfo(
+ local_addr, remote_addr, 0xABCD, xfrm.XFRM_MODE_TRANSPORT, 123,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
+ None, None, None, None)
+ # Input
+ self.xfrm.AddSaInfo(
+ remote_addr, local_addr, 0x9876, xfrm.XFRM_MODE_TRANSPORT, 456,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
+ None, None, None, None)
+
+ sock = net_test.UDPSocket(family)
+ self.SelectInterface(sock, netid, "mark")
+ sock.bind((local_addr, 0))
+ local_port = sock.getsockname()[1]
+ remote_port = 5555
+
+ xfrm_base.ApplySocketPolicy(
+ sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123, None)
+ xfrm_base.ApplySocketPolicy(
+ sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456, None)
+
+ # Create and receive an ESP packet.
+ IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
+ input_pkt = (IpType(src=remote_addr, dst=local_addr) /
+ scapy.UDP(sport=remote_port, dport=local_port) /
+ "input hello")
+ input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
+ input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876, 1, None)
+
+ self.ReceivePacketOn(netid, input_pkt)
+ msg, addr = sock.recvfrom(1024)
+ self.assertEquals("input hello", msg)
+ self.assertEquals((remote_addr, remote_port), addr[:2])
+
+ # Send and capture a packet.
+ sock.sendto("output hello", (remote_addr, remote_port))
+ packets = self.ReadAllPacketsOn(netid)
+ self.assertEquals(1, len(packets))
+ output_pkt = packets[0]
+ output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
+ # length of the payload plus the UDP header
+ self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8)
+ self.assertEquals(remote_addr, output_pkt.dst)
+ self.assertEquals(remote_port, output_pkt[scapy.UDP].dport)
+ self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload))
+ self.assertEquals(0xABCD, esp_hdr.spi)
+
+ def testNullEncryptionTransportMode(self):
+ """Verify null encryption in transport mode.
+
+ This test verifies both manual assembly and disassembly of UDP packets
+ with ESP in IPsec transport mode.
+ """
+ for version in [4, 6]:
+ self._CheckNullEncryptionTransportMode(version)
+
+ def _CheckGlobalPoliciesByMark(self, version):
+ """Tests that global policies may differ by only the mark."""
+ family = net_test.GetAddressFamily(version)
+ sel = xfrm.EmptySelector(family)
+ # Pick 2 arbitrary mark values.
+ mark1 = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
+ mark2 = xfrm.XfrmMark(mark=0xf00d, mask=xfrm_base.MARK_MASK_ALL)
+ # Create a global policy.
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+ tmpl = xfrm.UserTemplate(AF_UNSPEC, 0xfeed, 0, None)
+ # Create the policy with the first mark.
+ self.xfrm.AddPolicyInfo(policy, tmpl, mark1)
+ # Create the same policy but with the second (different) mark.
+ self.xfrm.AddPolicyInfo(policy, tmpl, mark2)
+ # Delete the policies individually
+ self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark1)
+ self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark2)
+
+ def testGlobalPoliciesByMarkV4(self):
+ self._CheckGlobalPoliciesByMark(4)
+
+ def testGlobalPoliciesByMarkV6(self):
+ self._CheckGlobalPoliciesByMark(6)
+
+ def _CheckUpdatePolicy(self, version):
+ """Tests that we can can update the template on a policy."""
+ family = net_test.GetAddressFamily(version)
+ tmpl1 = xfrm.UserTemplate(family, 0xdead, 0, None)
+ tmpl2 = xfrm.UserTemplate(family, 0xbeef, 0, None)
+ sel = xfrm.EmptySelector(family)
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+ mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
+
+ def _CheckTemplateMatch(tmpl):
+ """Dump the SPD and match a single template on a single policy."""
+ dump = self.xfrm.DumpPolicyInfo()
+ self.assertEquals(1, len(dump))
+ _, attributes = dump[0]
+ self.assertEquals(attributes['XFRMA_TMPL'], tmpl)
+
+ # Create a new policy using update.
+ self.xfrm.UpdatePolicyInfo(policy, tmpl1, mark, None)
+ # NEWPOLICY will not update the existing policy. This checks both that
+ # UPDPOLICY created a policy and that NEWPOLICY will not perform updates.
+ _CheckTemplateMatch(tmpl1)
+ with self.assertRaisesErrno(EEXIST):
+ self.xfrm.AddPolicyInfo(policy, tmpl2, mark, None)
+ # Update the policy using UPDPOLICY.
+ self.xfrm.UpdatePolicyInfo(policy, tmpl2, mark, None)
+ # There should only be one policy after update, and it should have the
+ # updated template.
+ _CheckTemplateMatch(tmpl2)
+
+ def testUpdatePolicyV4(self):
+ self._CheckUpdatePolicy(4)
+
+ def testUpdatePolicyV6(self):
+ self._CheckUpdatePolicy(6)
+
+ def _CheckPolicyDifferByDirection(self,version):
+ """Tests that policies can differ only by direction."""
+ family = net_test.GetAddressFamily(version)
+ tmpl = xfrm.UserTemplate(family, 0xdead, 0, None)
+ sel = xfrm.EmptySelector(family)
+ mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+ self.xfrm.AddPolicyInfo(policy, tmpl, mark)
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_IN, sel)
+ self.xfrm.AddPolicyInfo(policy, tmpl, mark)
+
+ def testPolicyDifferByDirectionV4(self):
+ self._CheckPolicyDifferByDirection(4)
+
+ def testPolicyDifferByDirectionV6(self):
+ self._CheckPolicyDifferByDirection(6)
+
+class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest):
+
+ def _CheckTunnelModeOutputMark(self, version, tunsrc, mark, expected_netid):
+ """Tests sending UDP packets to tunnel mode SAs with output marks.
+
+ Opens a UDP socket and binds it to a random netid, then sets up tunnel mode
+ SAs with an output_mark of mark and sets a socket policy to use the SA.
+ Then checks that sending on those SAs sends a packet on expected_netid,
+ or, if expected_netid is zero, checks that sending returns ENETUNREACH.
+
+ Args:
+ version: 4 or 6.
+ tunsrc: A string, the source address of the tunnel.
+ mark: An integer, the output_mark to set in the SA.
+ expected_netid: An integer, the netid to expect the kernel to send the
+ packet on. If None, expect that sendto will fail with ENETUNREACH.
+ """
+ # Open a UDP socket and bind it to a random netid.
+ family = net_test.GetAddressFamily(version)
+ s = socket(family, SOCK_DGRAM, 0)
+ self.SelectInterface(s, self.RandomNetid(), "mark")
+
+ # For generality, pick a tunnel endpoint that's not the address we
+ # connect the socket to.
+ tundst = TUNNEL_ENDPOINTS[version]
+ tun_addrs = (tunsrc, tundst)
+
+ # Create a tunnel mode SA and use XFRM_OUTPUT_MARK to bind it to netid.
+ spi = TEST_SPI * mark
+ reqid = 100 + spi
+ self.xfrm.AddSaInfo(tunsrc, tundst, spi, xfrm.XFRM_MODE_TUNNEL, reqid,
+ xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1,
+ None, None, None, mark)
+
+ # Set a socket policy to use it.
+ xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT, spi, reqid,
+ tun_addrs)
+
+ # Send a packet and check that we see it on the wire.
+ remoteaddr = self.GetRemoteAddress(version)
+
+ packetlen = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL, version,
+ False, net_test.UDP_PAYLOAD,
+ xfrm_base._ALGO_HMAC_SHA1,
+ xfrm_base._ALGO_CBC_AES_256)
+
+ if expected_netid is not None:
+ s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
+ self._ExpectEspPacketOn(expected_netid, spi, 1, packetlen, tunsrc, tundst)
+ else:
+ with self.assertRaisesErrno(ENETUNREACH):
+ s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
+
+ def testTunnelModeOutputMarkIPv4(self):
+ for netid in self.NETIDS:
+ tunsrc = self.MyAddress(4, netid)
+ self._CheckTunnelModeOutputMark(4, tunsrc, netid, netid)
+
+ def testTunnelModeOutputMarkIPv6(self):
+ for netid in self.NETIDS:
+ tunsrc = self.MyAddress(6, netid)
+ self._CheckTunnelModeOutputMark(6, tunsrc, netid, netid)
+
+ def testTunnelModeOutputNoMarkIPv4(self):
+ tunsrc = self.MyAddress(4, self.RandomNetid())
+ self._CheckTunnelModeOutputMark(4, tunsrc, 0, None)
+
+ def testTunnelModeOutputNoMarkIPv6(self):
+ tunsrc = self.MyAddress(6, self.RandomNetid())
+ self._CheckTunnelModeOutputMark(6, tunsrc, 0, None)
+
+ def testTunnelModeOutputInvalidMarkIPv4(self):
+ tunsrc = self.MyAddress(4, self.RandomNetid())
+ self._CheckTunnelModeOutputMark(4, tunsrc, 9999, None)
+
+ def testTunnelModeOutputInvalidMarkIPv6(self):
+ tunsrc = self.MyAddress(6, self.RandomNetid())
+ self._CheckTunnelModeOutputMark(6, tunsrc, 9999, None)
+
+ def testTunnelModeOutputMarkAttributes(self):
+ mark = 1234567
+ self.xfrm.AddSaInfo(TEST_ADDR1, TUNNEL_ENDPOINTS[6], 0x1234,
+ xfrm.XFRM_MODE_TUNNEL, 100, xfrm_base._ALGO_CBC_AES_256,
+ xfrm_base._ALGO_HMAC_SHA1, None, None, None, mark)
+ dump = self.xfrm.DumpSaInfo()
+ self.assertEquals(1, len(dump))
+ sainfo, attributes = dump[0]
+ self.assertEquals(mark, attributes["XFRMA_OUTPUT_MARK"])
+
+ def testInvalidAlgorithms(self):
+ key = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
+ invalid_auth = (xfrm.XfrmAlgoAuth(("invalid(algo)", 128, 96)), key)
+ invalid_crypt = (xfrm.XfrmAlgo(("invalid(algo)", 128)), key)
+ with self.assertRaisesErrno(ENOSYS):
+ self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
+ xfrm.XFRM_MODE_TRANSPORT, 0, xfrm_base._ALGO_CBC_AES_256,
+ invalid_auth, None, None, None, 0)
+ with self.assertRaisesErrno(ENOSYS):
+ self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
+ xfrm.XFRM_MODE_TRANSPORT, 0, invalid_crypt,
+ xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
+
+ def testUpdateSaAddMark(self):
+ """Test that when an SA has no mark, it can be updated to add a mark."""
+ for version in [4, 6]:
+ spi = 0xABCD
+ # Test that an SA created with ALLOCSPI can be updated with the mark.
+ new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
+ IPPROTO_ESP, spi, spi)
+ mark = xfrm.ExactMatchMark(0xf00d)
+ self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
+ net_test.GetWildcardAddress(version),
+ spi, xfrm.XFRM_MODE_TUNNEL, 0,
+ xfrm_base._ALGO_CBC_AES_256,
+ xfrm_base._ALGO_HMAC_SHA1,
+ None, None, mark, 0, is_update=True)
+ dump = self.xfrm.DumpSaInfo()
+ self.assertEquals(1, len(dump)) # check that update updated
+ sainfo, attributes = dump[0]
+ self.assertEquals(mark, attributes["XFRMA_MARK"])
+ self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
+ spi, IPPROTO_ESP, mark)
+
+ # TODO: we might also need to update the mark for a VALID SA.
if __name__ == "__main__":
unittest.main()
diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py
new file mode 100755
index 0000000..eb1a46e
--- /dev/null
+++ b/net/test/xfrm_tunnel_test.py
@@ -0,0 +1,953 @@
+#!/usr/bin/python
+#
+# Copyright 2017 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
+from errno import * # pylint: disable=wildcard-import
+from socket import * # pylint: disable=wildcard-import
+
+import random
+import itertools
+import struct
+import unittest
+
+from scapy import all as scapy
+from tun_twister import TunTwister
+import csocket
+import iproute
+import multinetwork_base
+import net_test
+import packets
+import util
+import xfrm
+import xfrm_base
+
+_LOOPBACK_IFINDEX = 1
+_TEST_XFRM_IFNAME = "ipsec42"
+_TEST_XFRM_IF_ID = 42
+
+# Does the kernel support xfrmi interfaces?
+def HaveXfrmInterfaces():
+ try:
+ i = iproute.IPRoute()
+ i.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID,
+ _LOOPBACK_IFINDEX)
+ i.DeleteLink(_TEST_XFRM_IFNAME)
+ try:
+ i.GetIfIndex(_TEST_XFRM_IFNAME)
+ assert "Deleted interface %s still exists!" % _TEST_XFRM_IFNAME
+ except IOError:
+ pass
+ return True
+ except IOError:
+ return False
+
+HAVE_XFRM_INTERFACES = HaveXfrmInterfaces()
+
+# Parameters to setup tunnels as special networks
+_TUNNEL_NETID_OFFSET = 0xFC00 # Matches reserved netid range for IpSecService
+_BASE_TUNNEL_NETID = {4: 40, 6: 60}
+_BASE_VTI_OKEY = 2000000100
+_BASE_VTI_IKEY = 2000000200
+
+_TEST_OUT_SPI = 0x1234
+_TEST_IN_SPI = _TEST_OUT_SPI
+
+_TEST_OKEY = 2000000100
+_TEST_IKEY = 2000000200
+
+_TEST_REMOTE_PORT = 1234
+
+_SCAPY_IP_TYPE = {4: scapy.IP, 6: scapy.IPv6}
+
+
+def _GetLocalInnerAddress(version):
+ return {4: "10.16.5.15", 6: "2001:db8:1::1"}[version]
+
+
+def _GetRemoteInnerAddress(version):
+ return {4: "10.16.5.20", 6: "2001:db8:2::1"}[version]
+
+
+def _GetRemoteOuterAddress(version):
+ return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
+
+
+def _GetNullAuthCryptTunnelModePkt(inner_version, src_inner, src_outer,
+ src_port, dst_inner, dst_outer,
+ dst_port, spi, seq_num, ip_hdr_options=None):
+ if ip_hdr_options is None:
+ ip_hdr_options = {}
+
+ ip_hdr_options.update({'src': src_inner, 'dst': dst_inner})
+
+ # Build and receive an ESP packet destined for the inner socket
+ IpType = {4: scapy.IP, 6: scapy.IPv6}[inner_version]
+ input_pkt = (
+ IpType(**ip_hdr_options) / scapy.UDP(sport=src_port, dport=dst_port) /
+ net_test.UDP_PAYLOAD)
+ input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
+ input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, spi, seq_num,
+ (src_outer, dst_outer))
+
+ return input_pkt
+
+
+def _CreateReceiveSock(version, port=0):
+ # Create a socket to receive packets.
+ read_sock = socket(net_test.GetAddressFamily(version), SOCK_DGRAM, 0)
+ read_sock.bind((net_test.GetWildcardAddress(version), port))
+ # The second parameter of the tuple is the port number regardless of AF.
+ local_port = read_sock.getsockname()[1]
+ # Guard against the eventuality of the receive failing.
+ csocket.SetSocketTimeout(read_sock, 500)
+
+ return read_sock, local_port
+
+
+def _SendPacket(testInstance, netid, version, remote, remote_port):
+ # Send a packet out via the tunnel-backed network, bound for the port number
+ # of the input socket.
+ write_sock = socket(net_test.GetAddressFamily(version), SOCK_DGRAM, 0)
+ testInstance.SelectInterface(write_sock, netid, "mark")
+ write_sock.sendto(net_test.UDP_PAYLOAD, (remote, remote_port))
+ local_port = write_sock.getsockname()[1]
+
+ return local_port
+
+
+def InjectTests():
+ InjectParameterizedTests(XfrmTunnelTest)
+ InjectParameterizedTests(XfrmInterfaceTest)
+ InjectParameterizedTests(XfrmVtiTest)
+
+
+def InjectParameterizedTests(cls):
+ VERSIONS = (4, 6)
+ param_list = itertools.product(VERSIONS, VERSIONS)
+
+ def NameGenerator(*args):
+ return "IPv%d_in_IPv%d" % tuple(args)
+
+ util.InjectParameterizedTest(cls, param_list, NameGenerator)
+
+
+class XfrmTunnelTest(xfrm_base.XfrmLazyTest):
+
+ def _CheckTunnelOutput(self, inner_version, outer_version, underlying_netid,
+ netid, local_inner, remote_inner, local_outer,
+ remote_outer, write_sock):
+
+ write_sock.sendto(net_test.UDP_PAYLOAD, (remote_inner, 53))
+ self._ExpectEspPacketOn(underlying_netid, _TEST_OUT_SPI, 1, None,
+ local_outer, remote_outer)
+
+ def _CheckTunnelInput(self, inner_version, outer_version, underlying_netid,
+ netid, local_inner, remote_inner, local_outer,
+ remote_outer, read_sock):
+
+ # The second parameter of the tuple is the port number regardless of AF.
+ local_port = read_sock.getsockname()[1]
+
+ # Build and receive an ESP packet destined for the inner socket
+ input_pkt = _GetNullAuthCryptTunnelModePkt(
+ inner_version, remote_inner, remote_outer, _TEST_REMOTE_PORT,
+ local_inner, local_outer, local_port, _TEST_IN_SPI, 1)
+ self.ReceivePacketOn(underlying_netid, input_pkt)
+
+ # Verify that the packet data and src are correct
+ data, src = read_sock.recvfrom(4096)
+ self.assertEquals(net_test.UDP_PAYLOAD, data)
+ self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
+
+ def _TestTunnel(self, inner_version, outer_version, func, direction,
+ test_output_mark_unset):
+ """Test a unidirectional XFRM Tunnel with explicit selectors"""
+ # Select the underlying netid, which represents the external
+ # interface from/to which to route ESP packets.
+ u_netid = self.RandomNetid()
+ # Select a random netid that will originate traffic locally and
+ # which represents the netid on which the plaintext is sent
+ netid = self.RandomNetid(exclude=u_netid)
+
+ local_inner = self.MyAddress(inner_version, netid)
+ remote_inner = _GetRemoteInnerAddress(inner_version)
+ local_outer = self.MyAddress(outer_version, u_netid)
+ remote_outer = _GetRemoteOuterAddress(outer_version)
+
+ output_mark = u_netid
+ if test_output_mark_unset:
+ output_mark = None
+ self.SetDefaultNetwork(u_netid)
+
+ try:
+ # Create input/ouput SPs, SAs and sockets to simulate a more realistic
+ # environment.
+ self.xfrm.CreateTunnel(
+ xfrm.XFRM_POLICY_IN, xfrm.SrcDstSelector(remote_inner, local_inner),
+ remote_outer, local_outer, _TEST_IN_SPI, xfrm_base._ALGO_CRYPT_NULL,
+ xfrm_base._ALGO_AUTH_NULL, None, None, None, xfrm.MATCH_METHOD_ALL)
+
+ self.xfrm.CreateTunnel(
+ xfrm.XFRM_POLICY_OUT, xfrm.SrcDstSelector(local_inner, remote_inner),
+ local_outer, remote_outer, _TEST_OUT_SPI, xfrm_base._ALGO_CBC_AES_256,
+ xfrm_base._ALGO_HMAC_SHA1, None, output_mark, None, xfrm.MATCH_METHOD_ALL)
+
+ write_sock = socket(net_test.GetAddressFamily(inner_version), SOCK_DGRAM, 0)
+ self.SelectInterface(write_sock, netid, "mark")
+ read_sock, _ = _CreateReceiveSock(inner_version)
+
+ sock = write_sock if direction == xfrm.XFRM_POLICY_OUT else read_sock
+ func(inner_version, outer_version, u_netid, netid, local_inner,
+ remote_inner, local_outer, remote_outer, sock)
+ finally:
+ if test_output_mark_unset:
+ self.ClearDefaultNetwork()
+
+ def ParamTestTunnelInput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput,
+ xfrm.XFRM_POLICY_IN, False)
+
+ def ParamTestTunnelOutput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
+ xfrm.XFRM_POLICY_OUT, False)
+
+ def ParamTestTunnelOutputNoSetMark(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
+ xfrm.XFRM_POLICY_OUT, True)
+
+
+@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
+class XfrmAddDeleteVtiTest(xfrm_base.XfrmBaseTest):
+ def _VerifyVtiInfoData(self, vti_info_data, version, local_addr, remote_addr,
+ ikey, okey):
+ self.assertEquals(vti_info_data["IFLA_VTI_IKEY"], ikey)
+ self.assertEquals(vti_info_data["IFLA_VTI_OKEY"], okey)
+
+ family = AF_INET if version == 4 else AF_INET6
+ self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_LOCAL"]),
+ local_addr)
+ self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_REMOTE"]),
+ remote_addr)
+
+ def testAddVti(self):
+ """Test the creation of a Virtual Tunnel Interface."""
+ for version in [4, 6]:
+ netid = self.RandomNetid()
+ local_addr = self.MyAddress(version, netid)
+ self.iproute.CreateVirtualTunnelInterface(
+ dev_name=_TEST_XFRM_IFNAME,
+ local_addr=local_addr,
+ remote_addr=_GetRemoteOuterAddress(version),
+ o_key=_TEST_OKEY,
+ i_key=_TEST_IKEY)
+ self._VerifyVtiInfoData(
+ self.iproute.GetIfinfoData(_TEST_XFRM_IFNAME), version, local_addr,
+ _GetRemoteOuterAddress(version), _TEST_IKEY, _TEST_OKEY)
+
+ new_remote_addr = {4: net_test.IPV4_ADDR2, 6: net_test.IPV6_ADDR2}
+ new_okey = _TEST_OKEY + _TEST_XFRM_IF_ID
+ new_ikey = _TEST_IKEY + _TEST_XFRM_IF_ID
+ self.iproute.CreateVirtualTunnelInterface(
+ dev_name=_TEST_XFRM_IFNAME,
+ local_addr=local_addr,
+ remote_addr=new_remote_addr[version],
+ o_key=new_okey,
+ i_key=new_ikey,
+ is_update=True)
+
+ self._VerifyVtiInfoData(
+ self.iproute.GetIfinfoData(_TEST_XFRM_IFNAME), version, local_addr,
+ new_remote_addr[version], new_ikey, new_okey)
+
+ if_index = self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
+
+ # Validate that the netlink interface matches the ioctl interface.
+ self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index)
+ self.iproute.DeleteLink(_TEST_XFRM_IFNAME)
+ with self.assertRaises(IOError):
+ self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
+
+ def _QuietDeleteLink(self, ifname):
+ try:
+ self.iproute.DeleteLink(ifname)
+ except IOError:
+ # The link was not present.
+ pass
+
+ def tearDown(self):
+ super(XfrmAddDeleteVtiTest, self).tearDown()
+ self._QuietDeleteLink(_TEST_XFRM_IFNAME)
+
+
+class SaInfo(object):
+
+ def __init__(self, spi):
+ self.spi = spi
+ self.seq_num = 1
+
+
+class IpSecBaseInterface(object):
+
+ def __init__(self, iface, netid, underlying_netid, local, remote, version):
+ self.iface = iface
+ self.netid = netid
+ self.underlying_netid = underlying_netid
+ self.local, self.remote = local, remote
+
+ # XFRM interfaces technically do not have a version. This keeps track of
+ # the IP version of the local and remote addresses.
+ self.version = version
+ self.rx = self.tx = 0
+ self.addrs = {}
+
+ self.iproute = iproute.IPRoute()
+ self.xfrm = xfrm.Xfrm()
+
+ def Teardown(self):
+ self.TeardownXfrm()
+ self.TeardownInterface()
+
+ def TeardownInterface(self):
+ self.iproute.DeleteLink(self.iface)
+
+ def SetupXfrm(self, use_null_crypt):
+ rand_spi = random.randint(0, 0x7fffffff)
+ self.in_sa = SaInfo(rand_spi)
+ self.out_sa = SaInfo(rand_spi)
+
+ # Select algorithms:
+ if use_null_crypt:
+ auth, crypt = xfrm_base._ALGO_AUTH_NULL, xfrm_base._ALGO_CRYPT_NULL
+ else:
+ auth, crypt = xfrm_base._ALGO_HMAC_SHA1, xfrm_base._ALGO_CBC_AES_256
+
+ self._SetupXfrmByType(auth, crypt)
+
+ def Rekey(self, outer_family, new_out_sa, new_in_sa):
+ """Rekeys the Tunnel Interface
+
+ Creates new SAs and updates the outbound security policy to use new SAs.
+
+ Args:
+ outer_family: AF_INET or AF_INET6
+ new_out_sa: An SaInfo struct representing the new outbound SA's info
+ new_in_sa: An SaInfo struct representing the new inbound SA's info
+ """
+ self._Rekey(outer_family, new_out_sa, new_in_sa)
+
+ # Update Interface object
+ self.out_sa = new_out_sa
+ self.in_sa = new_in_sa
+
+ def TeardownXfrm(self):
+ raise NotImplementedError("Subclasses should implement this")
+
+ def _SetupXfrmByType(self, auth_algo, crypt_algo):
+ raise NotImplementedError("Subclasses should implement this")
+
+ def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+ raise NotImplementedError("Subclasses should implement this")
+
+
+class VtiInterface(IpSecBaseInterface):
+
+ def __init__(self, iface, netid, underlying_netid, _, local, remote, version):
+ super(VtiInterface, self).__init__(iface, netid, underlying_netid, local,
+ remote, version)
+
+ self.ikey = _TEST_IKEY + netid
+ self.okey = _TEST_OKEY + netid
+
+ self.SetupInterface()
+ self.SetupXfrm(False)
+
+ def SetupInterface(self):
+ return self.iproute.CreateVirtualTunnelInterface(
+ self.iface, self.local, self.remote, self.ikey, self.okey)
+
+ def _SetupXfrmByType(self, auth_algo, crypt_algo):
+ # For the VTI, the selectors are wildcard since packets will only
+ # be selected if they have the appropriate mark, hence the inner
+ # addresses are wildcard.
+ self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
+ self.out_sa.spi, crypt_algo, auth_algo,
+ xfrm.ExactMatchMark(self.okey),
+ self.underlying_netid, None, xfrm.MATCH_METHOD_ALL)
+
+ self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
+ self.in_sa.spi, crypt_algo, auth_algo,
+ xfrm.ExactMatchMark(self.ikey), None, None,
+ xfrm.MATCH_METHOD_MARK)
+
+ def TeardownXfrm(self):
+ self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
+ self.out_sa.spi, self.okey, None)
+ self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
+ self.in_sa.spi, self.ikey, None)
+
+ def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+ # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
+ # the same, but rekeys are asymmetric, and only update the outbound
+ # policy.
+ self.xfrm.AddSaInfo(self.local, self.remote, new_out_sa.spi,
+ xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
+ xfrm_base._ALGO_AUTH_NULL, None, None,
+ xfrm.ExactMatchMark(self.okey), self.underlying_netid)
+
+ self.xfrm.AddSaInfo(self.remote, self.local, new_in_sa.spi,
+ xfrm.XFRM_MODE_TUNNEL, 0, xfrm_base._ALGO_CRYPT_NULL,
+ xfrm_base._ALGO_AUTH_NULL, None, None,
+ xfrm.ExactMatchMark(self.ikey), None)
+
+ # Create new policies for IPv4 and IPv6.
+ for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
+ # Add SPI-specific output policy to enforce using new outbound SPI
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+ tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
+ (self.local, self.remote))
+ self.xfrm.UpdatePolicyInfo(policy, tmpl, xfrm.ExactMatchMark(self.okey),
+ 0)
+
+ def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
+ self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP,
+ xfrm.ExactMatchMark(self.ikey))
+ self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP,
+ xfrm.ExactMatchMark(self.okey))
+
+
+@unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
+class XfrmAddDeleteXfrmInterfaceTest(xfrm_base.XfrmBaseTest):
+ """Test the creation of an XFRM Interface."""
+
+ def testAddXfrmInterface(self):
+ self.iproute.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID,
+ _LOOPBACK_IFINDEX)
+ if_index = self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
+ net_test.SetInterfaceUp(_TEST_XFRM_IFNAME)
+
+ # Validate that the netlink interface matches the ioctl interface.
+ self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index)
+ self.iproute.DeleteLink(_TEST_XFRM_IFNAME)
+ with self.assertRaises(IOError):
+ self.iproute.GetIfIndex(_TEST_XFRM_IFNAME)
+
+
+class XfrmInterface(IpSecBaseInterface):
+
+ def __init__(self, iface, netid, underlying_netid, ifindex, local, remote,
+ version):
+ super(XfrmInterface, self).__init__(iface, netid, underlying_netid, local,
+ remote, version)
+
+ self.ifindex = ifindex
+ self.xfrm_if_id = netid
+
+ self.SetupInterface()
+ self.SetupXfrm(False)
+
+ def SetupInterface(self):
+ """Create an XFRM interface."""
+ return self.iproute.CreateXfrmInterface(self.iface, self.netid, self.ifindex)
+
+ def _SetupXfrmByType(self, auth_algo, crypt_algo):
+ self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote,
+ self.out_sa.spi, crypt_algo, auth_algo, None,
+ self.underlying_netid, self.xfrm_if_id,
+ xfrm.MATCH_METHOD_ALL)
+ self.xfrm.CreateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local,
+ self.in_sa.spi, crypt_algo, auth_algo, None, None,
+ self.xfrm_if_id, xfrm.MATCH_METHOD_IFID)
+
+ def TeardownXfrm(self):
+ self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_OUT, None, self.remote,
+ self.out_sa.spi, None, self.xfrm_if_id)
+ self.xfrm.DeleteTunnel(xfrm.XFRM_POLICY_IN, None, self.local,
+ self.in_sa.spi, None, self.xfrm_if_id)
+
+ def _Rekey(self, outer_family, new_out_sa, new_in_sa):
+ # TODO: Consider ways to share code with xfrm.CreateTunnel(). It's mostly
+ # the same, but rekeys are asymmetric, and only update the outbound
+ # policy.
+ self.xfrm.AddSaInfo(
+ self.local, self.remote, new_out_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
+ None, self.underlying_netid, xfrm_if_id=self.xfrm_if_id)
+
+ self.xfrm.AddSaInfo(
+ self.remote, self.local, new_in_sa.spi, xfrm.XFRM_MODE_TUNNEL, 0,
+ xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL, None, None,
+ None, None, xfrm_if_id=self.xfrm_if_id)
+
+ # Create new policies for IPv4 and IPv6.
+ for sel in [xfrm.EmptySelector(AF_INET), xfrm.EmptySelector(AF_INET6)]:
+ # Add SPI-specific output policy to enforce using new outbound SPI
+ policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
+ tmpl = xfrm.UserTemplate(outer_family, new_out_sa.spi, 0,
+ (self.local, self.remote))
+ self.xfrm.UpdatePolicyInfo(policy, tmpl, None, self.xfrm_if_id)
+
+ def DeleteOldSaInfo(self, outer_family, old_in_spi, old_out_spi):
+ self.xfrm.DeleteSaInfo(self.local, old_in_spi, IPPROTO_ESP, None,
+ self.xfrm_if_id)
+ self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP, None,
+ self.xfrm_if_id)
+
+
+class XfrmTunnelBase(xfrm_base.XfrmBaseTest):
+
+ @classmethod
+ def setUpClass(cls):
+ xfrm_base.XfrmBaseTest.setUpClass()
+ # Tunnel interfaces use marks extensively, so configure realistic packet
+ # marking rules to make the test representative, make PMTUD work, etc.
+ cls.SetInboundMarks(True)
+ cls.SetMarkReflectSysctls(1)
+
+ # Group by tunnel version to ensure that we test at least one IPv4 and one
+ # IPv6 tunnel
+ cls.tunnelsV4 = {}
+ cls.tunnelsV6 = {}
+ for i, underlying_netid in enumerate(cls.tuns):
+ for version in 4, 6:
+ netid = _BASE_TUNNEL_NETID[version] + _TUNNEL_NETID_OFFSET + i
+ iface = "ipsec%s" % netid
+ local = cls.MyAddress(version, underlying_netid)
+ if version == 4:
+ remote = (net_test.IPV4_ADDR if (i % 2) else net_test.IPV4_ADDR2)
+ else:
+ remote = (net_test.IPV6_ADDR if (i % 2) else net_test.IPV6_ADDR2)
+
+ ifindex = cls.ifindices[underlying_netid]
+ tunnel = cls.INTERFACE_CLASS(iface, netid, underlying_netid, ifindex,
+ local, remote, version)
+ cls._SetInboundMarking(netid, iface, True)
+ cls._SetupTunnelNetwork(tunnel, True)
+
+ if version == 4:
+ cls.tunnelsV4[netid] = tunnel
+ else:
+ cls.tunnelsV6[netid] = tunnel
+
+ @classmethod
+ def tearDownClass(cls):
+ # The sysctls are restored by MultinetworkBaseTest.tearDownClass.
+ cls.SetInboundMarks(False)
+ for tunnel in cls.tunnelsV4.values() + cls.tunnelsV6.values():
+ cls._SetInboundMarking(tunnel.netid, tunnel.iface, False)
+ cls._SetupTunnelNetwork(tunnel, False)
+ tunnel.Teardown()
+ xfrm_base.XfrmBaseTest.tearDownClass()
+
+ def randomTunnel(self, outer_version):
+ version_dict = self.tunnelsV4 if outer_version == 4 else self.tunnelsV6
+ return random.choice(version_dict.values())
+
+ def setUp(self):
+ multinetwork_base.MultiNetworkBaseTest.setUp(self)
+ self.iproute = iproute.IPRoute()
+ self.xfrm = xfrm.Xfrm()
+
+ def tearDown(self):
+ multinetwork_base.MultiNetworkBaseTest.tearDown(self)
+
+ def _SwapInterfaceAddress(self, ifname, old_addr, new_addr):
+ """Exchange two addresses on a given interface.
+
+ Args:
+ ifname: Name of the interface
+ old_addr: An address to be removed from the interface
+ new_addr: An address to be added to an interface
+ """
+ version = 6 if ":" in new_addr else 4
+ ifindex = net_test.GetInterfaceIndex(ifname)
+ self.iproute.AddAddress(new_addr,
+ net_test.AddressLengthBits(version), ifindex)
+ self.iproute.DelAddress(old_addr,
+ net_test.AddressLengthBits(version), ifindex)
+
+ @classmethod
+ def _GetLocalAddress(cls, version, netid):
+ if version == 4:
+ return cls._MyIPv4Address(netid - _TUNNEL_NETID_OFFSET)
+ else:
+ return cls.OnlinkPrefix(6, netid - _TUNNEL_NETID_OFFSET) + "1"
+
+ @classmethod
+ def _SetupTunnelNetwork(cls, tunnel, is_add):
+ """Setup rules and routes for a tunnel Network.
+
+ Takes an interface and depending on the boolean
+ value of is_add, either adds or removes the rules
+ and routes for a tunnel interface to behave like an
+ Android Network for purposes of testing.
+
+ Args:
+ tunnel: A VtiInterface or XfrmInterface, the tunnel to set up.
+ is_add: Boolean that causes this method to perform setup if True or
+ teardown if False
+ """
+ if is_add:
+ # Disable router solicitations to avoid occasional spurious packets
+ # arriving on the underlying network; there are two possible behaviors
+ # when that occurred: either only the RA packet is read, and when it
+ # is echoed back to the tunnel, it causes the test to fail by not
+ # receiving # the UDP_PAYLOAD; or, two packets may arrive on the
+ # underlying # network which fails the assertion that only one ESP packet
+ # is received.
+ cls.SetSysctl(
+ "/proc/sys/net/ipv6/conf/%s/router_solicitations" % tunnel.iface, 0)
+ net_test.SetInterfaceUp(tunnel.iface)
+
+ for version in [4, 6]:
+ ifindex = net_test.GetInterfaceIndex(tunnel.iface)
+ table = tunnel.netid
+
+ # Set up routing rules.
+ start, end = cls.UidRangeForNetid(tunnel.netid)
+ cls.iproute.UidRangeRule(version, is_add, start, end, table,
+ cls.PRIORITY_UID)
+ cls.iproute.OifRule(version, is_add, tunnel.iface, table, cls.PRIORITY_OIF)
+ cls.iproute.FwmarkRule(version, is_add, tunnel.netid, cls.NETID_FWMASK,
+ table, cls.PRIORITY_FWMARK)
+
+ # Configure IP addresses.
+ addr = cls._GetLocalAddress(version, tunnel.netid)
+ prefixlen = net_test.AddressLengthBits(version)
+ tunnel.addrs[version] = addr
+ if is_add:
+ cls.iproute.AddAddress(addr, prefixlen, ifindex)
+ cls.iproute.AddRoute(version, table, "default", 0, None, ifindex)
+ else:
+ cls.iproute.DelRoute(version, table, "default", 0, None, ifindex)
+ cls.iproute.DelAddress(addr, prefixlen, ifindex)
+
+ def assertReceivedPacket(self, tunnel, sa_info):
+ tunnel.rx += 1
+ self.assertEquals((tunnel.rx, tunnel.tx),
+ self.iproute.GetRxTxPackets(tunnel.iface))
+ sa_info.seq_num += 1
+
+ def assertSentPacket(self, tunnel, sa_info):
+ tunnel.tx += 1
+ self.assertEquals((tunnel.rx, tunnel.tx),
+ self.iproute.GetRxTxPackets(tunnel.iface))
+ sa_info.seq_num += 1
+
+ def _CheckTunnelInput(self, tunnel, inner_version, local_inner, remote_inner,
+ sa_info=None, expect_fail=False):
+ """Test null-crypt input path over an IPsec interface."""
+ if sa_info is None:
+ sa_info = tunnel.in_sa
+ read_sock, local_port = _CreateReceiveSock(inner_version)
+
+ input_pkt = _GetNullAuthCryptTunnelModePkt(
+ inner_version, remote_inner, tunnel.remote, _TEST_REMOTE_PORT,
+ local_inner, tunnel.local, local_port, sa_info.spi, sa_info.seq_num)
+ self.ReceivePacketOn(tunnel.underlying_netid, input_pkt)
+
+ if expect_fail:
+ self.assertRaisesErrno(EAGAIN, read_sock.recv, 4096)
+ else:
+ # Verify that the packet data and src are correct
+ data, src = read_sock.recvfrom(4096)
+ self.assertReceivedPacket(tunnel, sa_info)
+ self.assertEquals(net_test.UDP_PAYLOAD, data)
+ self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2])
+
+ def _CheckTunnelOutput(self, tunnel, inner_version, local_inner,
+ remote_inner, sa_info=None):
+ """Test null-crypt output path over an IPsec interface."""
+ if sa_info is None:
+ sa_info = tunnel.out_sa
+ local_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
+ _TEST_REMOTE_PORT)
+
+ # Read a tunneled IP packet on the underlying (outbound) network
+ # verifying that it is an ESP packet.
+ pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, sa_info.spi,
+ sa_info.seq_num, None, tunnel.local,
+ tunnel.remote)
+
+ # Get and update the IP headers on the inner payload so that we can do a simple
+ # comparison of byte data. Unfortunately, due to the scapy version this runs on,
+ # we cannot parse past the ESP header to the inner IP header, and thus have to
+ # workaround in this manner
+ if inner_version == 4:
+ ip_hdr_options = {
+ 'id': scapy.IP(str(pkt.payload)[8:]).id,
+ 'flags': scapy.IP(str(pkt.payload)[8:]).flags
+ }
+ else:
+ ip_hdr_options = {'fl': scapy.IPv6(str(pkt.payload)[8:]).fl}
+
+ expected = _GetNullAuthCryptTunnelModePkt(
+ inner_version, local_inner, tunnel.local, local_port, remote_inner,
+ tunnel.remote, _TEST_REMOTE_PORT, sa_info.spi, sa_info.seq_num,
+ ip_hdr_options)
+
+ # Check outer header manually (Avoids having to overwrite outer header's
+ # id, flags or flow label)
+ self.assertSentPacket(tunnel, sa_info)
+ self.assertEquals(expected.src, pkt.src)
+ self.assertEquals(expected.dst, pkt.dst)
+ self.assertEquals(len(expected), len(pkt))
+
+ # Check everything else
+ self.assertEquals(str(expected.payload), str(pkt.payload))
+
+ def _CheckTunnelEncryption(self, tunnel, inner_version, local_inner,
+ remote_inner):
+ """Test both input and output paths over an encrypted IPsec interface.
+
+ This tests specifically makes sure that the both encryption and decryption
+ work together, as opposed to the _CheckTunnel(Input|Output) where the
+ input and output paths are tested separately, and using null encryption.
+ """
+ src_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
+ _TEST_REMOTE_PORT)
+
+ # Make sure it appeared on the underlying interface
+ pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, tunnel.out_sa.spi,
+ tunnel.out_sa.seq_num, None, tunnel.local,
+ tunnel.remote)
+
+ # Check that packet is not sent in plaintext
+ self.assertTrue(str(net_test.UDP_PAYLOAD) not in str(pkt))
+
+ # Check src/dst
+ self.assertEquals(tunnel.local, pkt.src)
+ self.assertEquals(tunnel.remote, pkt.dst)
+
+ # Check that the interface statistics recorded the outbound packet
+ self.assertSentPacket(tunnel, tunnel.out_sa)
+
+ try:
+ # Swap the interface addresses to pretend we are the remote
+ self._SwapInterfaceAddress(
+ tunnel.iface, new_addr=remote_inner, old_addr=local_inner)
+
+ # Swap the packet's IP headers and write it back to the underlying
+ # network.
+ pkt = TunTwister.TwistPacket(pkt)
+ read_sock, local_port = _CreateReceiveSock(inner_version,
+ _TEST_REMOTE_PORT)
+ self.ReceivePacketOn(tunnel.underlying_netid, pkt)
+
+ # Verify that the packet data and src are correct
+ data, src = read_sock.recvfrom(4096)
+ self.assertEquals(net_test.UDP_PAYLOAD, data)
+ self.assertEquals((local_inner, src_port), src[:2])
+
+ # Check that the interface statistics recorded the inbound packet
+ self.assertReceivedPacket(tunnel, tunnel.in_sa)
+ finally:
+ # Swap the interface addresses to pretend we are the remote
+ self._SwapInterfaceAddress(
+ tunnel.iface, new_addr=local_inner, old_addr=remote_inner)
+
+ def _CheckTunnelIcmp(self, tunnel, inner_version, local_inner, remote_inner,
+ sa_info=None):
+ """Test ICMP error path over an IPsec interface."""
+ if sa_info is None:
+ sa_info = tunnel.out_sa
+ # Now attempt to provoke an ICMP error.
+ # TODO: deduplicate with multinetwork_test.py.
+ dst_prefix, intermediate = {
+ 4: ("172.19.", "172.16.9.12"),
+ 6: ("2001:db8::", "2001:db8::1")
+ }[tunnel.version]
+
+ local_port = _SendPacket(self, tunnel.netid, inner_version, remote_inner,
+ _TEST_REMOTE_PORT)
+ pkt = self._ExpectEspPacketOn(tunnel.underlying_netid, sa_info.spi,
+ sa_info.seq_num, None, tunnel.local,
+ tunnel.remote)
+ self.assertSentPacket(tunnel, sa_info)
+
+ myaddr = self.MyAddress(tunnel.version, tunnel.underlying_netid)
+ _, toobig = packets.ICMPPacketTooBig(tunnel.version, intermediate, myaddr,
+ pkt)
+ self.ReceivePacketOn(tunnel.underlying_netid, toobig)
+
+ # Check that the packet too big reduced the MTU.
+ routes = self.iproute.GetRoutes(tunnel.remote, 0, tunnel.underlying_netid, None)
+ self.assertEquals(1, len(routes))
+ rtmsg, attributes = routes[0]
+ self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
+ self.assertEquals(packets.PTB_MTU, attributes["RTA_METRICS"]["RTAX_MTU"])
+
+ # Clear PMTU information so that future tests don't have to worry about it.
+ self.InvalidateDstCache(tunnel.version, tunnel.underlying_netid)
+
+ def _CheckTunnelEncryptionWithIcmp(self, tunnel, inner_version, local_inner,
+ remote_inner):
+ """Test combined encryption path with ICMP errors over an IPsec tunnel"""
+ self._CheckTunnelEncryption(tunnel, inner_version, local_inner,
+ remote_inner)
+ self._CheckTunnelIcmp(tunnel, inner_version, local_inner, remote_inner)
+ self._CheckTunnelEncryption(tunnel, inner_version, local_inner,
+ remote_inner)
+
+ def _TestTunnel(self, inner_version, outer_version, func, use_null_crypt):
+ """Bootstrap method to setup and run tests for the given parameters."""
+ tunnel = self.randomTunnel(outer_version)
+
+ try:
+ # Some tests require that the out_seq_num and in_seq_num are the same
+ # (Specifically encrypted tests), rebuild SAs to ensure seq_num is 1
+ #
+ # Until we get better scapy support, the only way we can build an
+ # encrypted packet is to send it out, and read the packet from the wire.
+ # We then generally use this as the "inbound" encrypted packet, injecting
+ # it into the interface for which it is expected on.
+ #
+ # As such, this is required to ensure that encrypted packets (which we
+ # currently have no way to easily modify) are not considered replay
+ # attacks by the inbound SA. (eg: received 3 packets, seq_num_in = 3,
+ # sent only 1, # seq_num_out = 1, inbound SA would consider this a replay
+ # attack)
+ tunnel.TeardownXfrm()
+ tunnel.SetupXfrm(use_null_crypt)
+
+ local_inner = tunnel.addrs[inner_version]
+ remote_inner = _GetRemoteInnerAddress(inner_version)
+
+ for i in range(2):
+ func(tunnel, inner_version, local_inner, remote_inner)
+ finally:
+ if use_null_crypt:
+ tunnel.TeardownXfrm()
+ tunnel.SetupXfrm(False)
+
+ def _CheckTunnelRekey(self, tunnel, inner_version, local_inner, remote_inner):
+ old_out_sa = tunnel.out_sa
+ old_in_sa = tunnel.in_sa
+
+ # Check to make sure that both directions work before rekey
+ self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+ old_in_sa)
+ self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+ old_out_sa)
+
+ # Rekey
+ outer_family = net_test.GetAddressFamily(tunnel.version)
+
+ # Create new SA
+ # Distinguish the new SAs with new SPIs.
+ new_out_sa = SaInfo(old_out_sa.spi + 1)
+ new_in_sa = SaInfo(old_in_sa.spi + 1)
+
+ # Perform Rekey
+ tunnel.Rekey(outer_family, new_out_sa, new_in_sa)
+
+ # Expect that the old SPI still works for inbound packets
+ self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+ old_in_sa)
+
+ # Test both paths with new SPIs, expect outbound to use new SPI
+ self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+ new_in_sa)
+ self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+ new_out_sa)
+
+ # Delete old SAs
+ tunnel.DeleteOldSaInfo(outer_family, old_in_sa.spi, old_out_sa.spi)
+
+ # Test both paths with new SPIs; should still work
+ self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+ new_in_sa)
+ self._CheckTunnelOutput(tunnel, inner_version, local_inner, remote_inner,
+ new_out_sa)
+
+ # Expect failure upon trying to receive a packet with the deleted SPI
+ self._CheckTunnelInput(tunnel, inner_version, local_inner, remote_inner,
+ old_in_sa, True)
+
+ def _TestTunnelRekey(self, inner_version, outer_version):
+ """Test packet input and output over a Virtual Tunnel Interface."""
+ tunnel = self.randomTunnel(outer_version)
+
+ try:
+ # Always use null_crypt, so we can check input and output separately
+ tunnel.TeardownXfrm()
+ tunnel.SetupXfrm(True)
+
+ local_inner = tunnel.addrs[inner_version]
+ remote_inner = _GetRemoteInnerAddress(inner_version)
+
+ self._CheckTunnelRekey(tunnel, inner_version, local_inner, remote_inner)
+ finally:
+ tunnel.TeardownXfrm()
+ tunnel.SetupXfrm(False)
+
+
+@unittest.skipUnless(net_test.LINUX_VERSION >= (3, 18, 0), "VTI Unsupported")
+class XfrmVtiTest(XfrmTunnelBase):
+
+ INTERFACE_CLASS = VtiInterface
+
+ def ParamTestVtiInput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True)
+
+ def ParamTestVtiOutput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
+ True)
+
+ def ParamTestVtiInOutEncrypted(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption,
+ False)
+
+ def ParamTestVtiIcmp(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False)
+
+ def ParamTestVtiEncryptionWithIcmp(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version,
+ self._CheckTunnelEncryptionWithIcmp, False)
+
+ def ParamTestVtiRekey(self, inner_version, outer_version):
+ self._TestTunnelRekey(inner_version, outer_version)
+
+
+@unittest.skipUnless(HAVE_XFRM_INTERFACES, "XFRM interfaces unsupported")
+class XfrmInterfaceTest(XfrmTunnelBase):
+
+ INTERFACE_CLASS = XfrmInterface
+
+ def ParamTestXfrmIntfInput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True)
+
+ def ParamTestXfrmIntfOutput(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput,
+ True)
+
+ def ParamTestXfrmIntfInOutEncrypted(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption,
+ False)
+
+ def ParamTestXfrmIntfIcmp(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False)
+
+ def ParamTestXfrmIntfEncryptionWithIcmp(self, inner_version, outer_version):
+ self._TestTunnel(inner_version, outer_version,
+ self._CheckTunnelEncryptionWithIcmp, False)
+
+ def ParamTestXfrmIntfRekey(self, inner_version, outer_version):
+ self._TestTunnelRekey(inner_version, outer_version)
+
+
+if __name__ == "__main__":
+ InjectTests()
+ unittest.main()