JAX, aka NumPy on steroids

JAX, aka NumPy on steroids

14, Jan Simone Scardapane / trad. Stefano Di Pietro

In the age of the 'big ones' (TensorFlow, PyTorch, ...), introducing and studying a new machine learning library might seem counterproductive. Yet JAX, a brand new research project by Google, has several features that make it interesting to a large audience. Firstly, it looks like a NumPy wrapper, making the transition from this library almost immediate. Secondly, it makes efficiency one of its strengths, thanks to the transparent use of XLA, a linear algebra accelerator originally developed for TensorFlow. Finally, and it is perhaps the most intriguing novelty, it is one of the first libraries whose soul is purely functional.

JAX - Comparing with the Big Ones

Basically, every neural networks library on the market can be categorised based on four key elements:

  1. The way it allows to write and manipulate operations on tensors (e.g. computational graphs vs. eager execution);
  2. Tools and methods for automatic gradient calculation (e.g. higher order gradients);
  3. The ability to speed up the code on GPUs, distributed systems, etc;
  4. The availability of high-level modules to build and train neural networks.

These four points lead to an enormous differentiation in the ecosystem: Keras, for example, was originally thought to be almost completely focused on point (4), leaving the other tasks to a backend engine. In 2015, on the other hand, Autograd focused on the first two points, allowing users to write code using only "classic" Python and NumPy constructs, providing subsequently many options for point (2). Autograd's simplicity greatly influenced the development of the libraries to follow, but it was penalized by the clear lack of the points (3) and (4), i.e. adequate techniques to speed up the code and sufficiently abstract modules for neural network development.

JAX is fundamentally Autograd 2.0: it takes up its whole philosophy, improving it with different acceleration techniques on GPU/TPU and with small high-level libraries for models prototyping and optimization. It can therefore be of great interest to different groups of users: for those who only want to speed up NumPy code as well as for those who love to develop "low level" without moving too much from the familiarity of NumPy. And, last but not least, for those looking for a library with a purely functional soul.

In the rest of this article we will quickly explore the main functionalities of JAX available as of today (version 0.14), relying on parts of the original tutorial and extending them when necessary.

You can find the code used in this article in a Google Colab notebook.

How to install JAX

Installing JAX requires you to compile XLA on your architecture, following the instructions on the website. On Google Colab you can find a ready-made binary version that can be easily installed by running:

!pip install --upgrade https://storage.googleapis.com/jax-wheels/cuda92/jaxlib-0.1.4-py3-none-linux_x86_64.whl

!pip install --upgrade jax

The rest of this tutorial assumes you are using this environment.

JAX core 1: NumPy wrapper

Let's start with the basics: some random NumPy instructions.

import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)

z = np.sin(x) + np.cos(y)

In JAX we only have to import the appropriate wrapper for NumPy:

# The only difference!
import jax.numpy as np # Unica differenza!

x = np.ones((5000, 5000))
y = np.arange(5000)

z = np.sin(x) + np.cos(y) 

In the last code block the underlying XLA engine guarantees a good acceleration out of the box: on the GPU backend of Colab, it runs in 30 ms compared to the 480 ms of its NumPy equivalent. Not all NumPy/SciPy functions have been implemented yet, but they should be ready for the first stable release of the library. If, in addition to code acceleration, we want to use all the features described below (e.g. auto-differentiation), there are some additional constraints on the code we are able to write, also deriving from the functional nature of the library: for example, you cannot change the values in an array using its indexes.

JAX core 2: JIT compiler

The last block of code we have seen accelerates each instruction using the XLA engine, but in general you may want to accelerate entire blocks of code by exploiting any parallelization mechanism that might turn out to be useful. In this case, JAX provides a tracing compiler mechanism very similar to PyTorch's JIT compiler.

The operation is very simple: we can use a decorator (or an explicit function) to tell JAX what to compile:

from jax import jit

def fn(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

# Or, without using a decorator:
# fn = jit(fn)

In this case JAX will compile the function when it's called for the first time and will use directly the optimized version from the second call onwards. On the Colab GPU backend JIT compilation yields an additional 30% speed-up, reducing the average execution time to 20 ms. However, to leverage the JIT compiler we need to obey some additional constraints, particularly when it comes to indexing and conditional flow instructions.

JAX medium 1: auto-differencing

The self-differentiation mechanism is not very different to what is available in similar libraries. Given a Python function whose instructions are a sequence of tensor manipulations we can automatically invoke a new function yielding the gradient of its operations:

from jax import grad

def simple_fun(x):
  return np.sin(x) / x

# Return the gradient of simple_fun with respect to x
grad_simple_fun = grad(simple_fun)

We can call grad multiple times get higher order gradients:

# Calculate the second-order gradient (Hessian matrix diagonal)
grad_grad_simple_fun = grad(grad(simple_fun))

And of course we can graph everything!

import matplotlib.pyplot as plt
plt.plot(x_range, simple_fun(x_range), 'b')
plt.plot(x_range, [grad_simple_fun(xi) for xi in x_range], '--r')
plt.plot(x_range, [grad_grad_simple_fun(xi) for xi in x_range], '--g')

JAX medium 2: advanced vectorization with vmap

JAX provides a third acceleration mechanism to be used when we want to apply the same function on one or more axes of a tensor.

Let's look at a practical example using the gradient calculation we have seen before:

# Gradient calculation (naive)
[grad_simple_fun(xi) for xi in x_range]

As many other libraries, JAX assumes that the function we are differentiating has only one output. In this case, in order to calculate several gradients in parallel, we had to call it indipendently for each value. We can achieve the same result using the vmap operator:

from jax import vmap
grad_vect_simple_fun = vmap(grad_simple_fun)(x_range)

vmap returns a new function that applies the original function (grad_simple_fun) to an entire vector. In this simple way, we get a 100x speedup on the execution (4 ms against 400 ms)!

grad, jit and vmap are three examples of what JAX calls modular transformations, i.e. operators that can be applied to a generic function and that can be combined together.

JAX lifecycle
Schema of the "lifecycle" of a function in JAX. Source: JAX GitHub.

JAX core 2.5: pseudo-random number generator

Before talking about some high-level modules for neural network training, it is necessary to briefly discuss the way JAX handles pseudo-random numbers. JAX implements its own PRNG which, unlike NumPy's one, has a purely functional interface, i.e. without side effects: among other things, a call to a pseudo-random method (eg, randn) does not change the internal state of the generator.

So JAX users must explicitly call and manipulate the status of the PRNG through a key:

from jax import random

# Generate a key
key = random.PRNGKey(0)

# he key must be explicitly passed to create an array of pseudo-random numbers
print(random.normal(key, shape=(3,)))

Important peculiarity, as mentioned, is that the key is not modified by the call to random.normal: subsequent calls to the function with the same key would produce the same array. To change the key, we must 'split it' with a specific call:

# Get two different keys
key, new_key = random.split(key)

# Using two different keys we get two different results
print(random.normal(key, shape=(3,)))
print(random.normal(new_key, shape=(3,)))

As of today, perhaps, this is the least intuitive and most error prone aspect of the library. It might be modified/improved in the future.

JAX advanced 1: building neural networks with STAX

JAX also contains mini-libraries that highlight its potential. One of these, STAX, can be used to build neural networks, with an interface similar to other deep learning frameworks. For example, we can build a network as a "stack" of different "layers":

from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax

net_init, net_apply = stax.serial(
    Dense(10), Relu,
    Dense(3), LogSoftmax,

Unlike other frameworks, however, the resulting neural network exposes a functional interface: in particular, the network is defined by a pair of functions, respectively for parameter initialization and prediction.

# Initialise the network with four inputs
out_shape, net_params = net_init((-1, 4))

# Get the predictions generated by the network
print(net_apply(net_params, Xtrain))

JAX advanced 2: optimisation with minmax

The second JAX library, minmax, allows optimization of cost functions. Suppose we define a cost function for our network (e.g. cross-entropy):

def loss(params):
  predictions = net_apply(params, Xtrain)
  return - np.mean(ytrain * predictions)

Within minmax we find several algorithms already implemented, including Adam:

from jax.experimental import minmax
opt_init, opt_update = minmax.adam(step_size=0.01)

Even the optimiser, like the neural network, is not an object but is defined by two functions, one for initialization and one for the update step (given the gradients). Let's see the code for a single optimisation step:

def step(i, opt_state):
  # Parameters for the optimisation algorithm
  params = minmax.get_params(opt_state)
  # Gradient of the loss function
  g = grad(loss)(params)
  # Update step
  return opt_update(i, g, opt_state)

And the overall optimisation code:

# Optimiser initialisation
opt_state = opt_init(net_params)
for i in range(100):
  # Train step
  opt_state = step(i, opt_state)
# Final parameters after training
net_params = minmax.get_params(opt_state)


JAX is a very young but extremely promising library with a potentially large audience, both for those who want to implement "from scratch" using its strong acceleration features as well as for those who feel inspired by its extremely functional interface. Waiting for the release of the first stable version, we hope this short tutorial has got you excited!

About the authors

Simone Scardapane is a research fellow at Sapienza University, and co-founder of the Italian Association for Machine Learning. Stefano Di Pietro is an IT and Machine Learning consultant and an Artificial Intelligence enthusiast.

If you liked our article, remember that subscribing to the Italian Association for Machine Learning is free! You can follow us daily on Facebook, LinkedIn, and Twitter.

Previous Post Next Post