| """Core data structures for compiled code templates.""" |
| import dataclasses |
| import enum |
| import sys |
| |
| import _schema |
| |
| |
| @enum.unique |
| class HoleValue(enum.Enum): |
| """ |
| Different "base" values that can be patched into holes (usually combined with the |
| address of a symbol and/or an addend). |
| """ |
| |
| # The base address of the machine code for the current uop (exposed as _JIT_ENTRY): |
| CODE = enum.auto() |
| # The base address of the machine code for the next uop (exposed as _JIT_CONTINUE): |
| CONTINUE = enum.auto() |
| # The base address of the read-only data for this uop: |
| DATA = enum.auto() |
| # The address of the current executor (exposed as _JIT_EXECUTOR): |
| EXECUTOR = enum.auto() |
| # The base address of the "global" offset table located in the read-only data. |
| # Shouldn't be present in the final stencils, since these are all replaced with |
| # equivalent DATA values: |
| GOT = enum.auto() |
| # The current uop's oparg (exposed as _JIT_OPARG): |
| OPARG = enum.auto() |
| # The current uop's operand on 64-bit platforms (exposed as _JIT_OPERAND): |
| OPERAND = enum.auto() |
| # The current uop's operand on 32-bit platforms (exposed as _JIT_OPERAND_HI and _JIT_OPERAND_LO): |
| OPERAND_HI = enum.auto() |
| OPERAND_LO = enum.auto() |
| # The current uop's target (exposed as _JIT_TARGET): |
| TARGET = enum.auto() |
| # The base address of the machine code for the jump target (exposed as _JIT_JUMP_TARGET): |
| JUMP_TARGET = enum.auto() |
| # The base address of the machine code for the error jump target (exposed as _JIT_ERROR_TARGET): |
| ERROR_TARGET = enum.auto() |
| # The index of the exit to be jumped through (exposed as _JIT_EXIT_INDEX): |
| EXIT_INDEX = enum.auto() |
| # The base address of the machine code for the first uop (exposed as _JIT_TOP): |
| TOP = enum.auto() |
| # A hardcoded value of zero (used for symbol lookups): |
| ZERO = enum.auto() |
| |
| |
| @dataclasses.dataclass |
| class Hole: |
| """ |
| A "hole" in the stencil to be patched with a computed runtime value. |
| |
| Analogous to relocation records in an object file. |
| """ |
| |
| offset: int |
| kind: _schema.HoleKind |
| # Patch with this base value: |
| value: HoleValue |
| # ...plus the address of this symbol: |
| symbol: str | None |
| # ...plus this addend: |
| addend: int |
| # Convenience method: |
| replace = dataclasses.replace |
| |
| def as_c(self) -> str: |
| """Dump this hole as an initialization of a C Hole struct.""" |
| parts = [ |
| f"{self.offset:#x}", |
| f"HoleKind_{self.kind}", |
| f"HoleValue_{self.value.name}", |
| f"&{self.symbol}" if self.symbol else "NULL", |
| f"{_signed(self.addend):#x}", |
| ] |
| return f"{{{', '.join(parts)}}}" |
| |
| |
| @dataclasses.dataclass |
| class Stencil: |
| """ |
| A contiguous block of machine code or data to be copied-and-patched. |
| |
| Analogous to a section or segment in an object file. |
| """ |
| |
| body: bytearray = dataclasses.field(default_factory=bytearray, init=False) |
| holes: list[Hole] = dataclasses.field(default_factory=list, init=False) |
| disassembly: list[str] = dataclasses.field(default_factory=list, init=False) |
| |
| def pad(self, alignment: int) -> None: |
| """Pad the stencil to the given alignment.""" |
| offset = len(self.body) |
| padding = -offset % alignment |
| self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}") |
| self.body.extend([0] * padding) |
| |
| def emit_aarch64_trampoline(self, hole: Hole) -> None: |
| """Even with the large code model, AArch64 Linux insists on 28-bit jumps.""" |
| base = len(self.body) |
| where = slice(hole.offset, hole.offset + 4) |
| instruction = int.from_bytes(self.body[where], sys.byteorder) |
| instruction &= 0xFC000000 |
| instruction |= ((base - hole.offset) >> 2) & 0x03FFFFFF |
| self.body[where] = instruction.to_bytes(4, sys.byteorder) |
| self.disassembly += [ |
| f"{base + 4 * 0:x}: d2800008 mov x8, #0x0", |
| f"{base + 4 * 0:016x}: R_AARCH64_MOVW_UABS_G0_NC {hole.symbol}", |
| f"{base + 4 * 1:x}: f2a00008 movk x8, #0x0, lsl #16", |
| f"{base + 4 * 1:016x}: R_AARCH64_MOVW_UABS_G1_NC {hole.symbol}", |
| f"{base + 4 * 2:x}: f2c00008 movk x8, #0x0, lsl #32", |
| f"{base + 4 * 2:016x}: R_AARCH64_MOVW_UABS_G2_NC {hole.symbol}", |
| f"{base + 4 * 3:x}: f2e00008 movk x8, #0x0, lsl #48", |
| f"{base + 4 * 3:016x}: R_AARCH64_MOVW_UABS_G3 {hole.symbol}", |
| f"{base + 4 * 4:x}: d61f0100 br x8", |
| ] |
| for code in [ |
| 0xD2800008.to_bytes(4, sys.byteorder), |
| 0xF2A00008.to_bytes(4, sys.byteorder), |
| 0xF2C00008.to_bytes(4, sys.byteorder), |
| 0xF2E00008.to_bytes(4, sys.byteorder), |
| 0xD61F0100.to_bytes(4, sys.byteorder), |
| ]: |
| self.body.extend(code) |
| for i, kind in enumerate( |
| [ |
| "R_AARCH64_MOVW_UABS_G0_NC", |
| "R_AARCH64_MOVW_UABS_G1_NC", |
| "R_AARCH64_MOVW_UABS_G2_NC", |
| "R_AARCH64_MOVW_UABS_G3", |
| ] |
| ): |
| self.holes.append(hole.replace(offset=base + 4 * i, kind=kind)) |
| |
| def remove_jump(self, *, alignment: int = 1) -> None: |
| """Remove a zero-length continuation jump, if it exists.""" |
| hole = max(self.holes, key=lambda hole: hole.offset) |
| match hole: |
| case Hole( |
| offset=offset, |
| kind="IMAGE_REL_AMD64_REL32", |
| value=HoleValue.GOT, |
| symbol="_JIT_CONTINUE", |
| addend=-4, |
| ) as hole: |
| # jmp qword ptr [rip] |
| jump = b"\x48\xFF\x25\x00\x00\x00\x00" |
| offset -= 3 |
| case Hole( |
| offset=offset, |
| kind="IMAGE_REL_I386_REL32" | "X86_64_RELOC_BRANCH", |
| value=HoleValue.CONTINUE, |
| symbol=None, |
| addend=-4, |
| ) as hole: |
| # jmp 5 |
| jump = b"\xE9\x00\x00\x00\x00" |
| offset -= 1 |
| case Hole( |
| offset=offset, |
| kind="R_AARCH64_JUMP26", |
| value=HoleValue.CONTINUE, |
| symbol=None, |
| addend=0, |
| ) as hole: |
| # b #4 |
| jump = b"\x00\x00\x00\x14" |
| case Hole( |
| offset=offset, |
| kind="R_X86_64_GOTPCRELX", |
| value=HoleValue.GOT, |
| symbol="_JIT_CONTINUE", |
| addend=addend, |
| ) as hole: |
| assert _signed(addend) == -4 |
| # jmp qword ptr [rip] |
| jump = b"\xFF\x25\x00\x00\x00\x00" |
| offset -= 2 |
| case _: |
| return |
| if self.body[offset:] == jump and offset % alignment == 0: |
| self.body = self.body[:offset] |
| self.holes.remove(hole) |
| |
| |
| @dataclasses.dataclass |
| class StencilGroup: |
| """ |
| Code and data corresponding to a given micro-opcode. |
| |
| Analogous to an entire object file. |
| """ |
| |
| code: Stencil = dataclasses.field(default_factory=Stencil, init=False) |
| data: Stencil = dataclasses.field(default_factory=Stencil, init=False) |
| symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field( |
| default_factory=dict, init=False |
| ) |
| _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False) |
| |
| def process_relocations(self, *, alignment: int = 1) -> None: |
| """Fix up all GOT and internal relocations for this stencil group.""" |
| for hole in self.code.holes.copy(): |
| if ( |
| hole.kind in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26"} |
| and hole.value is HoleValue.ZERO |
| ): |
| self.code.pad(alignment) |
| self.code.emit_aarch64_trampoline(hole) |
| self.code.holes.remove(hole) |
| self.code.remove_jump(alignment=alignment) |
| self.code.pad(alignment) |
| self.data.pad(8) |
| for stencil in [self.code, self.data]: |
| for hole in stencil.holes: |
| if hole.value is HoleValue.GOT: |
| assert hole.symbol is not None |
| hole.value = HoleValue.DATA |
| hole.addend += self._global_offset_table_lookup(hole.symbol) |
| hole.symbol = None |
| elif hole.symbol in self.symbols: |
| hole.value, addend = self.symbols[hole.symbol] |
| hole.addend += addend |
| hole.symbol = None |
| elif ( |
| hole.kind in {"IMAGE_REL_AMD64_REL32"} |
| and hole.value is HoleValue.ZERO |
| ): |
| raise ValueError( |
| f"Add PyAPI_FUNC(...) or PyAPI_DATA(...) to declaration of {hole.symbol}!" |
| ) |
| self._emit_global_offset_table() |
| self.code.holes.sort(key=lambda hole: hole.offset) |
| self.data.holes.sort(key=lambda hole: hole.offset) |
| |
| def _global_offset_table_lookup(self, symbol: str) -> int: |
| return len(self.data.body) + self._got.setdefault(symbol, 8 * len(self._got)) |
| |
| def _emit_global_offset_table(self) -> None: |
| got = len(self.data.body) |
| for s, offset in self._got.items(): |
| if s in self.symbols: |
| value, addend = self.symbols[s] |
| symbol = None |
| else: |
| value, symbol = symbol_to_value(s) |
| addend = 0 |
| self.data.holes.append( |
| Hole(got + offset, "R_X86_64_64", value, symbol, addend) |
| ) |
| value_part = value.name if value is not HoleValue.ZERO else "" |
| if value_part and not symbol and not addend: |
| addend_part = "" |
| else: |
| signed = "+" if symbol is not None else "" |
| addend_part = f"&{symbol}" if symbol else "" |
| addend_part += f"{_signed(addend):{signed}#x}" |
| if value_part: |
| value_part += "+" |
| self.data.disassembly.append( |
| f"{len(self.data.body):x}: {value_part}{addend_part}" |
| ) |
| self.data.body.extend([0] * 8) |
| |
| |
| def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]: |
| """ |
| Convert a symbol name to a HoleValue and a symbol name. |
| |
| Some symbols (starting with "_JIT_") are special and are converted to their |
| own HoleValues. |
| """ |
| if symbol.startswith("_JIT_"): |
| try: |
| return HoleValue[symbol.removeprefix("_JIT_")], None |
| except KeyError: |
| pass |
| return HoleValue.ZERO, symbol |
| |
| |
| def _signed(value: int) -> int: |
| value %= 1 << 64 |
| if value & (1 << 63): |
| value -= 1 << 64 |
| return value |