blob: 3f030845ec730aa083ebd68a6c4d63b62ad48a5e [file] [log] [blame]
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This script will take any number of trace files generated by strace(1)
# and output a system call filtering policy suitable for use with Minijail.
"""Helper tool to generate a minijail seccomp filter from strace output."""
from __future__ import print_function
import argparse
import collections
import re
import sys
NOTICE = """# Copyright (C) 2018 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ALLOW = '%s: 1'
# This ignores any leading PID tag and trailing <unfinished ...>, and extracts
# the syscall name and the argument list.
LINE_RE = re.compile(r'^\s*(?:\[[^]]*\]|\d+)?\s*([a-zA-Z0-9_]+)\(([^)<]*)')
SOCKETCALLS = {
'accept', 'bind', 'connect', 'getpeername', 'getsockname', 'getsockopt',
'listen', 'recv', 'recvfrom', 'recvmsg', 'send', 'sendmsg', 'sendto',
'setsockopt', 'shutdown', 'socket', 'socketpair',
}
# /* Protocol families. */
# #define PF_UNSPEC 0 /* Unspecified. */
# #define PF_LOCAL 1 /* Local to host (pipes and file-domain). */
# #define PF_UNIX PF_LOCAL /* POSIX name for PF_LOCAL. */
# #define PF_FILE PF_LOCAL /* Another non-standard name for PF_LOCAL. */
# #define PF_INET 2 /* IP protocol family. */
# #define PF_AX25 3 /* Amateur Radio AX.25. */
# #define PF_IPX 4 /* Novell Internet Protocol. */
# #define PF_APPLETALK 5 /* Appletalk DDP. */
# #define PF_NETROM 6 /* Amateur radio NetROM. */
# #define PF_BRIDGE 7 /* Multiprotocol bridge. */
# #define PF_ATMPVC 8 /* ATM PVCs. */
# #define PF_X25 9 /* Reserved for X.25 project. */
# #define PF_INET6 10 /* IP version 6. */
# #define PF_ROSE 11 /* Amateur Radio X.25 PLP. */
# #define PF_DECnet 12 /* Reserved for DECnet project. */
# #define PF_NETBEUI 13 /* Reserved for 802.2LLC project. */
# #define PF_SECURITY 14 /* Security callback pseudo AF. */
# #define PF_KEY 15 /* PF_KEY key management API. */
# #define PF_NETLINK 16
ArgInspectionEntry = collections.namedtuple('ArgInspectionEntry',
('arg_index', 'value_set'))
def get_parser():
"""Return a CLI parser for this tool."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('traces', nargs='+', help='The strace logs.')
return parser
def main(argv):
parser = get_parser()
opts = parser.parse_args(argv)
syscalls = collections.defaultdict(int)
uses_socketcall = False
basic_set = [
'restart_syscall', 'exit', 'exit_group', 'rt_sigreturn',
]
syscall_sets = {}
syscall_set_list = [
['sigreturn', 'rt_sigreturn'],
['sigaction', 'rt_sigaction'],
['sigprocmask', 'rt_sigprocmask'],
['open', 'openat'],
['mmap', 'mremap'],
['mmap2', 'mremap'],
]
arg_inspection = {
'socket': ArgInspectionEntry(0, set([])), # int domain
'ioctl': ArgInspectionEntry(1, set([])), # int request
'prctl': ArgInspectionEntry(0, set([])), # int option
'mmap': ArgInspectionEntry(2, set([])), # int prot
'mmap2': ArgInspectionEntry(2, set([])), # int prot
'mprotect': ArgInspectionEntry(2, set([])), # int prot
}
for syscall_list in syscall_set_list:
for syscall in syscall_list:
other_syscalls = syscall_list[:]
other_syscalls.remove(syscall)
syscall_sets[syscall] = other_syscalls
for trace_filename in opts.traces:
if 'i386' in trace_filename or ('x86' in trace_filename and
'64' not in trace_filename):
uses_socketcall = True
trace_file = open(trace_filename)
for line in trace_file:
matches = LINE_RE.match(line)
if not matches:
continue
syscall, args = matches.groups()
if uses_socketcall and syscall in SOCKETCALLS:
syscall = 'socketcall'
syscalls[syscall] += 1
args = [arg.strip() for arg in args.split(',')]
if syscall in arg_inspection:
arg_value = args[arg_inspection[syscall].arg_index]
arg_inspection[syscall].value_set.add(arg_value)
# Sort the syscalls based on frequency. This way the calls that are used
# more often come first which in turn speeds up the filter slightly.
sorted_syscalls = list(
x[0] for x in sorted(syscalls.items(), key=lambda pair: pair[1],
reverse=True)
)
print(NOTICE)
all_syscalls = sorted_syscalls
# Add the basic set once the frequency drops below 2.
below_ten_index = -1
for sorted_syscall in sorted_syscalls:
if syscalls[sorted_syscall] < 2:
below_ten_index = all_syscalls.index(sorted_syscall)
break
first_half = all_syscalls[:below_ten_index]
for basic_syscall in basic_set:
if basic_syscall not in all_syscalls:
first_half.append(basic_syscall)
all_syscalls = first_half + all_syscalls[below_ten_index:]
for syscall in all_syscalls:
if syscall in arg_inspection:
arg_index = arg_inspection[syscall].arg_index
arg_values = arg_inspection[syscall].value_set
arg_filter = ' || '.join('arg%d == %s' % (arg_index, arg_value)
for arg_value in arg_values)
print(syscall + ': ' + arg_filter)
else:
print(ALLOW % syscall)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))