This post shows how to write a basic JIT compiler for the Python bytecode, from scratch, using nothing but stock Python modules.
We will leverage the code written in a previous post to bind native code to callable Python functions. The complete code is available at github.com/cslarsen/minijit.
Update: This post also made the front page of HN! Check out the discussion.
At the end of this post, we will be able to compile branchless Python functions that perform arithmetic on signed 64-bit values:
>>> from jitcompiler import *
>>> @jit
... def foo(a, b):
... return a*a - b*b
...
--- Installing JIT for <function foo at 0x100c28c08>
>>> foo(2, 3)
--- JIT-compiling <function foo at 0x100c28c08>
-5
>>> foo(3, 4)
-7
>>> print(disassemble(foo))
0x100b1d000 48 89 fb mov rbx, rdi
0x100b1d003 48 89 f8 mov rax, rdi
0x100b1d006 48 0f af c3 imul rax, rbx
0x100b1d00a 50 push rax
0x100b1d00b 48 89 f3 mov rbx, rsi
0x100b1d00e 48 89 f0 mov rax, rsi
0x100b1d011 48 0f af c3 imul rax, rbx
0x100b1d015 48 89 c3 mov rbx, rax
0x100b1d018 58 pop rax
0x100b1d019 48 29 d8 sub rax, rbx
0x100b1d01c c3 ret
Our strategy is to translate Python bytecode to an intermediate representation, which will then be optimized before being emitted as x86-64 machine code. So the first part will be to understand how the Python bytecode works.
You can see the raw bytecode for the foo function at the top in Python 3 by
typing
>>> foo.__code__.co_code
b'|\x00|\x00\x14\x00|\x01|\x01\x14\x00\x18\x00S\x00'
In Python 2.7, that would be
>>> foo.func_code.co_code
'|\x00\x00|\x00\x00\x14|\x01\x00|\x01\x00\x14\x18S'
Because the two bytecode sequences are near identical, it doesn't matter which one will be used for the explanation. I've picked Python 2.7 for the remainder of this post, but the GitHub code supports both 2.7 and 3+.
Let's have a look at the disassembly of foo.
>>> import dis
>>> dis.dis(foo)
2 0 LOAD_FAST 0 (a)
3 LOAD_FAST 0 (a)
6 BINARY_MULTIPLY
7 LOAD_FAST 1 (b)
10 LOAD_FAST 1 (b)
13 BINARY_MULTIPLY
14 BINARY_SUBTRACT
15 RETURN_VALUE
The leftmost number 2 is the Python source code line number. The next column
contains the bytecode offsets. We clearly see that the LOAD_FAST instruction
takes three bytes: One for the opcode (which instruction it is) and two for a
16-bit argument. That argument is zero, referring to the first function
argument a.
CPython — like the JVM, CLR, Forth and many others – is implemented as a stack
machine. All the bytecode instructions operate on a stack of
objects. For example, LOAD_FAST will push a reference to the variable a
on the stack, while BINARY_MULTIPLY will pop off two, multiply them together
and put their product on the stack. For our purposes, we will treat the stack
as holding values.
A beautiful property of postfix systems is that operations can be serialized. For example, to compute an infix expression like
2*2 - 3*3
we have to jump back and forth, calculating products before subtracting. But in a postfix system, we need only scan forwards. For example, the above expression can be translated to Reverse Polish Notation (RPN) using the shunting-yard algorithm:
2 2 * 3 3 * -
Moving from left to right, we push 2 on the stack, then another 2. For the *
operation we pop them both off and push their product 4. Push 3 and 3, pop them
off and push their product 9. The stack will now contain 9 on the top and 4 at
the bottom. For the final subtraction, we pop them off, perform the subtraction
and push the result -5 on the stack.
Now, imagine that the expression was actually written in a programming language:
subtract(multiply(2, 2), multiply(3, 3))
The thing is, in postfix form, the evaluation order becomes explicit:
push 2
push 2
call multiply
push 3
push 3
call multiply
call subtract
The multiply and subtract functions find their arguments on the stack. For
subtract, the two arguments consist of the products 2*2 and 3*3.
The use of a stack makes it possible to execute instructions linearly, and this is essentially how stack machines operate. With that, you will probably understand most of the CPython opcodes and its interpreter loop.
We will now translate the bytecode instructions to an intermediate representation (IR). That is, in a form suitable for performing things like analysis, translation and optimization. Ours will be blissfully naive. We will forego things like single-static assignment (SSA) and register allocation for the sake of simplicity.
Our IR will consist of pseudo-assembly instructions in a list, with a faint resemblance to three address codes (TAC). For example
ir = [("mov", "rax", 101),
("push", "rax", None)]
Contrary to TAC, we put the operation first, followed by the destination and source
registers. We use None to indicate unused registers and arguments. It would
be a very good idea to use unique, abstract registers names like reg1, reg2
and so on, because it facilitates register
allocation. Out of scope.
We will reserve registers RAX and RBX for menial work like arithmetic, pushing and popping. RAX must also hold the return value, because that's the convention. The CPU already has a stack, so we'll use that as our data stack mechanism.
Registers RDI, RSI, RDX and RCX will be reserved for variables and arguments. Per AMD64 convention, we expect to see function arguments passed in those registers, in that order. In real programs, the matter is a bit more involved.
Constants in the bytecode can be looked up with co_consts:
>>> def bar(n): return n*101
...
>>> bar.func_code.co_consts
(None, 101)
We can now build a compiler that translates Python bytecode to our intermediate representation. Its general form will be
class Compiler(object):
"""Compiles Python bytecode to intermediate representation (IR)."""
def __init__(self, bytecode, constants):
self.bytecode = bytecode
self.constants = constants
self.index = 0
def fetch(self):
byte = self.bytecode[self.index]
self.index += 1
return byte
def decode(self):
opcode = self.fetch()
opname = dis.opname[opcode]
if opname.startswith(("UNARY", "BINARY", "INPLACE", "RETURN")):
argument = None
else:
argument = self.fetch()
return opname, argument
# ...
It takes some bytecode and constants, and keeps a running index of the current
bytecode position. It is wise to split the translation up into fetch and decode
steps. The fetch method simply retrieves the next bytecode, while decode
will fetch the opcode, look up its name and fetch any arguments.
We need to look up which registers holds which variable:
def variable(self, number):
# AMD64 argument passing order for our purposes.
order = ("rdi", "rsi", "rdx", "rcx")
return order[number]
The main method will look like
def compile(self):
while self.index < len(self.bytecode):
op, arg = self.decode()
if op == "LOAD_FAST":
yield "push", self.variable(arg), None
# ...
else:
raise NotImplementedError(op)
Here you can already see how we translate LOAD_FAST. We just push the
corresponding register onto the stack. So, if the function we are compiling has
one argument, the bytecode will refer to the zeroth variable. Through the
variable method, we see that this is register RDI. So it will output
("push", "rdi", "None")
The STORE_FAST instruction does the reverse. It pops a value off the stack
and stores it in the argument register:
yield "pop", "rax", None
yield "mov", self.variable(arg), "rax"
A binary instruction will pop two values off the stack. For example
# ...
elif op == "BINARY_MULTIPLY":
yield "pop", "rax", None
yield "pop", "rbx", None
yield "imul", "rax", "rbx"
yield "push", "rax", None
That's just about it. LOAD_CONST will use a special instruction for storing
immediate values (i.e., constant integers). Here is the entire method:
def compile(self):
while self.index < len(self.bytecode):
op, arg = self.decode()
if op == "LOAD_FAST":
yield "push", self.variable(arg), None
elif op == "STORE_FAST":
yield "pop", "rax", None
yield "mov", self.variable(arg), "rax"
elif op == "LOAD_CONST":
yield "immediate", "rax", self.constants[arg]
yield "push", "rax", None
elif op == "BINARY_MULTIPLY":
yield "pop", "rax", None
yield "pop", "rbx", None
yield "imul", "rax", "rbx"
yield "push", "rax", None
elif op in ("BINARY_ADD", "INPLACE_ADD"):
yield "pop", "rax", None
yield "pop", "rbx", None
yield "add", "rax", "rbx"
yield "push", "rax", None
elif op in ("BINARY_SUBTRACT", "INPLACE_SUBTRACT"):
yield "pop", "rbx", None
yield "pop", "rax", None
yield "sub", "rax", "rbx"
yield "push", "rax", None
elif op == "UNARY_NEGATIVE":
yield "pop", "rax", None
yield "neg", "rax", None
yield "push", "rax", None
elif op == "RETURN_VALUE":
yield "pop", "rax", None
yield "ret", None, None
else:
raise NotImplementedError(op)
We can now compile the foo function at the top to our IR.
>>> def foo(a, b):
... return a*a - b*b
...
>>> bytecode = map(ord, foo.func_code.co_code)
>>> constants = foo.func_code.co_consts
>>> ir = Compiler(bytecode, constants).compile()
>>> ir = list(ir)
>>>
>>> from pprint import pprint
>>> pprint(ir)
[('push', 'rdi', None),
('push', 'rdi', None),
('pop', 'rax', None),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('push', 'rsi', None),
('push', 'rsi', None),
('pop', 'rax', None),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('pop', 'rbx', None),
('pop', 'rax', None),
('sub', 'rax', 'rbx'),
('push', 'rax', None),
('pop', 'rax', None),
('ret', None, None)]
Wow, that sure is a lot of stack operations!
We're going to perform peephole optimizations on our IR. Such optimizations work on only a few instructions at at time, and translate them equivalent but better code. We will go for fewer instructions.
In the IR above, we see an obvious improvement. Instructions like
push rdi
pop rax
can surely be translated to
mov rax, rdi
Let's write a function for that. We'll also eliminate nonsensical instructions
like mov rax, rax.
def optimize(ir):
def fetch(n):
if n < len(ir):
return ir[n]
else:
return None, None, None
index = 0
while index < len(ir):
op1, a1, b1 = fetch(index)
op2, a2, b2 = fetch(index + 1)
# ...
# Removed no-op movs
if op1 == "mov" and a1 == b1:
index += 1
continue
# Short-circuit push x/pop y
if op1 == "push" and op2 == "pop":
index += 2
yield "mov", a2, a1
continue
index += 1
yield op1, a1, b1
Instead of showing that this actually works, we'll just throw in a few other optimizations. Just note that writing such optimizations are deceptively simple. It's very easy to do something that seem to make sense, only to see your program crash.
A construct like
mov rsi, rax
mov rbx, rsi
can surely be translated to
mov rbx, rax
so we'll add that as well:
if op1 == op2 == "mov" and a1 == b2:
index += 2
yield "mov", a2, b1
continue
Finally, the short-circuit of pop and push can be extended so that it works over one or several unrelated instructions. Take
push rax
mov rsi, rax
pop rbx
Since RAX isn't modified in mov rsi, rax, we can just write
mov rsi, rax
mov rbx, rax
We have to be careful that the middle instruction isn't a push, though. So we'll add
if op1 == "push" and op3 == "pop" and op2 not in ("push", "pop"):
if a2 != a3:
index += 3
yield "mov", a3, a1
yield op2, a2, b2
continue
There is nothing wrong with supporting an indefinite amount of middle instructions, but we won't do that here.
With these instructions, let's try to optimize the above IR. The complete optimization function is
def optimize(ir):
def fetch(n):
if n < len(ir):
return ir[n]
else:
return None, None, None
index = 0
while index < len(ir):
op1, a1, b1 = fetch(index)
op2, a2, b2 = fetch(index + 1)
op3, a3, b3 = fetch(index + 2)
if op1 == "mov" and a1 == b1:
index += 1
continue
if op1 == op2 == "mov" and a1 == b2:
index += 2
yield "mov", a2, b1
continue
if op1 == "push" and op2 == "pop":
index += 2
yield "mov", a2, a1
continue
if op1 == "push" and op3 == "pop" and op2 not in ("push", "pop"):
if a2 != a3:
index += 3
yield "mov", a3, a1
yield op2, a2, b2
continue
index += 1
yield op1, a1, b1
The IR code was
[('push', 'rdi', None),
('push', 'rdi', None),
('pop', 'rax', None),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('push', 'rsi', None),
('push', 'rsi', None),
('pop', 'rax', None),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('pop', 'rbx', None),
('pop', 'rax', None),
('sub', 'rax', 'rbx'),
('push', 'rax', None),
('pop', 'rax', None),
('ret', None, None)]
Running that through optimize yields
>>> pprint(list(optimize(ir)))
[('push', 'rdi', None),
('mov', 'rax', 'rdi'),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('push', 'rsi', None),
('mov', 'rax', 'rsi'),
('pop', 'rbx', None),
('imul', 'rax', 'rbx'),
('mov', 'rbx', 'rax'),
('pop', 'rax', None),
('sub', 'rax', 'rbx'),
('mov', 'rax', 'rax'),
('ret', None, None)]
saving us four instructions. But we still got a few spots left. The first three instructions should be optimizable. Let's run two passes on the IR:
>>> pprint(list(optimize(list(optimize(ir)))))
[('mov', 'rbx', 'rdi'),
('mov', 'rax', 'rdi'),
('imul', 'rax', 'rbx'),
('push', 'rax', None),
('mov', 'rbx', 'rsi'),
('mov', 'rax', 'rsi'),
('imul', 'rax', 'rbx'),
('mov', 'rbx', 'rax'),
('pop', 'rax', None),
('sub', 'rax', 'rbx'),
('ret', None, None)]
We've now saved seven instructions. Our optimizer won't be able to improve this
code any further. We could add even more peephole optimizations, but another
good technique would be to use a real register allocated so that we use the
full spectrum of available registers. The IR compiler could then just assign
values to unique registers like reg1, reg2 and so on, then the allocator
would choose how to populate the available registers properly. This is actually
a hot topic for research, and especially for JIT compilation because the
general problem is NP-complete.
So, we have translated Python bytecode to our IR and we have done some optimizations on it. We are finally ready to assemble it to machine code!
Our approach will be to write an assembler class that emits instructions. If we use the same name for the emitter methods as in the IR, and use the same signature for all, then we can just blindly assemble the whole IR in a short loop:
assembler = Assembler(mj.PAGESIZE)
for name, a, b in ir:
emit = getattr(assembler, name)
emit(a, b)
If the instruction is mov rax, rbx, then emit will point to assembler.mov
and the call will therefore be assembler.mov("rax", "rbx").
Let's write an assembler class. We'll copy the code for address,
little_endian and import create_block from the code in the previous
post.
class Assembler(object):
def __init__(self, size):
self.block = mj.create_block(size)
self.index = 0
self.size = size
@property
def address(self):
"""Returns address of block in memory."""
return ctypes.cast(self.block, ctypes.c_void_p).value
def little_endian(self, n):
"""Converts 64-bit number to little-endian format."""
return [(n & (0xff << i*2)) >> i*8 for i in range(8)]
def emit(self, *args):
"""Writes machine code to memory block."""
for code in args:
self.block[self.index] = code
self.index += 1
def ret(self, a, b):
self.emit(0xc3)
# ...
So calling assembler.ret(None, None) will set the first machine code byte to
0xc3. That's how retq is encoded. To find the encoding of other instructions,
I mainly used the NASM assembler. Putting the following in a file
sandbox.asm,
bits 64
section .text
mov rax, rcx
mov rax, rdx
mov rax, rbx
mov rax, rsp
I assembled it with
$ nasm -felf64 sandbox.asm -osandbox.o
(-fmacho64 for macOS) and dumped the machine code with
$ objdump -d sandbox.o
sandbox.o: file format elf64-x86-64
Disassembly of section .text:
0000000000000000 <.text>:
0: 48 89 c8 mov %rcx,%rax
3: 48 89 d0 mov %rdx,%rax
6: 48 89 d8 mov %rbx,%rax
9: 48 89 e0 mov %rsp,%rax
It seems like the 64-bit movq (which we just call mov) is encoded with
the prefix 0x48 0x89 with the source and destination registers stored in the
last byte. Digging into a few manuals, we see that they are encoded using
three bits each. We'll write a method for that.
def registers(self, a, b=None):
"""Encodes one or two registers for machine code instructions."""
order = ("rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi")
enc = order.index(a)
if b is not None:
enc = enc << 3 | order.index(b)
return enc
For the movq instruction, we can now write
def mov(self, a, b):
self.emit(0x48, 0x89, 0xc0 | self.registers(b, a))
The rest of the instructions are done in a similar manner, except for moving immediate (i.e., constant) values into registers.
def ret(self, a, b):
self.emit(0xc3)
def push(self, a, _):
self.emit(0x50 | self.registers(a))
def pop(self, a, _):
self.emit(0x58 | self.registers(a))
def imul(self, a, b):
self.emit(0x48, 0x0f, 0xaf, 0xc0 | self.registers(a, b))
def add(self, a, b):
self.emit(0x48, 0x01, 0xc0 | self.registers(b, a))
def sub(self, a, b):
self.emit(0x48, 0x29, 0xc0 | self.registers(b, a))
def neg(self, a, _):
self.emit(0x48, 0xf7, 0xd8 | self.register(a))
def mov(self, a, b):
self.emit(0x48, 0x89, 0xc0 | self.registers(b, a))
def immediate(self, a, number):
self.emit(0x48, 0xb8 | self.registers(a), *self.little_endian(number))
The only special thing about the last method is that we have to use a different prefix and encode the number in little-endian format.
Finally, we can tie everything together. Given the function
def foo(a, b):
return a*a - b*b
we first extract the Python bytecode, using ord to map bytes to integers, and
any constants
bytecode = map(ord, foo.func_code.co_code)
constants = foo.func_code.co_consts
Compiling to IR
ir = Compiler(bytecode, constants).compile()
ir = list(ir)
Perform a few optimization passes:
ir = list(optimize(ir))
ir = list(optimize(ir))
ir = list(optimize(ir))
Assemble to native code
assembler = Assembler(mj.PAGESIZE)
for name, a, b in ir:
emit = getattr(assembler, name)
emit(a, b)
Make the memory block executable
mj.make_executable(assembler.block, assembler.size)
We use ctypes to set the correct signature and cast the code to a callable
Python function. We can get the number of arguments with co_argcount, and we
treat input arguments as signed 64-bit integers.
argcount = foo.func_code.co_argcount
signature = ctypes.CFUNCTYPE(*[ctypes.c_int64] * argcount)
signature.restype = ctypes.c_int64
Finally,
native_foo = signature(assembler.address)
print(native_foo(2, 3))
It prints -5. Yay!
To disassemble the code, we can use the Capstone disassembler right
within Python. It's not a built-in module, so you need to install it yourself.
Or you can break into the Python process with a debugger. Here is the
disassembly for native_foo:
0x7f1133351000: mov rbx, rdi
0x7f1133351003: mov rax, rdi
0x7f1133351006: imul rax, rbx
0x7f113335100a: push rax
0x7f113335100b: mov rbx, rsi
0x7f113335100e: mov rax, rsi
0x7f1133351011: imul rax, rbx
0x7f1133351015: mov rbx, rax
0x7f1133351018: pop rax
0x7f1133351019: sub rax, rbx
0x7f113335101c: ret
You can try out different functions, for example
def bar(n):
return n * 0x101
turns into
0x7f07d16a7000: mov rbx, rdi
0x7f07d16a7003: movabs rax, 0x101
0x7f07d16a700d: imul rax, rbx
0x7f07d16a7011: ret
and
def baz(a, b, c):
a -= 1
return a + 2*b -c
becomes
0x7f13fba09000: push rdi
0x7f13fba09001: movabs rax, 1
0x7f13fba0900b: mov rbx, rax
0x7f13fba0900e: pop rax
0x7f13fba0900f: sub rax, rbx
0x7f13fba09012: mov rdi, rax
0x7f13fba09015: push rdi
0x7f13fba09016: movabs rax, 2
0x7f13fba09020: mov rbx, rax
0x7f13fba09023: mov rax, rsi
0x7f13fba09026: imul rax, rbx
0x7f13fba0902a: pop rbx
0x7f13fba0902b: add rax, rbx
0x7f13fba0902e: mov rbx, rdx
0x7f13fba09031: sub rax, rbx
0x7f13fba09034: ret
You may wonder how fast this runs. The answer is: Slow. The reason is: Because
there is inherent overhead when calling into native code with ctypes. It
needs to convert arguments and so on. I also believe (but haven't
double-checked) that it saves some registers, because per the convention, we
should have restored RBX before exiting.
But it would be interesting to compile larger functions with native function calls, loops and so on, and compare that with Python. At that point, I believe you'll see the native code going much faster.
On /r/compsci there was a comment that this really isn't just-in-time compilation until there is some mechanism that automatically swaps out a function with a compiled version. So let's try to do something about that.
A pretty obvious approach is to require a little help from the user. Use a decorator. Recall that a decorator is really just a function that gets the freshly defined object as the first argument. If we install a little closure there that remembers the original function, we can then literally compile just-in-time when it is called for the first time. Again, only the decorated functions that are actually called will be compiled to native code.
We'll start without anything:
def jit(function):
def frontend(*args, **kw):
# Just pass on the call to the original function
return function(*args, **kw)
return frontend
With this, we can mark functions that we want to be compiled:
@jit
def foo(a, b):
return a*a - b*b
So the inner frontend function then just needs to check if the function has
already been compiled. If not, compile it and install it as the local function
reference. If the compilation fails, don't swap out anything. The complete
decorator looks like this:
def jit(function):
def frontend(*args, **kw):
if not hasattr(frontend, "function"):
# We haven't tried to compile the function yet.
try:
# Compile function and set it as the active one
native, asm = compile_native(function, verbose=False)
frontend.function = native
except Exception as e:
# Oops, the compilation failed. Just fall silently back to
# the original function.
frontend.function = function
# Call either the original or compiled function with the
# user-supplied arguments
return frontend.function(*args, **kw)
# Make all calls to the decorated function go through "frontend"
return frontend
See the GitHub repository for an example program that uses this.
I believe this is good for learning, so play around a bit. Try to make a register allocator, for example. Create more peephole optimizations. Add support for calling other functions, loops.
With a decorator, you should be able to swap out class methods on the fly with compiled ones. That's exactly what Numba does, but ours is just a drop in the ocean compared to that.
While we took the approach of translating Python bytecode, another good
technique is to use the ast module to traverse the abstract syntax tree. Ben
Hoyt did exactly that in pyast64, and I strongly recommend to take a
look at his excellent code and post.