Skip to content

Commit

Permalink
pythonGH-115802: Use the GHC calling convention in JIT code (pythonGH…
Browse files Browse the repository at this point in the history
  • Loading branch information
brandtbucher authored May 1, 2024
1 parent beb653c commit 49baa65
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 27 deletions.
1 change: 1 addition & 0 deletions Include/cpython/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ typedef struct _PyExecutorObject {
uint32_t code_size;
size_t jit_size;
void *jit_code;
void *jit_side_entry;
_PyExitData exits[1];
} _PyExecutorObject;

Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_perf_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def is_unwinding_reliable():
cflags = sysconfig.get_config_var("PY_CORE_CFLAGS")
if not cflags:
return False
return "no-omit-frame-pointer" in cflags
return "no-omit-frame-pointer" in cflags and "_Py_JIT" not in cflags


def perf_command_works():
Expand Down
52 changes: 37 additions & 15 deletions Python/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction *trace, size
{
// Loop once to find the total compiled size:
size_t instruction_starts[UOP_MAX_TRACE_LENGTH];
size_t code_size = 0;
size_t data_size = 0;
size_t code_size = trampoline.code.body_size;
size_t data_size = trampoline.data.body_size;
for (size_t i = 0; i < length; i++) {
_PyUOpInstruction *instruction = (_PyUOpInstruction *)&trace[i];
const StencilGroup *group = &stencil_groups[instruction->opcode];
Expand All @@ -408,11 +408,29 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction *trace, size
// Loop again to emit the code:
unsigned char *code = memory;
unsigned char *data = memory + code_size;
{
// Compile the trampoline, which handles converting between the native
// calling convention and the calling convention used by jitted code
// (which may be different for efficiency reasons). On platforms where
// we don't change calling conventions, the trampoline is empty and
// nothing is emitted here:
const StencilGroup *group = &trampoline;
// Think of patches as a dictionary mapping HoleValue to uintptr_t:
uintptr_t patches[] = GET_PATCHES();
patches[HoleValue_CODE] = (uintptr_t)code;
patches[HoleValue_CONTINUE] = (uintptr_t)code + group->code.body_size;
patches[HoleValue_DATA] = (uintptr_t)data;
patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
patches[HoleValue_TOP] = (uintptr_t)memory + trampoline.code.body_size;
patches[HoleValue_ZERO] = 0;
emit(group, patches);
code += group->code.body_size;
data += group->data.body_size;
}
assert(trace[0].opcode == _START_EXECUTOR || trace[0].opcode == _COLD_EXIT);
for (size_t i = 0; i < length; i++) {
_PyUOpInstruction *instruction = (_PyUOpInstruction *)&trace[i];
const StencilGroup *group = &stencil_groups[instruction->opcode];
// Think of patches as a dictionary mapping HoleValue to uintptr_t:
uintptr_t patches[] = GET_PATCHES();
patches[HoleValue_CODE] = (uintptr_t)code;
patches[HoleValue_CONTINUE] = (uintptr_t)code + group->code.body_size;
Expand Down Expand Up @@ -454,25 +472,28 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction *trace, size
code += group->code.body_size;
data += group->data.body_size;
}
// Protect against accidental buffer overrun into data:
const StencilGroup *group = &stencil_groups[_FATAL_ERROR];
uintptr_t patches[] = GET_PATCHES();
patches[HoleValue_CODE] = (uintptr_t)code;
patches[HoleValue_CONTINUE] = (uintptr_t)code;
patches[HoleValue_DATA] = (uintptr_t)data;
patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
patches[HoleValue_TOP] = (uintptr_t)code;
patches[HoleValue_ZERO] = 0;
emit(group, patches);
code += group->code.body_size;
data += group->data.body_size;
{
// Protect against accidental buffer overrun into data:
const StencilGroup *group = &stencil_groups[_FATAL_ERROR];
uintptr_t patches[] = GET_PATCHES();
patches[HoleValue_CODE] = (uintptr_t)code;
patches[HoleValue_CONTINUE] = (uintptr_t)code;
patches[HoleValue_DATA] = (uintptr_t)data;
patches[HoleValue_EXECUTOR] = (uintptr_t)executor;
patches[HoleValue_TOP] = (uintptr_t)code;
patches[HoleValue_ZERO] = 0;
emit(group, patches);
code += group->code.body_size;
data += group->data.body_size;
}
assert(code == memory + code_size);
assert(data == memory + code_size + data_size);
if (mark_executable(memory, total_size)) {
jit_free(memory, total_size);
return -1;
}
executor->jit_code = memory;
executor->jit_side_entry = memory + trampoline.code.body_size;
executor->jit_size = total_size;
return 0;
}
Expand All @@ -484,6 +505,7 @@ _PyJIT_Free(_PyExecutorObject *executor)
size_t size = executor->jit_size;
if (memory) {
executor->jit_code = NULL;
executor->jit_side_entry = NULL;
executor->jit_size = 0;
if (jit_free(memory, size)) {
PyErr_WriteUnraisable(NULL);
Expand Down
2 changes: 2 additions & 0 deletions Python/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ make_executor_from_uops(_PyUOpInstruction *buffer, int length, const _PyBloomFil
#endif
#ifdef _Py_JIT
executor->jit_code = NULL;
executor->jit_side_entry = NULL;
executor->jit_size = 0;
if (_PyJIT_Compile(executor, executor->trace, length)) {
Py_DECREF(executor);
Expand Down Expand Up @@ -1219,6 +1220,7 @@ init_cold_exit_executor(_PyExecutorObject *executor, int oparg)
#endif
#ifdef _Py_JIT
executor->jit_code = NULL;
executor->jit_side_entry = NULL;
executor->jit_size = 0;
if (_PyJIT_Compile(executor, executor->trace, 1)) {
return -1;
Expand Down
54 changes: 45 additions & 9 deletions Tools/jit/_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class _Target(typing.Generic[_S, _R]):
_: dataclasses.KW_ONLY
alignment: int = 1
args: typing.Sequence[str] = ()
ghccc: bool = False
prefix: str = ""
debug: bool = False
force: bool = False
Expand Down Expand Up @@ -85,7 +86,11 @@ async def _parse(self, path: pathlib.Path) -> _stencils.StencilGroup:
sections: list[dict[typing.Literal["Section"], _S]] = json.loads(output)
for wrapped_section in sections:
self._handle_section(wrapped_section["Section"], group)
assert group.symbols["_JIT_ENTRY"] == (_stencils.HoleValue.CODE, 0)
# The trampoline's entry point is just named "_ENTRY", since on some
# platforms we later assume that any function starting with "_JIT_" uses
# the GHC calling convention:
entry_symbol = "_JIT_ENTRY" if "_JIT_ENTRY" in group.symbols else "_ENTRY"
assert group.symbols[entry_symbol] == (_stencils.HoleValue.CODE, 0)
if group.data.body:
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
group.data.disassembly.append(line)
Expand All @@ -103,6 +108,9 @@ def _handle_relocation(
async def _compile(
self, opname: str, c: pathlib.Path, tempdir: pathlib.Path
) -> _stencils.StencilGroup:
# "Compile" the trampoline to an empty stencil group if it's not needed:
if opname == "trampoline" and not self.ghccc:
return _stencils.StencilGroup()
o = tempdir / f"{opname}.o"
args = [
f"--target={self.triple}",
Expand Down Expand Up @@ -130,13 +138,38 @@ async def _compile(
"-fno-plt",
# Don't call stack-smashing canaries that we can't find or patch:
"-fno-stack-protector",
"-o",
f"{o}",
"-std=c11",
f"{c}",
*self.args,
]
await _llvm.run("clang", args, echo=self.verbose)
if self.ghccc:
# This is a bit of an ugly workaround, but it makes the code much
# smaller and faster, so it's worth it. We want to use the GHC
# calling convention, but Clang doesn't support it. So, we *first*
# compile the code to LLVM IR, perform some text replacements on the
# IR to change the calling convention(!), and then compile *that*.
# Once we have access to Clang 19, we can get rid of this and use
# __attribute__((preserve_none)) directly in the C code instead:
ll = tempdir / f"{opname}.ll"
args_ll = args + [
# -fomit-frame-pointer is necessary because the GHC calling
# convention uses RBP to pass arguments:
"-S", "-emit-llvm", "-fomit-frame-pointer", "-o", f"{ll}", f"{c}"
]
await _llvm.run("clang", args_ll, echo=self.verbose)
ir = ll.read_text()
# This handles declarations, definitions, and calls to named symbols
# starting with "_JIT_":
ir = re.sub(r"(((noalias|nonnull|noundef) )*ptr @_JIT_\w+\()", r"ghccc \1", ir)
# This handles calls to anonymous callees, since anything with
# "musttail" needs to use the same calling convention:
ir = ir.replace("musttail call", "musttail call ghccc")
# Sometimes *both* replacements happen at the same site, so fix it:
ir = ir.replace("ghccc ghccc", "ghccc")
ll.write_text(ir)
args_o = args + ["-Wno-unused-command-line-argument", "-o", f"{o}", f"{ll}"]
else:
args_o = args + ["-o", f"{o}", f"{c}"]
await _llvm.run("clang", args_o, echo=self.verbose)
return await self._parse(o)

async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:
Expand All @@ -146,6 +179,8 @@ async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:
with tempfile.TemporaryDirectory() as tempdir:
work = pathlib.Path(tempdir).resolve()
async with asyncio.TaskGroup() as group:
coro = self._compile("trampoline", TOOLS_JIT / "trampoline.c", work)
tasks.append(group.create_task(coro, name="trampoline"))
for opname in opnames:
coro = self._compile(opname, TOOLS_JIT_TEMPLATE_C, work)
tasks.append(group.create_task(coro, name=opname))
Expand Down Expand Up @@ -445,6 +480,7 @@ def _handle_relocation(

def get_target(host: str) -> _COFF | _ELF | _MachO:
"""Build a _Target for the given host "triple" and options."""
# ghccc currently crashes Clang when combined with musttail on aarch64. :(
if re.fullmatch(r"aarch64-apple-darwin.*", host):
return _MachO(host, alignment=8, prefix="_")
if re.fullmatch(r"aarch64-pc-windows-msvc", host):
Expand All @@ -455,13 +491,13 @@ def get_target(host: str) -> _COFF | _ELF | _MachO:
return _ELF(host, alignment=8, args=args)
if re.fullmatch(r"i686-pc-windows-msvc", host):
args = ["-DPy_NO_ENABLE_SHARED"]
return _COFF(host, args=args, prefix="_")
return _COFF(host, args=args, ghccc=True, prefix="_")
if re.fullmatch(r"x86_64-apple-darwin.*", host):
return _MachO(host, prefix="_")
return _MachO(host, ghccc=True, prefix="_")
if re.fullmatch(r"x86_64-pc-windows-msvc", host):
args = ["-fms-runtime-lib=dll"]
return _COFF(host, args=args)
return _COFF(host, args=args, ghccc=True)
if re.fullmatch(r"x86_64-.*-linux-gnu", host):
args = ["-fpic"]
return _ELF(host, args=args)
return _ELF(host, args=args, ghccc=True)
raise ValueError(host)
4 changes: 4 additions & 0 deletions Tools/jit/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ def _dump_footer(opnames: typing.Iterable[str]) -> typing.Iterator[str]:
yield ""
yield "static const StencilGroup stencil_groups[512] = {"
for opname in opnames:
if opname == "trampoline":
continue
yield f" [{opname}] = INIT_STENCIL_GROUP({opname}),"
yield "};"
yield ""
yield "static const StencilGroup trampoline = INIT_STENCIL_GROUP(trampoline);"
yield ""
yield "#define GET_PATCHES() { \\"
for value in _stencils.HoleValue:
yield f" [HoleValue_{value.name}] = (uintptr_t)0xBADBADBADBADBADB, \\"
Expand Down
4 changes: 2 additions & 2 deletions Tools/jit/template.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
do { \
OPT_STAT_INC(traces_executed); \
__attribute__((musttail)) \
return ((jit_func)((EXECUTOR)->jit_code))(frame, stack_pointer, tstate); \
return ((jit_func)((EXECUTOR)->jit_side_entry))(frame, stack_pointer, tstate); \
} while (0)

#undef GOTO_TIER_ONE
Expand All @@ -65,7 +65,7 @@ do { \

#define PATCH_VALUE(TYPE, NAME, ALIAS) \
PyAPI_DATA(void) ALIAS; \
TYPE NAME = (TYPE)(uint64_t)&ALIAS;
TYPE NAME = (TYPE)(uintptr_t)&ALIAS;

#define PATCH_JUMP(ALIAS) \
do { \
Expand Down
25 changes: 25 additions & 0 deletions Tools/jit/trampoline.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "Python.h"

#include "pycore_ceval.h"
#include "pycore_frame.h"
#include "pycore_jit.h"

// This is where the calling convention changes, on platforms that require it.
// The actual change is patched in while the JIT compiler is being built, in
// Tools/jit/_targets.py. On other platforms, this function compiles to nothing.
_Py_CODEUNIT *
_ENTRY(_PyInterpreterFrame *frame, PyObject **stack_pointer, PyThreadState *tstate)
{
// This is subtle. The actual trace will return to us once it exits, so we
// need to make sure that we stay alive until then. If our trace side-exits
// into another trace, and this trace is then invalidated, the code for
// *this function* will be freed and we'll crash upon return:
PyAPI_DATA(void) _JIT_EXECUTOR;
PyObject *executor = (PyObject *)(uintptr_t)&_JIT_EXECUTOR;
Py_INCREF(executor);
// Note that this is *not* a tail call:
PyAPI_DATA(void) _JIT_CONTINUE;
_Py_CODEUNIT *target = ((jit_func)&_JIT_CONTINUE)(frame, stack_pointer, tstate);
Py_SETREF(tstate->previous_executor, executor);
return target;
}

0 comments on commit 49baa65

Please sign in to comment.