Accelerate keras model training on Macs using Apple MLX
May 15, 2024
Apple MLX is a machine learning library from Apple’s Machine Learning Research team. MLX offers an API that closely mirrors NumPy’s API.
MLX also has initial support as a Keras backend and is supposedly the fastest Keras backend on ARM Macs https://github.com/keras-team/keras/pull/18962#issue-2047200491
The default Keras backend is Tensorflow
python -c "import tensorflow as tf; print(tf.keras.backend.backend())"
tensorflow
If you’d like to change your keras backend to MLX, here’s how you can do so.
First, if you already have keras installed, uninstall it
pip uninstall keras
And then install keras’ mlx
branch
pip install git+https://github.com/keras-team/keras.git@mlx
When you train your model, provide the KERAS_BACKEND
environment variable and set it to mlx
KERAS_BACKEND=mlx python -c "import tensorflow as tf; print(tf.keras.backend.backend())"
mlx
Your model should train faster.
Please note that as of the day of writing, the mlx
backend is in active development and doesn’t yet support all keras features (like keras’ Conv2D
layer etc.). If you use a keras feature that is yet to be supported, the mlx
backend will raise an exception and inform you.