Reinforcing Learning: Using Mojoš„ to learn 3x faster
Rewriting a Monte Carlo RL approach in Mojo for a 3x speed increase, with some drawbacks...
At the minute Iām spending my evenings on reinforcement learning, working through the canonical textbook - Barto & Sutton. The highlights so far have been the programming exercises: youāre given a problem, work through and see the theory actually work.
The Problem
The latest problem I worked on (5.12) involved getting a simple little car to drive down a course. Here the red line is the start, green is the finish, and yellow is the track. If the car runs off course itāll be set back to the start line. Sounds like a pretty easy job - but if we try it out with a random approach (just trying to move in any direction at each time step) it doesnāt work wellā¦
The Solution
So instead, we need to come up with an approach (referred to as a policy) for our car that will let it make proper decisions - ie. āif Iām coming up on the turn, itās time to slow down vertically and start pushing right.ā For this Iāll be using a Monte Carlo approach, where we simulate many episodes - I did 50 million - and use the results of each episode to update our policy.
And it works pretty well! Our car has learned to navigate this track like a heat-seeking missile, bombing straight for the finish line each time from anywhere on the start line.
The Pain
A well-adjusted person would look at their car happily making its way and call that job done, but I was irritated. Getting to this result took training over 50 million episodes. That took 3 hours to run on my beefy M2 Macbook Pro. Iām learning on the side, so thatās an evening gone, plus a Macbook thatās hot enough to fry an egg in the meantime.
So I decided to test out the new Python-esque language Mojo. The Mojo team have advanced some hefty claims about the languageās speed (eg. this one outpacing Rust). Iām an RL noob and an optimisation idiot, so Iām not going to make any sweeping claims - but I wanted to give it a try and see how I got on.
The Challenger
Mojo is a new AI-first language spearheaded by Chris Lattner of Swift/LLVM. The language is meant to be a superset of Python which supports the same syntax. Currently, the syntax is similar to Python, with caveats. Take, for example, building a dictionary:
# Python
value_dict = {" ": 0, "*": 1, "S": 2, "F": 3, "C": 4}
# Mojo
var value_dict = Dict[String, Int8]()
value_dict[" "] = 0
value_dict["*"] = 1
value_dict["S"] = 2
value_dict["F"] = 3
value_dict["C"] = 4
Mojo is immediately more involved. While Python handles variable declarations implicity, Mojo uses a var
keyword. (I donāt get whyā¦ var is used for all declarations).
Mojo is also strongly-typed. Our Mojo dictionary here will only accept String-type keys and Int8-type values. A Python dicitionary, by contrast, will let you map basically anything to anything, even in the same dict. Thatās flexible, but also means the compiler canāt make any optimising assumptions - everything just gets hashed out at runtime.
Youāll also notice that a Mojo dict needs every element to be added individually, rather than using a simple Python-style instantiation. That kind of jank is typical of current Mojo: itās a nascent language with regular breaking changes, missing functionality (weāll get to some of that) and bugs.
The Solution Part 2: Good, Bad and Ugly
So I buckled down and spent about 6 hours rewriting
The Ugly
As someone used to Python and its stack (Numpy, Pandas etcā¦), rewriting in Mojo was mostly comfortable, but there were real pain points. Easy Python oneliners had to be properly considered and implemented from scratch, and tested for correctness. Here are a few examples:
Converting an ASCII track representation to a 2D array
In Python, this took a list comprehension passed into Numpy. In Mojo, that logic had to be written using a nested for loop, with a properly tuned index to translate between strings and tensor entries. There were some unintuitive rubs too - like needing to cast the indices into an Index option to index into a Mojo array.
Reading in an array
# Python
with open(filepath, "r") as f:
start_grid = np.array(
[[value_dict[c] for c in line.strip()] for line in f.readlines()]
)
start_grid = np.flip(start_grid, axis=0)
# Mojo
with open(filepath, "r") as f:
var grid_text = f.read()
var grid_split = grid_text.splitlines()
var len_y = len(grid_split)
var len_x = len(grid_split[0])
var track = Tensor[DType.int8].rand(TensorShape(len_y, len_x))
for i_y in range(len_y):
for i_x in range(len_x):
var str_value = grid_split[len_y - i_y - 1][i_x]
track[Index(i_y, i_x)] = value_dict[str_value]
Python: A list comprehension with a numpy utility function flip
Mojo: Traversing the input/output arrays manually and filling in an empty array element-by-element
Generating random values
Mojo doesnāt have a function you can call to get a random int. Instead, the workflow is to assign a pointer, fill the address with a random number, and then pull that number back out - surprisingly esoteric for a common operation.
Generating a random int
# Python
>>> np.random.randint(low=0, high=5)
3
# Mojo
>>> var p1 = DTypePointer[DType.int8].alloc(1)
>>> randint[DType.int8](p1, size=1, low=0, high=5)
>>> p1[0]
3
Tensor slicing
Mojo has a builtin Tensor type, which makes sense for an AI-first language without a good community alternative/package manager. Unfortunately, the Tensor type is missing a lot of the usual bells and whistles youād expect from a Numpy or Pytorch. Slicing is a good example. In the code below, Numpy uses 4-D coordinates to index into a 5-D array, and assigns the argmax to a single element in a 4-D array. In Mojo, that becomes an 8-line from-scratch implementation of the same process, with less elegant syntax.
# Python
pi[S_t] = np.argmax(Q[S_t]) # S_t is a 4-tuple of coordinates
# Mojo
var max_val: Float32 = -500
var argmax: Int8 = -1
for a in range(9):
var Q_val = Q[Index(S_t[0], S_t[1], S_t[2], S_t[3], a)]
if Q_val > max_val:
max_val = Q_val
argmax = i
pi[Index(S_t[0], S_t[1], S_t[2], S_t[3])] = argmax
Globals
Mojo canāt handle expressions at the global level, which can make it awkward to use a global state. Itās not a big deal, but in Python I can pass a few relevant arguments and mutate a global state, whereas Mojo needs everything to be explicitly passed.
# Python
def simulate_episode(
policy: np.ndarray,
max_steps: int = 500,
noise_prob: float = 0.1,
eps: float = 0.05,
) -> bool:
...
terminated = simulate_episode(policy=pi)
# Mojo
fn simulate_episode[max_steps: Int = 500](
policy: Tensor[DType.int8],
noise_prob: Float16 = 0.1,
eps: Float16 = 0.05,
inout state: SIMD[DType.int8, 4],
inout states_hist: InlineList[SIMD[DType.int8, 4], max_steps],
inout actions_hist: InlineList[Int8, max_steps],
inout t: Int,
y_max: Int,
x_max: Int,
track: Tensor[DType.int8],
value_dict: Dict[String, Int8],
start_coords: Tuple[Int8, Int8],
finish_bounds: Tuple[Int8, Int8, Int8]
) raises -> Bool:
...
var terminated = simulate_episode[max_steps=500](policy, state, states_hist, actions_hist, T, y_max, x_max, track, value_dict, start_coords, finish_bounds, eps=eps)
The Bad
I only ran into 1 real bug while working on this problem - iffy rounding. Itās not a big deal, but it did lose me about half an hour of confusion when I was trying to benchmark performance.
# Python
>>> round(1.34, 1)
1.3
# Mojo
>>> round(1.34, 1)
1.0
The Good
Thatās been a lot of gripes, but in the end Mojo did deliver what I wanted: a 3.2x speedup compared to my usual Python/Numpy stack. What took 3 hours to run in Python was just shy of an hour in Mojo. Whatās more, CPU/RAM utilisation Mojo was lower (no egg-frying heat), so there could be probably be more performance to squeeze out of it with parallelisation and other optimisations.
Now, thanks to the sharp edges, my rewrite did take like 6 hours, so it was a net loss for this one example. That said, if this was a common process, or it was a big training job, or any other number of factors was different, this kind of speed-up would be a no-brainer to go for. My hope is that as Modular polishes their language, the performance benefit will stay high as the drawbacks diminish. The language is seeing a lot of updates at the minute, so I reckon itāll get there sooner rather than later.
The Mojo code I used for training is available in this repo.