Skip to main content

Jax, Flax, & Macs

·2 mins

JAX is a powerful tensor manipulation and autograd library that has seen a surge in popularity recently. It’s also one of the key dependencies in Google’s neural network library Flax. However, trying to install it on an M1 Mac can be a bit tricky. Today I managed to get it running on a 2021 Macbook Pro M1 without too many problems, so I’ve shared the Python environment here to help others facing the same problem. Once downloaded, and in a virtual environment, the key dependencies can be installed by running:

pip install --upgrade pip
pip install -r flax-requirements.txt

And that should (hopefully) do it.

Out of curiosity I ran a quick test using a CNN on MNIST to see how well the M1 chip does against a Colab GPU. I ran the example code and recorded how long 10 training and evaluation epochs took with a fixed batch size of 32.

DeviceRunning Time
M1 Pro (10 Core)3m 3s
Colab GPU1m 7s
Colab CPU~ 20m

My 10 core M1 Pro ran 10 epochs in 3 minutes and 3 seconds, where the Colab GPU took only 1 minute and 7 seconds. Though nearly 3x slower, I don’t think this is too bad for a CPU, and will be great for development environments. This story might change a lot depending on batch sizes and workloads, so take this with a pinch of salt. For reference though, a Colab CPU took around 20 minutes to run the same code, so the M1 CPU is certainly doing something right.

PyMC & BlackJAX #

Another reason you might want to use JAX is for probabilistic programming in PyMC. Using a JAX-based backend (BlackJAX) it’s possible to speed up model compilation and fitting quite considerably. I also ran a quick comparison using a simple linear regression with 2000 data points to test sampling and compilation times in PyMC and CmdStan. I ran each model 5 times and took the lowest number for both sampling time and end-to-end time, including compilation.

SamplerSampling TimeEnd-to-end Time
CmdStan 2.30.01.1s6.7s
PyMC 4.2.29.0s10.9s
PyMC 4.2.2 + BlackJAX2.8s3.6s
PyMC 4.2.2 + NumPyro1.8s2.8s

Interestingly, base PyMC was slowest in both cases, but benefitted a lot from using the BlackJAX and NumPyro backends. The CmdStan sampler was fastest, but takes a while longer to compile meaning the end-to-end time was slower. This benchmark is far from exhaustive, and depending on model specification and dataset size, I suspect these results could change a lot.