Class Central is learner-supported. When you buy through links on our site, we may earn an affiliate commission.

YouTube

Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Aleksa Gordić - The AI Epiphany via YouTube

Overview

Save Big on Coursera Plus. 7,000+ courses at $160 off. Limited Time Only!
Learn to code a Neural Network from scratch using pure JAX in this comprehensive tutorial video. Dive into creating a Multi-Layer Perceptron (MLP) and training it as a classifier on the MNIST dataset. Follow along as the instructor guides you through the process, from initializing the MLP and implementing prediction functions to setting up PyTorch data loaders and constructing the training loop. Enhance your understanding with visualizations of learned weights, embeddings using t-SNE, and analysis of dead neurons. Gain practical insights into advanced JAX techniques and neural network implementation over the course of this 86-minute learning experience.

Syllabus

Intro, structuring the code
MLP initialization function
Prediction function
PyTorch MNIST dataset
PyTorch data loaders
Training loop
Adding the accuracy metric
Visualize the image and prediction
Small code refactoring
Visualizing MLP weights
Visualizing embeddings using t-SNE
Analyzing dead neurons
Outro

Taught by

Aleksa Gordić - The AI Epiphany

Reviews

Start your review of Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.