Making a simple VM interpreter in Python

By Christian Stigen Larsen
Posted 28 Jan 2015 — updated 23 Jan 2016

I was asked if I could write a simple blog post on how to make a (process) virtual machine — specifically, a stack machine. Using Python, I'm going to convince you how easy it is!

Update: Based on comments from r/Python, I've made some small code changes. Thanks to robin-gvx, bs4h and Dagur! The code in this post is now available on GitHub.

A stack machine doesn't have registers. Instead, it puts values on a stack and operate on that. Stack machines are extremely simple, but very powerful. There's a reason why Python, Java, PostScript, Forth and many other languages have chosen a stack machine as their VM.

Anyway, let's talk a little bit more about the stacks. We need an instruction pointer stack, which we'll use to store return addresses. This is so that we can call a subroutine (a function, for instance) and then jump back from where we came. We could have used self-modifying code instead, like Donald Knuth's original MIX did, but then you'd have to manage a stack yourself if you wanted recursion to work. In this post, I won't really implement subroutine calls, but they're trivial anyway (consider it an exercise).

With the data stack you get a lot of stuff for free. For instance, consider an expression like (2+3)*4. On a stack machine, the equivalent code for this expression would be 2 3 + 4 *: Push two and three on the stack, then the next instruction is +, so pop off two numbers and push back the sum. Next, push four on the stack, then pop two values and push their product. Easy!

So, let's start writing a simple class for a stack. We'll just inherit from collections.deque:

from collections import deque

class Stack(deque):
    push = deque.append

    def top(self):
        return self[-1]

We can now push and pop to this stack and peek at the top value with top.

Next, let's make a class for the virtual machine itself. As noted, we need two stacks and also some memory for the code itself. We'll rely on Python's dynamic typing so we can put anything in the list. Only problem is that we really don't discern between strings and built-in functions. The correct way would be to insert actual Python functions in the list. Perhaps I'll do that in a future update.

We also need an instruction pointer, pointing to the next item to execute in the code.

class Machine:
    def __init__(self, code):
        self.data_stack = Stack()
        self.return_addr_stack = Stack()
        self.instruction_pointer = 0
        self.code = code
        self.dispatch_map = {...} # see below for values

Now we'll create some convenience functions to reduce some typing:

    def pop(self):
        return self.data_stack.pop()

    def push(self, value):
        self.data_stack.push(value)

    def top(self):
        return self.data_stack.top()

We'll create a dispatch function that takes an "opcode" (we don't really use codes, just dynamic types, but you get my point) and executes it. But first, let's just create the interpreter loop:

    def run(self):
        while self.instruction_pointer < len(self.code):
            opcode = self.code[self.instruction_pointer]
            self.instruction_pointer += 1
            self.dispatch(opcode)

As you can see, it does one thing, and does it pretty well: Fetch the next instruction, increment the instruction pointer and then dispatch based on the opcode. The dispatch function is a tad more longer:

    def dispatch(self, op):
        if op in self.dispatch_map:
            self.dispatch_map[op]()
        elif isinstance(op, int):
            # push numbers on the data stack
            self.push(op)
        elif isinstance(op, str) and op[0]==op[-1]=='"':
            # push quoted strings on the data stack
            self.push(op[1:-1])
        else:
            raise RuntimeError("Unknown opcode: '%s'" % op)

Update: The dispatch_map referenced in __init__ should contain the values below. It used to be constructed in dispatch, but as EnTerr pointed out in the comments below, that is grossly inefficient.

    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":       self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":     self.drop,
        "dup":      self.dup,
        "if":       self.if_stmt,
        "jmp":      self.jmp,
        "over":     self.over,
        "print":    self.print_,
        "println":  self.println,
        "read":     self.read,
        "stack":    self.dump_stack,
        "swap":     self.swap,
    }

Basically, dispatch looks up the opcode, and sees if it's a built-in function like * or drop or dup. By the way, those are Forth words, a brilliant language that you should check out. In fact, the code you'll see here is basically a simple Forth.

Anyway, it looks up an opcode like *, sees it should call self.mul then executes it. It looks like this:

    def mul(self):
        self.push(self.pop() * self.pop())

All the other functions are like this. If it can't find the operation in the dispatch map, it will first see if it's a number. Numbers are automatically pushed on the data stack. If it's a quoted string, it will push that.

So, there you have it! Congrats!

Let's define a few more operations, then write a program using our newly designed virtual machine and p-code language:

    # Allow to use "print" as a name for our own method:
    from __future__ import print_function

    # ...

    def plus(self):
        self.push(self.pop() + self.pop())

    def minus(self):
        last = self.pop()
        self.push(self.pop() - last)

    def mul(self):
        self.push(self.pop() * self.pop())

    def div(self):
        last = self.pop()
        self.push(self.pop() / last)

    def print(self):
        sys.stdout.write(str(self.pop()))
        sys.stdout.flush()

    def println(self):
        sys.stdout.write("%s\n" % self.pop())
        sys.stdout.flush()

Let's write that print((2+3)*4) example using our virtual machine:

Machine([2, 3, "+", 4, "*", "println"]).run()

You can even try running it!

Now, let's introduce a jump operation, a go-to operation:

    def jmp(self):
        addr = self.pop()
        if isinstance(addr, int) and 0 <= addr < len(self.code):
            self.instruction_pointer = addr
        else:
            raise RuntimeError("JMP address must be a valid integer.")

It just changes the instruction pointer. Now let's look at branching:

    def if_stmt(self):
        false_clause = self.pop()
        true_clause = self.pop()
        test = self.pop()
        self.push(true_clause if test else false_clause)

This is also very straight-forward. If you wanted to add a conditional jump, you'd have to simply do test-value true-value false-value IF JMP. (As branching is performed often, many VMs provide instructions like JNE, "jump if not equal", for convenience).

Here's a program that asks the user for two numbers, then prints their sum and product:

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"Enter another number: "', "print", "read", "cast_int",
    "over", "over",
    '"Their sum is: "', "print", "+", "println",
    '"Their product is: "', "print", "*", "println"
]).run()

The over, read and cast_int operations look like this:

    def cast_int(self):
        self.push(int(self.pop()))

    def over(self):
        b = self.pop()
        a = self.pop()
        self.push(a)
        self.push(b)
        self.push(a)

    def read(self):
        self.push(raw_input())

Here's a simple program that asks the user for a number, prints if it's even or odd, then loops:

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"The number "', "print", "dup", "print", '" is "', "print",
    2, "%", 0, "==", '"even."', '"odd."', "if", "println",
    0, "jmp" # loop forever!
]).run()

Now, some exercises for you: Create call and return commands. The call will push its current address on the return stack, then call self.jmp(). The return operation should simply pop the return stack and set the instruction pointer to this value (jumps back, or returns from a call). When you've done that, you've got subroutines.

A simple parser for this language

Let's create a small language that mimics the code. We'll compile to our machine code:

import tokenize
from StringIO import StringIO

# ...

def parse(text):
    tokens = tokenize.generate_tokens(StringIO(text).readline)
    for toknum, tokval, _, _, _ in tokens:
        if toknum == tokenize.NUMBER:
            yield int(tokval)
        elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
            yield tokval
        elif toknum == tokenize.ENDMARKER:
            break
        else:
            raise RuntimeError("Unknown token %s: '%s'" %
                    (tokenize.tok_name[toknum], tokval))

A simple optimizer: Constant folding

Constant folding is an example of peephole optimization that looks for obvious pieces of code that can be precomputed at compile time. For instance, mathematical expressions involving constants, e.g. 2 3 +. That should be quite simple to implement.

def constant_fold(code):
    """Constant-folds simple mathematical expressions like 2 3 + to 5."""
    while True:
        # Find two consecutive numbers and an arithmetic operator
        for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
            if isinstance(a, int) and isinstance(b, int) \
                    and op in {"+", "-", "*", "/"}:
                m = Machine((a, b, op))
                m.run()
                code[i:i+3] = [m.top()]
                print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
                break
        else:
            break
    return code

The only problem with this approach is that we would have to update jump locations, and that is hard to do in many cases (e.g. read cast_int jmp). There are many solutions to this, but a simple work-around is to only allow jumping to named labels in the language, then resolve their actual locations after doing optimizations.

If you implement Forth words, or functions, you can do more optimizations, like removing code that is provably never used (also known as dead code elimination).

A REPL

Now we can make a simple REPL like so:

def repl():
    print('Hit CTRL+D or type "exit" to quit.')

    while True:
        try:
            source = raw_input("> ")
            code = list(parse(source))
            code = constant_fold(code)
            Machine(code).run()
        except (RuntimeError, IndexError) as e:
            print("IndexError: %s" % e)
        except KeyboardInterrupt:
            print("\nKeyboardInterrupt")

We can thus test some simple programs:

> 2 3 + 4 * println
Constant-folded 2+3 to 5
Constant-folded 5*4 to 20
20
> 12 dup * println
144
> "Hello, world!" dup println println
Hello, world!
Hello, world!

As you can see, the constant-folder seems to work great! In the first example, it optimizes away all the code down to simply 20 println.

Next steps

When you have added call and return, you can let the user define new functions. In Forth, functions are called words and they begin with a colon, a name and ends with a semicolon. For example, an integer square word would be:

: square dup * ;

You can actually try this out in a program like Gforth:

$ gforth
Gforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.
Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license'
Type `bye' to exit
: square dup * ;  ok
12 square . 144  ok

You can also add support for this by looking for : in the parser. When you find one, you need to record the name to insert the name along with an address (e.g., the current position in the code) and insert them into a symbol table. For simplicity, you could probably even do better by just inserting the whole code up to the concluding semicolon in a dictionary, e.g.:

symbol_table = {
  "square": ["dup", "*"]
  # ...
}

When you're finished parsing, you can link the program: Go through the main code and look for calls to user-defined functions in the symbol table. Whenever you find one, add that code to the end of the main code, unless it's already there. Then replace the square operation with a <address> call, where <address> is the location where the word/function was inserted.

For this to work correctly, you should consider removing the jmp instruction. Otherwise, you'll have to resolve them as well. It can work, but then you have to keep track of their references in the same order that the user wrote the program. I.e., if you want to move around subroutines, you have to be a bit careful. It's fully doable, though. You should probably also add an exit function to stop the program (perhaps with an exit-code to the OS?), so that the main code execution won't continue running into the subroutines.

Actually, a good program layout would probably be to put the main code in a subroutine itself, called main. Or whatever, you decide!

As you can see, this is a lot of fun and teaches you a lot about code generation, linking, program space layout and so on.

Even more cool things to do

You could use a Python bytecode generation library to attempt to translate the VM code to native Python bytecode. Or implement it in Java and do it on the JVM; then you'll get JITing for free!

Also, it would be cool to try to make a register machine. You could try to implement a call stack with stack frames and establish a calling convention.

Finally, if you don't like the Forth-like language defined here, you can create your own simple language that compiles down to this VM. For instance, you should be able to convert infix notation like (2+3)*4 to 2 3 + 4 * and emit code for your VM. You could allow for C-style code blocks a la { ... } so that a statement like if ( test ) { ... } else { ... } would be translated to:

<true/false test>
<address of true block>
<address of false block>
if
jmp

<true block>
<address of end of entire if-statement> jmp

<false block>
<address of end of entire if-statement> jmp

For instance,

Address  Code
-------  ----
 0       2 3 >
 3       7        # Address of true-block
 4       11       # Address of false-block
 5       if
 6       jmp      # Conditional jump based on test

# True-block
 7       "Two is greater than three."
 8       println
 9       15       # Continue main program
10       jmp

# False-block ("else { ... }")
11       "Two is less than three."
12       println
13       15       # Continue main program
14       jmp

# If-statement finished, main program continues here
15       ...

Oh, you also need to add the comparison operators != < <= > >= for this to work.

I've done some of these things in my C++ stack machine, so you can get some hints from there.

I've also turned the code here into a project called Crianza. It has more optimizations and an experimental mode to compile down to Python bytecode.

Good luck!

The complete code

The complete code is available at vm.py on github.