Learn how to optimize Large Language Models using JAX in this 19-minute technical talk that introduces key concepts for speeding up neural network computations. Explore JAX's powerful features including Jaxpr, XLA compilation with tf.function, vectorized operations through vmap and pmap, and asynchronous dispatch handling with block_until_ready(). Discover why JAX can achieve 10x-100x speed improvements over other frameworks through XLA memory bandwidth optimization and fusion techniques. Gain insights into implementing Flax.linen for T5X optimization, with detailed explanations of core concepts and links to essential documentation and learning resources. Perfect for developers familiar with TensorFlow2 and PyTorch2 looking to expand their toolkit with JAX's high-performance computing capabilities.
Overview
Syllabus
Introduction
Functions
Functional Programming
Stateful Layers
Neural Network Architecture
Xla
Neural Network
Notebooks
Taught by
Discover AI