Simulating 1,000 of Conway’s Game of Life with JAX
Recently, I have been exploring Google’s JAX library, a high-performance numerical computing framework for scientific computing and machine learning research. It is similar to NumPy in syntax but runs natively on GPUs and TPUs, enabling significant performance gains. Google uses JAX to train their large language models, including Gemma and Gemini. [1]
Starting with the basics, I wanted to run a simple simulation to test out this library. Most simulations consist of a set of predefined rules, an initial state, interacting elements, and output data. I came across the well-known Conway’s Game of Life, a simple cellular automaton that runs on a 2D grid. The game is quite simple: it follows 4 predefined rules, and each cell can be in one of two states: alive or dead. [2]
Rule | Outcome |
|---|---|
| A live cell with fewer than two neighbors | Dies from isolation |
| A live cell with more than three neighbors | Dies due to overcrowding |
| A dead cell with exactly three neighbors | Becomes alive |
| A live cell with two or three neighbors | Continues to the next generation (survives) |
JAX has some cool functions that I will use in this simulation, including Just-in-time compilation (jit), auto-vectorization (vmap), and sequential looping (scan) to simulate the system over time. I will discuss the implementation of each of these functions in the context of Conway’s Game of Life, and while doing so, I’ll also analyze the evolution of 1,000 randomly initialized instances.
Simulation Setup
Parameter | Value / Description |
|---|---|
| Number of Instances | 1,000 |
| Grid Size | 64 × 64 cells |
| Time Steps | 1,000 generations |
| Initial State | Bernoulli distribution, p = 0.5 (50% chance each cell is alive) |
| Boundary Conditions | Toroidal (wrap-around edges using jnp.pad(..., mode="wrap")) |
| Data Type | jnp.int8 (memory-efficient representation of alive/dead cells) |
| Hardware Used | NVIDIA T4 GPU instance with 15 GB vRAM |
Core Logic
The fundamental operation is to calculate the next state of the grid at each time step. For each iteration, we count the eight neighbors of each cell. To do that we can use the JAX’s convolve2d function with a 3x3 kernel of ones, except with a zero in the center. This kernel sums the neighbors for every cell in a single operation. [3]
def step(grid):
"""
Performs a single step of the Conway's Game of Life simulation.
Args:
grid: A 2D JAX numpy array representing the current state of the grid.
1 represents a living cell, 0 represents a dead cell.
Returns:
A 2D JAX numpy array representing the grid after one step of the simulation.
"""
kernel = jnp.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype = jnp.int8)
padded_grid = jnp.pad(grid, pad_width=1, mode='wrap')
number_of_neighbors = jax.scipy.signal.convolve2d(padded_grid, kernel, mode='valid')
########### Building Conway Game of Life Rules ###########
is_alive = (grid == 1)
survives = (is_alive & ((number_of_neighbors == 2) | (number_of_neighbors == 3)))
is_dead = (grid == 0)
reproduces = (is_dead & (number_of_neighbors == 3))
new_grid = (survives | reproduces).astype(jnp.int8)
return new_grid
The rules are implemented as array-wide boolean operations to be used as masks to identify cells that would survive or be born in the next generation.
Parallelization
To scale from a single grid to thousands, I used JAX’s vmap to transform the “step” function which is designed for 2D array into a new function that process 3D batch of arrays with the shape (num_simulations, height, width) in parallel.
parallelize = jax.vmap(step)
Simulation Over Time
To simulate the system over time, I used JAX’s scan function. This function efficiently applies a function repeatedly, passing the output of one iteration as the input to the next. This allows us to run the simulation for thousands of steps with minimal overhead.
@jax.jit
def run_simulation(initial_state, steps=1000):
final_states, history = jax.lax.scan(lambda state, _: (parallelize(state), parallelize(state)), initial_state, None, length=steps)
return final_states, history
High-Performance Compilation
As shown in the previous code snippet, the entire simulation function is decorated with ‘@jax.jtt’ This compiler will run the sequence of operations together using XLA machine code which is very efficient and optimized for GPU. The first time the JIT function is invoked, it takes some time to compile. However, subsequent calls are fast because the compiled XLA code gets cached.
final_grid_states, all_history = run_simulation(initial_state)
# The first run will be slow because of JIT compilation. Subsequent runs will be instant.
final_grid_states.block_until_ready()
Logs of Problems Encountered
Memory overflow
Problem: My initial attempt was to create 1,000,000 instances, but that immediately crashed the system.
Root Cause: JAX’s default dtype is int32
Solution: Used a more memory-efficient datatype such as int8, reducing the memory footprint by 75%.
Convolve2D function limitation
Problem: The simulation required wrap-around (toroidal) behavior, which wasn’t working as expected.
Root Cause: JAX convolve2d does not have boundary= "wrap" by default.
Solution: Implemented a wrapping logic manually using jnp.pad(grid, pad_width = 1, mode = "wrap") before convolution and changed the mode in the convolve2d function to mode = "valid".
JIT Concretization Error
Problem: A ConcretizationTypeError was raised during compilation.
Root Cause: JAX requires loop lengths to be static at compile time, which makes total sense but mistakenly I tried to pass a dynamic loop length.
Solution: I considered using static_argnums=1 inside the jax.jit decorator, each different steps parameters would have led the function to recompile each time, but I decided to keep the steps defined in the function itself to avoid recompilation.
Code Playground
To play around with the code, you can find the complete code available on Google Colab.
Simulation Results & Analysis
Before wrapping up, let’s analyze the results of the 1,000 simulated instances. Do the results of this simulation match the known observed behaviour in Conway’s Game of Life? Apparently, Game of Life favors simplicity. Randomness majorly resolves into static patterns and simple oscillators, and only very rarely it resolves into complex patterns called spaceships. Let’s take a closer look.
Population Dynamics
A plot of the average population (number of cells alive) over time reveals a sharp initial decline.
Starting from random 50% initial density, the number of living cells drops significantly within the first 100-200 steps before almost leveling out. This shows that the initial states in the Game of Life are unstable and rapidly simplify.
Final States
After 1,000 steps or generations, the final states of each instance can be categorized as extinct, stable or active.
Based on the previous graph, the vast majority of the instances are still changing and in a dynamic state. No instances died out completely. Meanwhile, only two instances froze into a still life pattern.
The initial results, given the simulation parameters used, suggest that true stability is rare and most Game of Life instances remain in a dynamic state. The next step would be to analyze the active universes.
Patterns: Basic States and Period-2 Oscillators
Stable Instances
Out of 1,000 instances, only two became stable. Many of still life patterns as well as period-2 oscillators can be found in the results. In the code playground, you will find kernel definitions for each of these patterns along with the results.
Active Instances
In the experiment, I observed 998 active instances. Now the main question is: are these active instances chaotic, or do they show certain patterns? One of the common patterns in active Game of Life instances is Period-2 Oscillators or also called “Blinkers”.
Out of the active instances, 495 active instances settles into simple oscillators. To find these, we apply two conditions to the last 3 states of the simulation data for all active instances. The first condition is whether the grid is in the same state it was 2 steps ago. The second condition is if the grid is different from the previous step. If these two conditions are met then we have an oscillator.
# Condition 1: The grid returned to its state from 2 steps ago.
returned_to_state = jnp.sum(jnp.abs(state_t_final - state_t_minus_2), axis=(-2, -1)) == 0
# Condition 2: The grid was different from the previous step (i.e., not a still life).
was_not_stable = jnp.sum(jnp.abs(state_t_final - state_t_minus_1), axis=(-2, -1)) > 0
# A universe is a period-2 oscillator if both conditions are true
is_period_2_oscillator = returned_to_state & was_not_stable
count_p2 = jnp.sum(is_period_2_oscillator)
The following graph is few examples from the active instances that settled into oscillators.
This shows that nearly half of the active instances are not chaotic anymore but settles into “blinkers” or period-2 oscillators. So the universes do not just either die or stabilize, they can remain active in predictable, periodic loops.
Patterns: Any Spaceships Discovered?
Period-4 Oscillators, also called spaceships due to their shapes, are more complex structures that can spontaneously occur when the simulations run for a very high number of steps/generations. However, given that each of the 1,000 Game of Life instances was run for only 1,000 steps, I did not get such a complex pattern.
True complexity and information-carrying patterns like spaceships are fragile. They are unlikely to arise spontaneously from rando, noise and likely require more specific, less dense starting conditions or a much larger grid to survive.
What’s Next?
→ Scaling the experiment and modifying the initial conditions.
References
Enjoyed reading this article?
Here are some more articles you might like to read next: