blob: bb81e027cc04ccde886b183296d7c48cde13054a [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright (C) 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import collections
import functools
import itertools
import os
import re
import subprocess
import sys
_ALWAYS_INCLUDED = [
"module_layout", # is exported even if CONFIG_TRIM_UNUSED_KSYMS is enabled
"__put_task_struct", # this allows us to keep `struct task_struct` stable
]
_ABIGAIL_HEADER = "[abi_symbol_list]"
def symbol_sort(symbols):
# use the method that `sort` uses: case insensitive and ignoring
# underscores, that keeps symbols with related name close to each other.
# yeah, that is a bit brute force, but it gets the job done
def __key(a):
"""Creates a key for comparison of symbols."""
# We want to sort underscore prefixed symbols along with those without, but
# before them. Hence add a trailing underscore for every missing leading
# one and strip all others.
# E.g. __blk_mq_end_request, _blk_mq_end_request, blk_mq_end_request get
# replaced by blkmqendrequest, blkmqendrequest_, blkmqendrequest__ and
# compared lexicographically.
# if the caller passes None or an empty string something is odd, so assert
# and ignore if asserts are disabled as we do not need to deal with that
assert (a)
if not a:
return a
tmp = a.lower()
for idx, c in enumerate(tmp):
if c != "_":
break
return (tmp.replace("_", "") + (5 - idx) * "_", a)
return sorted(set(symbols), key=__key)
def find_binaries(directory):
"""Locates vmlinux and kernel modules (*.ko)."""
vmlinux = None
modules = []
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".ko"):
modules.append(os.path.join(root, file))
elif file == "vmlinux":
vmlinux = os.path.join(root, file)
return vmlinux, modules
def extract_undefined_symbols(modules):
"""Extracts undefined symbols from a list of module files."""
# yes, we could pass all of them to nm, but I want to avoid hitting shell
# limits with long lists of modules
result = {}
for module in sorted(modules):
symbols = []
out = subprocess.check_output(["llvm-nm", "--undefined-only", module],
stderr=subprocess.DEVNULL).decode("ascii")
for line in out.splitlines():
symbols.append(line.strip().split()[1])
result[os.path.basename(module)] = symbol_sort(symbols)
return result
def extract_exported_symbols(binary):
"""Extracts the ksymtab exported symbols from a kernel binary."""
symbols = []
out = subprocess.check_output(["llvm-nm", "--defined-only", binary],
stderr=subprocess.DEVNULL).decode("ascii")
for line in out.splitlines():
pos = line.find(" __ksymtab_")
if pos != -1:
symbols.append(line[pos + len(" __ksymtab_"):])
return symbol_sort(symbols)
def extract_generic_exports(vmlinux, modules):
"""Extracts the ksymtab exported symbols from vmlinux and a set of modules."""
symbols = extract_exported_symbols(vmlinux)
for module in modules:
symbols.extend(extract_exported_symbols(module))
return symbols
def extract_exported_in_modules(modules):
"""Extracts the ksymtab exported symbols for a list of kernel modules."""
return {module: extract_exported_symbols(module) for module in modules}
def report_missing(module_symbols, exported):
"""Reports missing symbols that are undefined, but not known in any binary."""
for module, symbols in module_symbols.items():
for symbol in symbols:
if symbol not in exported:
print("Symbol {} required by {} but not provided".format(
symbol, module))
def add_dependent_symbols(module_symbols, exported):
"""Checks the undefined symbols, and adds more to enforce missing dependencies."""
for module, symbols in module_symbols.items():
syms = []
for symbol in symbols:
# Tracepoints are exposed in the ABI using their matching struct
# tracepoint. Sadly this exposes callback functions as void * pointers,
# which make the ABI tooling ineffective to monitor tracepoint changes.
# To enable ABI checks covering tracepoint, add the matching __traceiter
# symbols to the symbol list as they are defined with full types.
if not symbol.startswith('__tracepoint_'):
continue
cur = symbol.replace('__tracepoint_', '__traceiter_')
if (cur not in exported) or (cur in symbols):
continue
syms.append(cur)
module_symbols[module].extend(syms)
def read_symbol_list(symbol_list):
"""Reads a previously created libabigail symbol symbol list."""
symbols = []
with open(symbol_list) as wl:
for line in [l.strip() for l in wl]:
if not line or line.startswith("#") or line.startswith("["):
continue
symbols.append(line)
return symbols
def create_symbol_list(symbol_list, undefined_symbols, exported,
emit_module_symbol_lists, module_grouping,
additions_only):
"""Creates a symbol symbol list for libabigail."""
precious_symbols = set()
if additions_only:
precious_symbols.update(read_symbol_list(symbol_list))
symbol_counter = collections.Counter(
itertools.chain.from_iterable(undefined_symbols.values()))
with open(symbol_list, "w") as wl:
common_wl_section = symbol_sort([
symbol for symbol, count in symbol_counter.items()
if (count > 1 or not module_grouping) and symbol in exported
] + _ALWAYS_INCLUDED)
# write the header
wl.write(_ABIGAIL_HEADER)
wl.write("\n")
if module_grouping:
wl.write("# commonly used symbols\n")
wl.write(" ")
wl.write("\n ".join(common_wl_section))
wl.write("\n")
precious_symbols.difference_update(common_wl_section)
for module, symbols in undefined_symbols.items():
if emit_module_symbol_lists:
mod_wl_file = symbol_list + "_" + os.path.splitext(module)[0]
with open(mod_wl_file, "w") as mod_wl:
# write the header
mod_wl.write(_ABIGAIL_HEADER)
mod_wl.write("\n ")
mod_wl.write("\n ".join([s for s in symbols if s in exported]))
mod_wl.write("\n")
new_wl_section = symbol_sort([
symbol for symbol in symbols
if symbol in exported and symbol not in common_wl_section
])
if not new_wl_section:
continue
wl.write("\n# required by {}\n ".format(module))
wl.write("\n ".join(new_wl_section))
wl.write("\n")
precious_symbols.difference_update(new_wl_section)
if precious_symbols:
wl.write("\n# preserved by --additions-only\n ")
wl.write("\n ".join(symbol_sort(precious_symbols)))
wl.write("\n")
def main():
"""Extracts the required symbols for a directory full of kernel modules."""
parser = argparse.ArgumentParser()
parser.add_argument(
"directory",
nargs="?",
default=os.getcwd(),
help="the directory to search for kernel binaries")
parser.add_argument(
"--skip-report-missing",
action="store_false",
dest="report_missing",
help="Do not report symbols required by modules, but missing from vmlinux"
)
parser.add_argument(
"--include-module-exports",
action="store_true",
help="Include inter-module symbols")
parser.add_argument(
"--full-gki-abi",
action="store_true",
help="Assume all vmlinux and GKI module symbols are part of the ABI")
parser.add_argument(
"--symbol-list", "--whitelist",
help="The symbol list to create")
parser.add_argument(
"--additions-only",
action="store_true",
help="Read the existing symbol list and ensure no symbols get removed")
parser.add_argument(
"--print-modules",
action="store_true",
help="Emit the names of the processed modules")
parser.add_argument(
"--emit-module-symbol-lists", "--emit-module-whitelists",
action="store_true",
help="Emit a separate symbol list for each module")
parser.add_argument(
"--skip-module-grouping",
action="store_false",
dest="module_grouping",
help="Do not group symbols by module.")
parser.add_argument(
"--module-filter",
action="append",
dest="module_filters",
help="Only process modules matching the filter. Can be passed multiple times."
)
parser.add_argument(
"--gki-modules",
help="List of GKI modules which must be provided when the search directory contains both vendor and GKI modules")
args = parser.parse_args()
if not os.path.isdir(args.directory):
print("Expected a directory to search for binaries, but got %s" %
args.directory)
return 1
if args.emit_module_symbol_lists and not args.symbol_list:
print("Emitting module symbol lists requires the --symbol-list parameter.")
return 1
if args.symbol_list is None:
args.symbol_list = "/dev/stdout"
# Locate the Kernel Binaries
vmlinux, modules = find_binaries(args.directory)
if args.module_filters:
modules = [
mod for mod in modules if any(
[re.search(f, os.path.basename(mod)) for f in args.module_filters])
]
# Partition vendor and GKI modules in two lists
gki_modules = []
if args.gki_modules is not None:
with open(args.gki_modules) as f:
gki_modules = [ os.path.basename(mod) for mod in f.read().splitlines() ]
gki_modules = [ mod for mod in modules if os.path.basename(mod) in gki_modules ]
modules = [ mod for mod in modules if mod not in gki_modules ]
if vmlinux is None or not os.path.isfile(vmlinux):
print("Could not find a suitable vmlinux file.")
return 1
# Get required symbols of all modules
undefined_symbols = extract_undefined_symbols(modules)
# Get the actually defined and exported symbols
generic_exports = extract_generic_exports(vmlinux, gki_modules)
local_exports = extract_exported_in_modules(modules)
# Build the list of all exported symbols (generic and local)
all_exported = list(
itertools.chain.from_iterable(local_exports.values()))
all_exported.extend(generic_exports)
all_exported = set(all_exported)
add_dependent_symbols(undefined_symbols, all_exported)
# For sanity, check for inconsistencies between required and exported symbols
# Do not do this analysis if module_filters are in place as likely
# inter-module dependencies are broken by this.
if args.report_missing and not args.module_filters:
report_missing(undefined_symbols, all_exported)
# If specified, create the symbol list
if args.symbol_list:
create_symbol_list(
args.symbol_list,
{ "full-gki-abi": generic_exports } if args.full_gki_abi else undefined_symbols,
all_exported if args.include_module_exports else generic_exports,
args.emit_module_symbol_lists,
args.module_grouping,
args.additions_only)
if args.print_modules:
print("These modules have been considered when creating the symbol list:")
print(" " +
"\n ".join(sorted([os.path.basename(mod) for mod in modules])))
if __name__ == "__main__":
sys.exit(main())