In this tutorial we’ll have a look at creating a simple classifier for the MNIST dataset. We’ll try to do this in the least amount of code possible and explain every step in detail along the way.

If you found yourself here and want to learn more about the MNIST dataset, or what classification models are, you can read up on them here.

Import and Setup

It’s time to hop over to google colab, and to create a new notebook and select a gpu instance as runtime. Then we have to install and import torch and torchvision, since we’ll be relying on them for most of the code we’ll write:

!pip install torch torchvision
import torch, torchvision
from torch import nn, optim

Model Architecture

The model for this task is very simple, and can merely consists of one fully connected layer:

class Classifier(nn.Module):
    def __init__(self, **kwargs):

        self.layer = nn.Linear(784, 10)

    def forward(self, x):
        return self.layer(x)

Since the MNIST dataset is relatively simple, a single layer is sufficient for a ~90% accuracy. If you want to squeeze out a bit more you can add an additional layer. If you’re not familiar with this syntax it might be best to check out our article on getting started with Pytorch.

Next we’ll want to do, is instantiating our model and sending it over to the gpu, this can be done with the .device() and .to() function in pytorch:

#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load model to the specified device, either gpu or cpu
model = Classifier().to(device)

We also require an optimizer and a criterion to train our model. We’ll go with the standard Adam optimizer and the Cross Entropy Loss which is the go to loss function when solving multi-class classification tasks. For an in depth explanation of this loss function I recommend this resource.

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# mean-squared error loss
criterion = nn.CrossEntropyLoss()

Data and Dataloaders

Lastly we need to get the training and testing data into our notebook, which we can do conveniently with the torchvision module:

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True

And we also need to create dataloaders to feed the data to our models:

train_loader =
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True

test_loader =
    test_dataset, batch_size=32, shuffle=False, num_workers=4

Training Loop

Now we’ve done all the necessary setup and can put together the heart piece of our code, the training loop! We create a nested loop, one that iterates through the epochs, and an inner loop that iterates through the batches of the dataloader, to feed them one after another to our model. We also need to do the call the usual pytorch functions that reset the gradients every training loop, and perform a forward pass and backpropagation. The training loop looks like this:

epochs = 5
for epoch in range(epochs):
    loss = 0
    for batch_features, batch_labels in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, 784).to(device)

        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes

        # compute reconstructions
        outputs = model(batch_features)
        #outputs = outputs.view(-1).to(device)

        # compute training reconstruction loss

        train_loss = criterion(outputs,

        # compute accumulated gradients

        # perform parameter update based on current gradients

        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()

    # compute the epoch training loss
    loss = loss / len(train_loader)

    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

Testing our model:

We also still want to see how well our model performs, and we can do so by comparing the predicted label vs the target label and averaging over the entire test set:

import matplotlib.pyplot as plt
import numpy as np
total = 0
correct = 0
for batch_features, labels in test_loader:
    batch_features = batch_features.view(-1, 784).to(device)  
    outputs = model(batch_features)


    #plt.imshow(batch_features[0].cpu().detach().view(28, 28).numpy())     


    _, predicted = torch.max(, 1)
    total += labels.size(0)
    correct += (predicted.cpu() == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))