After completing your first Machine Learning ‘Hello World!’ tutorial, that probably consists of creating a classifier for the MNIST dataset, you’ll probably wonder how to apply a classifier to color images. This is the purpose of this tutorial.
Fully Connected Model
We could approach this dataset in a similar manner to how we solved the MNIST classification task, however this will give us very bad results.
When we’re flattening the image, we’re discarding a lot of structural information that is already present in the pixels of the image. Rather than just throwing this information out of the window, we’d be better off leveraging it and using it to improve our results. One way to do this is by using a convolutional neural network rather than a fully connected network.
First, let’s see how well a fully connected model performs. Assume we use some architecture similar to this that consists of two layers:
class Classifier(nn.Module): def __init__(self, **kwargs): super().__init__() self.layer = nn.Linear(3072, 64) self.act = nn.ReLU() self.layer2 = nn.Linear(64, 10) self.activation = nn.LogSoftmax(); def forward(self, x): x = self.layer(x) x = self.act(x) x = self.layer2(x) return self.activation(x)
And during training we flatten the input image as such:
batch_features = batch_features.view(-1,3072).to(device)
Essentially we’re stretching a 32x32 RGB image (32x32x3) into a one dimensional tensor with 3072 entries. After training for 10 epochs we obtain an accuracy of 43% (may vary), which isn’t very good.
Next we’ll throw a Convolutional Neural Network at the dataset and see how well that performs. We actually don’t have to change too much from the previous code, except for the network architecture and the line that feeds the batch_features to our model. Starting with the model, our classifier class will look like this:
class Classifier(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = torch.flatten(x, 1) # flatten all dimensions except batch x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return x