Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Aleksa Gordić - The AI Epiphany via YouTube Direct link

Prediction function

3 of 13

3 of 13

Prediction function

Class Central Classrooms beta

YouTube videos curated by Class Central.

Classroom Contents

Coding a Neural Network from Scratch in Pure JAX - Machine Learning with JAX - Tutorial 3

Automatically move to the next video in the Classroom when playback concludes

  1. 1 Intro, structuring the code
  2. 2 MLP initialization function
  3. 3 Prediction function
  4. 4 PyTorch MNIST dataset
  5. 5 PyTorch data loaders
  6. 6 Training loop
  7. 7 Adding the accuracy metric
  8. 8 Visualize the image and prediction
  9. 9 Small code refactoring
  10. 10 Visualizing MLP weights
  11. 11 Visualizing embeddings using t-SNE
  12. 12 Analyzing dead neurons
  13. 13 Outro

Never Stop Learning.

Get personalized course recommendations, track subjects and courses with reminders, and more.

Someone learning on their laptop while sitting on the floor.