Back to Home

Accelerate keras model training on Macs using Apple MLX

Apple MLX is a machine learning library from Apple’s Machine Learning Research team. MLX offers APIs similar to NumPy and PyTorch.

MLX has initial support as a Keras backend and is supposedly the fastest Keras backend on ARM Macs https://github.com/keras-team/keras/pull/18962#issue-2047200491

The default Keras backend is Tensorflow

python -c "import tensorflow as tf; print(tf.keras.backend.backend())"
tensorflow

If you’d like to change your keras backend to MLX, here’s how you can do so. First, if you already have keras installed, uninstall it

pip uninstall keras

And then install keras’ mlx branch

pip install git+https://github.com/keras-team/keras.git@mlx

When you train your model, provide the KERAS_BACKEND environment variable and set it to mlx

KERAS_BACKEND=mlx python -c "import tensorflow as tf; print(tf.keras.backend.backend())"
mlx

Your model should train faster.

Please note that as of the day of writing, the mlx backend is in active development and doesn’t yet support all keras features (like keras’ Conv2D layer etc.). If a keras feature is yet to be supported, the mlx backend will raise an exception and inform you.

Built with Hugo & Notion. Source code is available at GitHub.