Class Central is learner-supported. When you buy through links on our site, we may earn an affiliate commission.

YouTube

Introduction to JAX and XLA Optimization for Large Language Models 2023

Discover AI via YouTube

Overview

Learn how to optimize Large Language Models using JAX in this 19-minute technical talk that introduces key concepts for speeding up neural network computations. Explore JAX's powerful features including Jaxpr, XLA compilation with tf.function, vectorized operations through vmap and pmap, and asynchronous dispatch handling with block_until_ready(). Discover why JAX can achieve 10x-100x speed improvements over other frameworks through XLA memory bandwidth optimization and fusion techniques. Gain insights into implementing Flax.linen for T5X optimization, with detailed explanations of core concepts and links to essential documentation and learning resources. Perfect for developers familiar with TensorFlow2 and PyTorch2 looking to expand their toolkit with JAX's high-performance computing capabilities.

Syllabus

Introduction
Functions
Functional Programming
Stateful Layers
Neural Network Architecture
Xla
Neural Network
Notebooks

Taught by

Discover AI

Reviews

Start your review of Introduction to JAX and XLA Optimization for Large Language Models 2023

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.