tools/compile_seccomp_policy: Coalesce contiguous syscall actions am: b21da7a85a am: e62a7f2cfd
am: ec63e09089

Change-Id: Ia1cf14b96e83f7d4724c3aab90f156bd4d209ba0
diff --git a/tools/compiler.py b/tools/compiler.py
index 996b0da..73053d2 100644
--- a/tools/compiler.py
+++ b/tools/compiler.py
@@ -64,21 +64,85 @@
                             *args)
 
 
+class SyscallPolicyRange:
+    """A contiguous range of SyscallPolicyEntries that have the same action."""
+
+    def __init__(self, *entries):
+        self.numbers = (entries[0].number, entries[-1].number + 1)
+        self.frequency = sum(e.frequency for e in entries)
+        self.accumulated = 0
+        self.filter = entries[0].filter
+
+    def __repr__(self):
+        return 'SyscallPolicyRange<numbers: %r, frequency: %d, filter: %r>' % (
+            self.numbers, self.frequency, self.filter.instructions
+            if self.filter else None)
+
+    def simulate(self, arch, syscall_number, *args):
+        """Simulate the policy with the given arguments."""
+        if not self.filter:
+            return (0, 'ALLOW')
+        return self.filter.simulate(arch, syscall_number, *args)
+
+
+def _convert_to_ranges(entries):
+    entries = list(sorted(entries, key=lambda r: r.number))
+    lower = 0
+    while lower < len(entries):
+        upper = lower + 1
+        while upper < len(entries):
+            if entries[upper - 1].filter != entries[upper].filter:
+                break
+            if entries[upper - 1].number + 1 != entries[upper].number:
+                break
+            upper += 1
+        yield SyscallPolicyRange(*entries[lower:upper])
+        lower = upper
+
+
+def _compile_single_range(entry,
+                          accept_action,
+                          reject_action,
+                          visitor,
+                          lower_bound=0,
+                          upper_bound=1e99):
+    action = accept_action
+    if entry.filter:
+        entry.filter.accept(visitor)
+        action = entry.filter
+    if entry.numbers[1] - entry.numbers[0] == 1:
+        # Single syscall.
+        # Accept if |X == nr|.
+        return bpf.SyscallEntry(
+            entry.numbers[0], action, reject_action, op=bpf.BPF_JEQ)
+    elif entry.numbers[0] == lower_bound:
+        # Syscall range aligned with the lower bound.
+        # Accept if |X < nr[1]|.
+        return bpf.SyscallEntry(
+            entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
+    elif entry.numbers[1] == upper_bound:
+        # Syscall range aligned with the upper bound.
+        # Accept if |X >= nr[0]|.
+        return bpf.SyscallEntry(
+            entry.numbers[0], action, reject_action, op=bpf.BPF_JGE)
+    # Syscall range in the middle.
+    # Accept if |nr[0] <= X < nr[1]|.
+    upper_entry = bpf.SyscallEntry(
+        entry.numbers[1], reject_action, action, op=bpf.BPF_JGE)
+    return bpf.SyscallEntry(
+        entry.numbers[0], upper_entry, reject_action, op=bpf.BPF_JGE)
+
+
 def _compile_entries_linear(entries, accept_action, reject_action, visitor):
     # Compiles the list of entries into a simple linear list of comparisons. In
     # order to make the generated code a bit more efficient, we sort the
     # entries by frequency, so that the most frequently-called syscalls appear
     # earlier in the chain.
     next_action = reject_action
-    entries.sort(key=lambda e: -e.frequency)
-    for entry in entries[::-1]:
-        if entry.filter:
-            next_action = bpf.SyscallEntry(entry.number, entry.filter,
-                                           next_action)
-            entry.filter.accept(visitor)
-        else:
-            next_action = bpf.SyscallEntry(entry.number, accept_action,
-                                           next_action)
+    ranges = sorted(_convert_to_ranges(entries), key=lambda r: -r.frequency)
+    for entry in ranges[::-1]:
+        next_action = _compile_single_range(entry, accept_action, next_action,
+                                            visitor)
     return next_action
 
 
@@ -103,40 +167,32 @@
     # TODO(lhchavez): There is one further possible optimization, which is to
     # hoist any syscalls that are more frequent than all other syscalls in the
     # BST combined into a linear chain before entering the BST.
-    entries.sort(key=lambda e: e.number)
+    ranges = list(_convert_to_ranges(entries))
+
     accumulated = 0
-    for entry in entries:
+    for entry in ranges:
         accumulated += entry.frequency
         entry.accumulated = accumulated
 
     # Recursively create the internal nodes.
-    def _generate_syscall_bst(entries, lower_bound=0, upper_bound=2**64 - 1):
-        assert entries
-        if len(entries) == 1:
-            # This is a single entry, but the interval we are currently
-            # considering contains other syscalls that we want to reject. So
-            # instead of an internal node, create a leaf node with an equality
-            # comparison.
+    def _generate_syscall_bst(ranges, lower_bound=0, upper_bound=2**64 - 1):
+        assert ranges
+        if len(ranges) == 1:
+            # This is a single syscall entry range, but the interval we are
+            # currently considering contains other syscalls that we want to
+            # reject. So instead of an internal node, create one or more leaf
+            # nodes that check the range.
             assert lower_bound < upper_bound
-            if entries[0].filter:
-                entries[0].filter.accept(visitor)
-                return bpf.SyscallEntry(
-                    entries[0].number,
-                    entries[0].filter,
-                    reject_action,
-                    op=bpf.BPF_JEQ)
-            return bpf.SyscallEntry(
-                entries[0].number,
-                accept_action,
-                reject_action,
-                op=bpf.BPF_JEQ)
+            return _compile_single_range(ranges[0], accept_action,
+                                         reject_action, visitor, lower_bound,
+                                         upper_bound)
 
         # Find the midpoint that minimizes the difference between accumulated
         # costs in the left and right subtrees.
-        previous_accumulated = entries[0].accumulated - entries[0].frequency
-        last_accumulated = entries[-1].accumulated - previous_accumulated
+        previous_accumulated = ranges[0].accumulated - ranges[0].frequency
+        last_accumulated = ranges[-1].accumulated - previous_accumulated
         best = (1e99, -1)
-        for i, entry in enumerate(entries):
+        for i, entry in enumerate(ranges):
             if not i:
                 continue
             left_accumulated = entry.accumulated - previous_accumulated
@@ -145,42 +201,40 @@
         midpoint = best[1]
         assert midpoint >= 1, best
 
+        cutoff_bound = ranges[midpoint].numbers[0]
+
         # Now we build the right and left subtrees independently. If any of the
         # subtrees consist of a single entry _and_ the bounds are tight around
         # that entry (that is, the bounds contain _only_ the syscall we are
         # going to consider), we can avoid emitting a leaf node and instead
         # have the comparison jump directly into the action that would be taken
         # by the entry.
-        if entries[midpoint].number == upper_bound:
-            if entries[midpoint].filter:
-                entries[midpoint].filter.accept(visitor)
-                right_subtree = entries[midpoint].filter
+        if (cutoff_bound, upper_bound) == ranges[midpoint].numbers:
+            if ranges[midpoint].filter:
+                ranges[midpoint].filter.accept(visitor)
+                right_subtree = ranges[midpoint].filter
             else:
                 right_subtree = accept_action
         else:
-            right_subtree = _generate_syscall_bst(
-                entries[midpoint:], entries[midpoint].number, upper_bound)
+            right_subtree = _generate_syscall_bst(ranges[midpoint:],
+                                                  cutoff_bound, upper_bound)
 
-        if lower_bound == entries[midpoint].number - 1:
-            assert entries[midpoint - 1].number == lower_bound
-            if entries[midpoint - 1].filter:
-                entries[midpoint - 1].filter.accept(visitor)
-                left_subtree = entries[midpoint - 1].filter
+        if (lower_bound, cutoff_bound) == ranges[midpoint - 1].numbers:
+            if ranges[midpoint - 1].filter:
+                ranges[midpoint - 1].filter.accept(visitor)
+                left_subtree = ranges[midpoint - 1].filter
             else:
                 left_subtree = accept_action
         else:
-            left_subtree = _generate_syscall_bst(
-                entries[:midpoint], lower_bound, entries[midpoint].number - 1)
+            left_subtree = _generate_syscall_bst(ranges[:midpoint],
+                                                 lower_bound, cutoff_bound)
 
         # Finally, now that both subtrees have been generated, we can create
         # the internal node of the binary search tree.
         return bpf.SyscallEntry(
-            entries[midpoint].number,
-            right_subtree,
-            left_subtree,
-            op=bpf.BPF_JGE)
+            cutoff_bound, right_subtree, left_subtree, op=bpf.BPF_JGE)
 
-    return _generate_syscall_bst(entries)
+    return _generate_syscall_bst(ranges)
 
 
 class PolicyCompiler:
diff --git a/tools/compiler_unittest.py b/tools/compiler_unittest.py
index 2f0b0bd..cfa2b8d 100755
--- a/tools/compiler_unittest.py
+++ b/tools/compiler_unittest.py
@@ -19,6 +19,8 @@
 from __future__ import print_function
 
 import os
+import random
+import shutil
 import tempfile
 import unittest
 
@@ -358,6 +360,40 @@
                 bpf.simulate(program.instructions, self.arch.arch_nr,
                              self.arch.syscalls['read'], 0)[1], 'KILL_THREAD')
 
+    def test_compile_simulate(self):
+        """Ensure policy reflects script by testing some random scripts."""
+        iterations = 10
+        for i in range(iterations):
+            num_entries = len(self.arch.syscalls) * (i + 1) // iterations
+            syscalls = dict(
+                zip(
+                    random.sample(self.arch.syscalls.keys(), num_entries),
+                    (random.randint(1, 1024) for _ in range(num_entries)),
+                ))
+
+            frequency_contents = '\n'.join(
+                '%s: %d' % s for s in syscalls.items())
+            policy_contents = '@frequency ./test.frequency\n' + '\n'.join(
+                '%s: 1' % s[0] for s in syscalls.items())
+
+            self._write_file('test.frequency', frequency_contents)
+            path = self._write_file('test.policy', policy_contents)
+
+            for strategy in list(compiler.OptimizationStrategy):
+                program = self.compiler.compile_file(
+                    path,
+                    optimization_strategy=strategy,
+                    kill_action=bpf.KillProcess())
+                for name, number in self.arch.syscalls.items():
+                    expected_result = ('ALLOW'
+                                       if name in syscalls else 'KILL_PROCESS')
+                    self.assertEqual(
+                        bpf.simulate(program.instructions, self.arch.arch_nr,
+                                     number, 0)[1], expected_result,
+                        ('syscall name: %s, syscall number: %d, '
+                         'strategy: %s, policy:\n%s') %
+                        (name, number, strategy, policy_contents))
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/tools/testdata/arch_64.json b/tools/testdata/arch_64.json
index bd3e2f4..10c9855 100644
--- a/tools/testdata/arch_64.json
+++ b/tools/testdata/arch_64.json
@@ -6,7 +6,103 @@
     "read": 0,
     "write": 1,
     "open": 2,
-    "close": 3
+    "close": 3,
+    "syscall_4": 4,
+    "syscall_5": 5,
+    "syscall_6": 6,
+    "syscall_7": 7,
+    "syscall_8": 8,
+    "syscall_9": 9,
+    "syscall_10": 10,
+    "syscall_11": 11,
+    "syscall_12": 12,
+    "syscall_13": 13,
+    "syscall_14": 14,
+    "syscall_15": 15,
+    "syscall_16": 16,
+    "syscall_17": 17,
+    "syscall_18": 18,
+    "syscall_19": 19,
+    "syscall_20": 20,
+    "syscall_21": 21,
+    "syscall_22": 22,
+    "syscall_23": 23,
+    "syscall_24": 24,
+    "syscall_25": 25,
+    "syscall_26": 26,
+    "syscall_27": 27,
+    "syscall_28": 28,
+    "syscall_29": 29,
+    "syscall_30": 30,
+    "syscall_31": 31,
+    "syscall_32": 32,
+    "syscall_33": 33,
+    "syscall_34": 34,
+    "syscall_35": 35,
+    "syscall_36": 36,
+    "syscall_37": 37,
+    "syscall_38": 38,
+    "syscall_39": 39,
+    "syscall_40": 40,
+    "syscall_41": 41,
+    "syscall_42": 42,
+    "syscall_43": 43,
+    "syscall_44": 44,
+    "syscall_45": 45,
+    "syscall_46": 46,
+    "syscall_47": 47,
+    "syscall_48": 48,
+    "syscall_49": 49,
+    "syscall_50": 50,
+    "syscall_51": 51,
+    "syscall_52": 52,
+    "syscall_53": 53,
+    "syscall_54": 54,
+    "syscall_55": 55,
+    "syscall_56": 56,
+    "syscall_57": 57,
+    "syscall_58": 58,
+    "syscall_59": 59,
+    "syscall_60": 60,
+    "syscall_61": 61,
+    "syscall_62": 62,
+    "syscall_63": 63,
+    "syscall_64": 64,
+    "syscall_65": 65,
+    "syscall_66": 66,
+    "syscall_67": 67,
+    "syscall_68": 68,
+    "syscall_69": 69,
+    "syscall_70": 70,
+    "syscall_71": 71,
+    "syscall_72": 72,
+    "syscall_73": 73,
+    "syscall_74": 74,
+    "syscall_75": 75,
+    "syscall_76": 76,
+    "syscall_77": 77,
+    "syscall_78": 78,
+    "syscall_79": 79,
+    "syscall_80": 80,
+    "syscall_81": 81,
+    "syscall_82": 82,
+    "syscall_83": 83,
+    "syscall_84": 84,
+    "syscall_85": 85,
+    "syscall_86": 86,
+    "syscall_87": 87,
+    "syscall_88": 88,
+    "syscall_89": 89,
+    "syscall_90": 90,
+    "syscall_91": 91,
+    "syscall_92": 92,
+    "syscall_93": 93,
+    "syscall_94": 94,
+    "syscall_95": 95,
+    "syscall_96": 96,
+    "syscall_97": 97,
+    "syscall_98": 98,
+    "syscall_99": 99
   },
   "constants": {
     "ENOSYS": 38,