Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3
Aleksa Gordić - The AI Epiphany via YouTube
Overview
Save Big on Coursera Plus. 7,000+ courses at $160 off. Limited Time Only!
Learn to code a Neural Network from scratch using pure JAX in this comprehensive tutorial video. Dive into creating a Multi-Layer Perceptron (MLP) and training it as a classifier on the MNIST dataset. Follow along as the instructor guides you through the process, from initializing the MLP and implementing prediction functions to setting up PyTorch data loaders and constructing the training loop. Enhance your understanding with visualizations of learned weights, embeddings using t-SNE, and analysis of dead neurons. Gain practical insights into advanced JAX techniques and neural network implementation over the course of this 86-minute learning experience.
Syllabus
Intro, structuring the code
MLP initialization function
Prediction function
PyTorch MNIST dataset
PyTorch data loaders
Training loop
Adding the accuracy metric
Visualize the image and prediction
Small code refactoring
Visualizing MLP weights
Visualizing embeddings using t-SNE
Analyzing dead neurons
Outro
Taught by
Aleksa Gordić - The AI Epiphany