Overview
Dive into a comprehensive tutorial on Flax, a JAX-based machine learning library, covering everything from basics to advanced concepts. Learn how to build performant and reproducible ML models, explore Flax's advantages over Haiku, and master key concepts like linear regression, custom model creation, and CNN implementation. Gain hands-on experience with practical examples, including a linear regression toy example and a CNN on MNIST dataset. Discover how to handle dropout, BatchNorm, and other essential techniques for building robust neural networks. Follow along with the provided Jupyter notebook and leverage additional resources to deepen your understanding of Flax and its ecosystem.
Syllabus
Intro - Flax is performant and reproducible
Deepnote walk-through sponsored
Flax basics
Flax vs Haiku
Benchmarking Flax
Linear regression toy example
Introducing Optax Adam state example
Creating custom models
self.param example
self.variable example
Handling dropout, BatchNorm, etc.
CNN on MNIST example
TrainState source code
CNN dropout modification
Outro and summary
Taught by
Aleksa Gordić - The AI Epiphany