blob: c2f8984425055ce2ccec11dbbcd3232f0f90a783 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright (C) 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import collections
import functools
import itertools
import os
import subprocess
import sys
def symbol_sort(symbols):
# use the method that `sort` uses: case insensitive and ignoring
# underscores, that keeps symbols with related name close to each other.
# yeah, that is a bit brute force, but it gets the job done
def __key(a):
"""Create a key for comparison of symbols."""
# We want to sort underscore prefixed symbols along with those without, but
# before them. Hence add a trailing underscore for every missing leading
# one and strip all others.
# E.g. __blk_mq_end_request, _blk_mq_end_request, blk_mq_end_request get
# replaced by blkmqendrequest, blkmqendrequest_, blkmqendrequest__ and
# compared lexicographically.
# if the caller passes None or an empty string something is odd, so assert
# and ignore if asserts are disabled as we do not need to deal with that
assert (a)
if not a:
return a
tmp = a.lower()
for idx, c in enumerate(tmp):
if c != "_":
break
return tmp.replace("_", "") + (5 - idx) * "_"
return sorted(set(symbols), key=__key)
def find_binaries(directory):
"""Locate vmlinux and kernel modules (*.ko)."""
vmlinux = None
modules = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".ko"):
modules.append(os.path.join(root, file))
elif file == "vmlinux":
vmlinux = os.path.join(root, file)
return vmlinux, modules
def extract_undefined_symbols(modules):
"""Extract undefined symbols from a list of module files."""
# yes, we could pass all of them to nm, but I want to avoid hitting shell
# limits with long lists of modules
result = {}
for module in sorted(modules):
symbols = []
out = subprocess.check_output(["nm", "--undefined-only", module],
encoding="ascii",
stderr=subprocess.DEVNULL)
for line in out.splitlines():
symbols.append(line.strip().split()[1])
result[os.path.basename(module)] = symbol_sort(symbols)
return result
def extract_exported_symbols(binary):
"""Extract the ksymtab exported symbols from a kernel binary."""
symbols = []
out = subprocess.check_output(["nm", "--defined-only", binary],
encoding="ascii",
stderr=subprocess.DEVNULL)
for line in out.splitlines():
pos = line.find(" __ksymtab_")
if pos != -1:
symbols.append(line[pos + len(" __ksymtab_"):])
return symbol_sort(symbols)
def extract_exported_in_modules(modules):
"""Extract the ksymtab exported symbols for a list of kernel modules."""
return {module: extract_exported_symbols(module) for module in modules}
def report_missing(module_symbols, exported):
"""Report missing symbols that are undefined, but not know in any binary."""
for module, symbols in module_symbols.items():
for symbol in symbols:
if symbol not in exported:
print("Symbol {} required by {} but not provided".format(
symbol, module))
def create_whitelist(whitelist, undefined_symbols, exported):
"""Create a symbol whitelist for libabigail."""
symbol_counter = collections.Counter(
itertools.chain.from_iterable(undefined_symbols.values()))
with open(whitelist, "w") as wl:
common_wl_section = symbol_sort([
symbol for symbol, count in symbol_counter.items()
if count > 1 and symbol in exported
])
# always write the header
wl.write("[abi_whitelist]\n")
if common_wl_section:
wl.write("# commonly used symbols\n ")
wl.write("\n ".join(common_wl_section))
wl.write("\n")
else:
# ensure the whitelist contains at least one symbol to not be ignored
wl.write(" dummy_symbol\n")
for module, symbols in undefined_symbols.items():
new_wl_section = symbol_sort([
symbol for symbol in symbols
if symbol in exported and symbol not in common_wl_section
])
if not new_wl_section:
continue
wl.write("\n# required by {}\n ".format(module))
wl.write("\n ".join(new_wl_section))
wl.write("\n")
def main():
"""Extract the required symbols for a directory full of kernel modules."""
parser = argparse.ArgumentParser()
parser.add_argument(
"directory",
nargs="?",
default=os.getcwd(),
help="the directory to search for kernel binaries")
parser.add_argument(
"--skip-report-missing",
action="store_false",
dest="report_missing",
default=True,
help="Do not report symbols required by modules, but missing from vmlinux"
)
parser.add_argument(
"--include-module-exports",
action="store_true",
default=False,
help="Include inter-module symbols")
parser.add_argument(
"--whitelist",
default="/dev/stdout",
help="The whitelist to create")
args = parser.parse_args()
if not os.path.isdir(args.directory):
print("Expected a directory to search for binaries, but got %s" %
args.directory)
return 1
# Locate the Kernel Binaries
vmlinux, modules = find_binaries(args.directory)
if vmlinux is None or not os.path.isfile(vmlinux):
print("Could not find a suitable vmlinux file.")
return 1
# Get required symbols of all modules
undefined_symbols = extract_undefined_symbols(modules)
# Get the actually defined and exported symbols
exported_in_vmlinux = extract_exported_symbols(vmlinux)
exported_in_modules = extract_exported_in_modules(modules)
# Build the list of all exported symbols (vmlinux + modules)
all_exported = list(
itertools.chain.from_iterable(exported_in_modules.values()))
all_exported.extend(exported_in_vmlinux)
all_exported = set(all_exported)
# For sanity, check for inconsistencies between required and exported symbols
if args.report_missing:
report_missing(undefined_symbols, all_exported)
# If specified, create the whitelist
if args.whitelist:
create_whitelist(
args.whitelist, undefined_symbols,
all_exported if args.include_module_exports else exported_in_vmlinux)
if __name__ == "__main__":
sys.exit(main())