GPUs can crunch numbers orders of magnitudes faster than a CPU. While their original use case was only for graphics, their applications now are constantly growing. A set of tools parallel to those used for computer graphics has been developed for high performance linear algebra and differentiable programming, and these tools actually have a number of useful capabilities that extend out of the scope for which they were designed.

Why write GPU compute code in an ML framework?

Flexibility

Array abstraction makes it really fast to prototype and try new ideas.
For example to create a shader which subtracts one image from another, sum up all the resulting pixels and execute it, simply do:

(img_a - img_b).sum()

Scalability

Use tools which are regularly scaled up to massive clusters with hundreds of GPUs.

Portability

Run the exact same code on CPU or GPU. Super useful for testing and small experiments.

Ecosystem

The python / numpy / scipy ecosystem is enormous and spans almost every domain of science. Code for just about anything you might want to do is likely on tap.

Why render / rasterize graphics in an ML framework? Reasons above plus:

  • Simpler than bringing in a whole extra graphics stack
  • Avoids overhead and difficulties of passing GPU memory between different contexts
  • Autodiff comes along for the ride. Differentiable rendering is always on the menu.

Why jax?

  • numpy compatibility unlocks the largest ecosystem
  • transformations - vmap, grad, ect
  • JIT is first class citizen, which is necessary for just about anything that isn’t bound on matrix multiply.
  • Sparse arrays supported for many operations and transformations
  • CPU, GPU, TPU support
  • Easy multi-GPU via p/xmap

How do you render many sprites efficiently?

The secret is sparse arrays + JIT!

Limitations

  • NO BACKENDS SUPPORT FOR REAL-TIME RENDERING. This is the huge one. How can we make this happen?
  • Sometimes it’s tricky to get good performance

Alternatives

Taichi is another option with a lot of potential. The backends and real-time support is wonderful, but they are missing the convenience of array programming that numpy and other ML frameworks have. Autodiff support is also not as good (You can get incorrect gradients without warning if you don’t follow the correct patterns). Also some features are only supported on certain backends.
Pytorch is nice and even has experimental vulkan backend, but jit/functional code needs more emphasis. Departure from numpy also isn’t great.

Example

Run this interactive demo on colab
Notebook is also available on github

import jax.numpy as jnp
from jax import jit, vmap
from jax.experimental import sparse
import jax
from tqdm import tqdm
from IPython.display import HTML
from base64 import b64encode
import mediapy as media
from einops import reduce

“Draw” a sprite. This does not actually draw anything to a full sized image yet. It creates a sparse array which says where this particular sprite is poisitioned within a yet-to-be initialized full size image array.

def draw_single_sprite(pos, sprite, sp_width, sp_height, out_dims):
  channels = out_dims[2]
  numel = sp_width * sp_height * channels
  data = sprite.reshape(numel)
  raw_indices = jnp.indices((channels, sp_height, sp_width), dtype=jnp.int16)
  indices = jnp.flip( raw_indices.T.reshape(-1, channels), 1)
  indices = indices.at[:, 0].set( indices[:, 0] + pos[0] )
  indices = indices.at[:, 1].set( indices[:, 1] + pos[1] )
  return sparse.BCOO((data, indices), shape=out_dims)

Create an array of these sprites at each position, then trigger them all to be rasterized onto a single image array by calling “sum”.
JIT to make it fast.

def draw_all_sprites(all_pos, all_indices, sprite_sheet, canv_dims):
  draw_bound = lambda p, sprite: draw_single_sprite(
      p, sprite, sprite_sheet.shape[1], sprite_sheet.shape[2], canv_dims)
  draw_all = vmap(draw_bound, in_axes=(0))
  render = draw_all(all_pos.astype(jnp.int16), sprite_sheet[all_indices]).sum(0)
  return jnp.clip(sparse.todense(render), 0, 255)# .astype(jnp.uint8) # render

fast_draw_sprites = jit(draw_all_sprites, static_argnums=(3,))

Render at double resolution and then average down (MSAA).
This compensates for the fact that sprites being drawn at integer coordinates.

“r_scale” can be tuned for performance/quality tradeoff. It is hardcoded because there was an issue passing it to the jit as static.

def scaled_render(pos, indices, sprites, dim):
  r_scale = 4
  render_res = dim * r_scale
  img = fast_draw_sprites(
          pos * r_scale - 0.5 * sprites.shape[0], 
          indices, 
          sprites, 
          (render_res , render_res, 3)
        )
  return reduce( 
      img, "(h sh) (w sw) c -> h w c", "mean", sh=r_scale, sw=r_scale)
  
fast_scaled_render = jit(scaled_render, static_argnums=(3,))

Compute inverse square force between particles. Double the work is done computing the symetric distance matrix (each pair of particles has their distance computed twice) becauses it’s very easy to vectorize this way. Potential room for improvement.

def compute_forces(pos, scale, eps=0.1):
  a, b = jnp.expand_dims(pos, 1), jnp.expand_dims(pos, 0)
  diff = a - b
  dist = (diff * diff).sum(axis=-1) ** 0.5
  dist = jnp.expand_dims(dist, 2)
  force = diff / ((dist * scale) ** 3 + eps)
  return force.sum(0)

fast_compute_forces = jit(compute_forces, static_argnames=("eps"))

Integrate particle positions and velocities using Euler method.
JIT into a nice cozy burrito.

def sim_update_force(
    parts_pos, parts_vel, 
    t_delta=0.05, scale=5, 
    repel_mag=0.1, center_mag=2.5,
    steps=10, damp=0.99):

  p_p = jnp.array(parts_pos)
  p_v = jnp.array(parts_vel)
  # jax.experimental.loops
  for _ in range(steps):
    p_p = p_p + t_delta * p_v
    force = fast_compute_forces(p_p, scale)
    center_diff = p_p
    centering_force = center_diff / ((center_diff ** 2).sum() ** 0.5)
    p_v = damp * p_v - t_delta * (force * repel_mag + centering_force * center_mag)
  return p_p, p_v

fast_sim_update_force = jit(sim_update_force, static_argnames=("steps", "scale"))

We could render any sprites we want, but for this example we’ll just render each particle as nice smooth gaussian circle.

def gaussian_kern(kernlen, nsig):
  """Returns a 2D Gaussian kernel."""
  x = jnp.linspace(-nsig, nsig, kernlen+1)
  kern1d = jnp.diff(jax.scipy.stats.norm.cdf(x))
  kern2d = jnp.outer(kern1d, kern1d)
  return kern2d/kern2d.sum()

A helper function to run a simulation and render it to a video. Nice default parameters for easy tweaking.

def generate_video(
    name="test_parts.mp4", p_count=800, 
    sprite_dim=27, vid_dim=300, brightness=80,
    t_delta=0.05, scale=25, center_mag=0.5, 
    repel_mag=0.05, damp=0.997, 
    total_steps=500, steps=4, seed=144):
  
  key = jax.random.PRNGKey(seed)
  p_state = jax.random.uniform(key, (p_count, 2), minval=-0.5, maxval=0.5)
  v_state = jnp.zeros((p_count, 2))
  sprite_indices = jnp.zeros((p_count,), dtype=int)
  gaussian_sprites = jnp.tile(
      gaussian_kern(
          sprite_dim, 3.0
      ).reshape(1, sprite_dim, sprite_dim, 1), 3) * brightness

  with media.VideoWriter(
      name, (vid_dim, vid_dim), crf=30, fps=45) as vw:

    for i in tqdm(range(total_steps)):
      render = fast_scaled_render(
        (p_state * 0.9 + 0.5) * vid_dim,
        sprite_indices, 
        gaussian_sprites, 
        vid_dim
      )
      p_state, v_state = fast_sim_update_force(
          p_state, v_state, 
          t_delta=t_delta, scale=scale, 
          center_mag=center_mag, repel_mag=repel_mag, 
          damp=damp,
          steps=steps
      )
      vw.add_image(render)

Generate a video!
Included are a couple extra examples demonstrating different rendering and sim parameters.

generate_video()

# blob
#generate_video(p_count=16000, vid_dim=1024, sprite_dim=21, scale=150, center_mag=0.25, damp=0.998, steps=4, brightness=40, total_steps=1000)

# galaxy
#generate_video(p_count=16000, vid_dim=1024, sprite_dim=21, scale=150, center_mag=0.0, repel_mag=-0.01, damp=1, steps=4, brightness=15, total_steps=1000)

mp4 = open('test_parts.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=300 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)
100%|██████████| 500/500 [00:02<00:00, 176.34it/s]
!nvidia-smi
Thu Jul 21 02:58:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   59C    P0    28W /  70W |  13660MiB / 15109MiB |     19%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

How can this be run on a GPU backend (ideally non-cuda) that supports real-time rendering?

Jax uses XLA to optimize code and target various backends. It has it’s own IR format called “HLO” (High Level Operations). There is a guide for implementing new backends. XLA only targets CUDA via LLVM and PTX. This is great but not good enough. Either an emulator/translator for PTX is needed (nvidia has an interest that this is difficult to do) or something that can compile XLA’s HLO to a highly portable backend like vulkan would be ideal. Not all functions in the HLO IR instruction set are needed. Some example HLO for code similar to above is dumped below.

HloModule xla_computation_draw_points.130

%region_1.83 (Arg_0.84: f32[], Arg_1.85: f32[]) -> f32[] {
  %Arg_0.84 = f32[] parameter(0)
  %Arg_1.85 = f32[] parameter(1)
  ROOT %add.86 = f32[] add(f32[] %Arg_0.84, f32[] %Arg_1.85)
}

%input_fused_computation_scatter (param_0.1: pred[4096,16], param_1.3: s16[4096,16,2]) -> f32[64,64] {
  %constant_7 = f32[] constant(0)
  %broadcast.11 = f32[64,64] broadcast(f32[] %constant_7), dimensions={}
  %param_1.3 = s16[4096,16,2] parameter(1)
  %bitcast.14 = s16[65536,2] bitcast(s16[4096,16,2] %param_1.3)
  %slice.1 = s16[65536,1] slice(s16[65536,2] %bitcast.14), slice={[0:65536], [0:1]}
  %bitcast.13 = s16[65536] bitcast(s16[65536,1] %slice.1)
  %constant_5 = s16[] constant(0)
  %broadcast.10 = s16[65536] broadcast(s16[] %constant_5), dimensions={}
  %compare.1 = pred[65536] compare(s16[65536] %bitcast.13, s16[65536] %broadcast.10), direction=LT
  %constant_3 = s16[] constant(64)
  %broadcast.9 = s16[65536] broadcast(s16[] %constant_3), dimensions={}
  %add.1 = s16[65536] add(s16[65536] %bitcast.13, s16[65536] %broadcast.9)
  %select.3 = s16[65536] select(pred[65536] %compare.1, s16[65536] %add.1, s16[65536] %bitcast.13)
  %convert.3 = s32[65536] convert(s16[65536] %select.3)
  %bitcast.12 = s32[65536,1] bitcast(s32[65536] %convert.3)
  %slice.0 = s16[65536,1] slice(s16[65536,2] %bitcast.14), slice={[0:65536], [1:2]}
  %bitcast.11 = s16[65536] bitcast(s16[65536,1] %slice.0)
  %compare.0 = pred[65536] compare(s16[65536] %bitcast.11, s16[65536] %broadcast.10), direction=LT
  %add.0 = s16[65536] add(s16[65536] %bitcast.11, s16[65536] %broadcast.9)
  %select.2 = s16[65536] select(pred[65536] %compare.0, s16[65536] %add.0, s16[65536] %bitcast.11)
  %convert.2 = s32[65536] convert(s16[65536] %select.2)
  %bitcast.10 = s32[65536,1] bitcast(s32[65536] %convert.2)
  %concatenate.0 = s32[65536,2] concatenate(s32[65536,1] %bitcast.12, s32[65536,1] %bitcast.10), dimensions={1}
  %param_0.1 = pred[4096,16] parameter(0)
  %constant_1 = f32[] constant(1)
  %broadcast.8 = f32[4096,16] broadcast(f32[] %constant_1), dimensions={}
  %broadcast.6 = f32[4096,16] broadcast(f32[] %constant_7), dimensions={}
  %select.1 = f32[4096,16] select(pred[4096,16] %param_0.1, f32[4096,16] %broadcast.8, f32[4096,16] %broadcast.6)
  %bitcast.9 = f32[65536] bitcast(f32[4096,16] %select.1)
  ROOT %scatter.0 = f32[64,64] scatter(f32[64,64] %broadcast.11, s32[65536,2] %concatenate.0, f32[65536] %bitcast.9), update_window_dims={}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%region_1.83
}

%region_0.52 (Arg_0.53: pred[], Arg_1.54: pred[]) -> pred[] {
  %Arg_0.53 = pred[] parameter(0)
  %Arg_1.54 = pred[] parameter(1)
  ROOT %and.55 = pred[] and(pred[] %Arg_0.53, pred[] %Arg_1.54)
}

%fused_computation (param_0.16: f32[4096,2]) -> pred[4096,16] {
  %iota.2 = s32[4] iota(), iota_dimension=0
  %convert.10 = f32[4] convert(s32[4] %iota.2)
  %broadcast.39 = f32[4096,4] broadcast(f32[4] %convert.10), dimensions={1}
  %param_0.16 = f32[4096,2] parameter(0)
  %slice.7 = f32[4096,1] slice(f32[4096,2] %param_0.16), slice={[0:4096], [0:1]}
  %bitcast.23 = f32[4096] bitcast(f32[4096,1] %slice.7)
  %broadcast.36 = f32[4096,4] broadcast(f32[4096] %bitcast.23), dimensions={0}
  %add.7 = f32[4096,4] add(f32[4096,4] %broadcast.39, f32[4096,4] %broadcast.36)
  %broadcast.34 = f32[4096,4,4,1] broadcast(f32[4096,4] %add.7), dimensions={0,2}
  %slice.6 = f32[4096,1] slice(f32[4096,2] %param_0.16), slice={[0:4096], [1:2]}
  %bitcast.22 = f32[4096] bitcast(f32[4096,1] %slice.6)
  %broadcast.31 = f32[4096,4] broadcast(f32[4096] %bitcast.22), dimensions={0}
  %add.6 = f32[4096,4] add(f32[4096,4] %broadcast.39, f32[4096,4] %broadcast.31)
  %broadcast.29 = f32[4096,4,4,1] broadcast(f32[4096,4] %add.6), dimensions={0,1}
  %concatenate.3 = f32[4096,4,4,2] concatenate(f32[4096,4,4,1] %broadcast.34, f32[4096,4,4,1] %broadcast.29), dimensions={3}
  %convert.9 = s16[4096,4,4,2] convert(f32[4096,4,4,2] %concatenate.3)
  %bitcast.21 = s16[4096,16,2] bitcast(s16[4096,4,4,2] %convert.9)
  %convert.4 = s32[4096,16,2] convert(s16[4096,16,2] %bitcast.21)
  %constant_14 = s32[] constant(64)
  %broadcast.13 = s32[4096,16,2] broadcast(s32[] %constant_14), dimensions={}
  %compare.2 = pred[4096,16,2] compare(s32[4096,16,2] %convert.4, s32[4096,16,2] %broadcast.13), direction=LT
  %constant_12 = pred[] constant(true)
  ROOT %reduce.0 = pred[4096,16] reduce(pred[4096,16,2] %compare.2, pred[] %constant_12), dimensions={2}, to_apply=%region_0.52
}

%fused_computation.2 (param_0.19: s32[2,1], param_1.25: f32[4096,2]) -> s16[4096,16,2] {
  %iota.4 = s32[4] iota(), iota_dimension=0
  %convert.14 = f32[4] convert(s32[4] %iota.4)
  %broadcast.49 = f32[4096,4] broadcast(f32[4] %convert.14), dimensions={1}
  %param_1.25 = f32[4096,2] parameter(1)
  %slice.11 = f32[4096,1] slice(f32[4096,2] %param_1.25), slice={[0:4096], [0:1]}
  %bitcast.29 = f32[4096] bitcast(f32[4096,1] %slice.11)
  %broadcast.48 = f32[4096,4] broadcast(f32[4096] %bitcast.29), dimensions={0}
  %add.11 = f32[4096,4] add(f32[4096,4] %broadcast.49, f32[4096,4] %broadcast.48)
  %broadcast.47 = f32[4096,4,4,1] broadcast(f32[4096,4] %add.11), dimensions={0,2}
  %slice.10 = f32[4096,1] slice(f32[4096,2] %param_1.25), slice={[0:4096], [1:2]}
  %bitcast.28 = f32[4096] bitcast(f32[4096,1] %slice.10)
  %broadcast.46 = f32[4096,4] broadcast(f32[4096] %bitcast.28), dimensions={0}
  %add.10 = f32[4096,4] add(f32[4096,4] %broadcast.49, f32[4096,4] %broadcast.46)
  %broadcast.45 = f32[4096,4,4,1] broadcast(f32[4096,4] %add.10), dimensions={0,1}
  %concatenate.5 = f32[4096,4,4,2] concatenate(f32[4096,4,4,1] %broadcast.47, f32[4096,4,4,1] %broadcast.45), dimensions={3}
  %convert.13 = s16[4096,4,4,2] convert(f32[4096,4,4,2] %concatenate.5)
  %bitcast.27 = s16[4096,16,2] bitcast(s16[4096,4,4,2] %convert.13)
  %param_0.19 = s32[2,1] parameter(0)
  ROOT %gather.0 = s16[4096,16,2] gather(s16[4096,16,2] %bitcast.27, s32[2,1] %param_0.19), offset_dims={0,1}, collapsed_slice_dims={2}, start_index_map={2}, index_vector_dim=1, slice_sizes={4096,16,1}
}

ENTRY %main.89 (Arg_0.1: f32[4096,2]) -> (f32[64,64]) {
  %Arg_0.1 = f32[4096,2] parameter(0)
  %fusion = pred[4096,16] fusion(f32[4096,2] %Arg_0.1), kind=kLoop, calls=%fused_computation
  %constant_9 = s32[2,1] constant({ {0}, {1} })
  %fusion.2 = s16[4096,16,2] fusion(s32[2,1] %constant_9, f32[4096,2] %Arg_0.1), kind=kLoop, calls=%fused_computation.2
  %input_fusion_scatter = f32[64,64] fusion(pred[4096,16] %fusion, s16[4096,16,2] %fusion.2), kind=kInput, calls=%input_fused_computation_scatter
  ROOT %tuple.88 = (f32[64,64]) tuple(f32[64,64] %input_fusion_scatter)
}