graviti
ProductOpen DatasetsApps MarketSolutionsResourcesCompany

Enhancing Neural Networks with Mixup in PyTorch

Published at2021-09-29

Randomly mixing up images, and it works better?

Image classification has been one of the domains that thrived with the exponential improvement of deep learning. Traditional image recognition tasks heavily rely on processing methods such as dilations/erosions, kernels, and transforms to the frequency domain , and yet the difficulty in feature extraction has ultimately confined the progress made through these methods. Neural networks, on the other hand, focus on finding the relationships between the input images and output labels to ‘tune’ an architecture for such purpose. While the increase in accuracy was significant, networks often require vast quantities of data for training, and thus numerous research now focuses on performing data augmentation — the process of increasing data quantity from a pre-existing dataset.

This article introduces a simple yet surprisingly effective augmentation strategy — mixup, with an implementation via PyTorch and the comparison of results.

Before Mixup — Why Data Augmentation?

Parameters inside a neural network architecture are trained and updated based on a given set of training data. However, as the training data only covers a certain part of the entire distribution of the possible data, the network would likely overfit on the ‘seen’ part of distribution. Hence, the more data we have for training would theoretically cover a better picture of the entire distribution.

While the number of data we have is limited, we can always try to slightly alter the images and use them as ‘new’ samples to feed into the network for training. This process is called data augmentation.

What is Mixup?

img

Supposedly we are classifying images of dogs and cats, and we are given a set of images for each of them with labels (i.e., [1, 0] -> dogs, [0, 1] -> cats), a mixup process is simply averaging out two images and their labels correspondingly as a new data.

Specifically, we can write the concept of mixup mathematically:

img

where x, y are the mixed images and labels of xᵢ (label yᵢ) and xⱼ (label y), and λ is a random number from a given beta distribution.

This provides continuous samples of data in between the different classes, which intuitively expands the distribution of a given training set and thus makes the network more robust during the testing phase.

Using mixup on any networks

Since mixup is merely a data augmentation method, it is orthogonal to any network architectures for classification, meaning that you can always implement this in a dataset with any networks you wish for a classification problem.

Based on the original paper mixup: Beyond Empirical Risk Minimization, Zhang et al. had experimented with multiple datasets and architectures, empirically indicating that the benefit of mixup is not just a one-time special case.

Computing Environment

Libraries

The entire program is built via the PyTorch library (including torchvision). The concept of mixup requires sample generation from beta distribution, which could be acquired from the NumPy library, we also used the random library to find random images for mixup. The following code imports all the libraries:

"""
Import necessary libraries to train a network using mixup
The code is mainly developed using the PyTorch library
"""
import numpy as np
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

Dataset

For demonstration, we apply the concept of mixup on the traditional image classification, to which CIFAR-10 seems to be the most viable option. CIFAR-10 contains 60000 colored images of 10 classes (6000 per class) divided into training and testing sets in a 5:1 ratio. The images are fairly simple to classify yet harder than the most basic digit recognition dataset MNIST.

There are numerous ways to download the CIFAR-10 dataset, including from the University of Toronto website or using torchvision datasets. One particular platform worth mentioning is the Graviti Open Datasets platform, which contains hundreds of datasets and the corresponding authors for them, as well as labels for each dataset’s designated training tasks (i.e., classification, object detection). You may download other classification datasets such as CompCars or SVHN to test out the improvement mixup brings in different scenarios. The company is currently developing their SDKs, which, although currently takes extra time to load the data directly, can be very useful in the near future as they are rapidly improving batch downloading.

Hardware Requirements

It is preferred to train the neural network on GPUs, as it increases the training speed significantly. However, if only CPU is available, you may still test the program. To allow your program to determine the hardware itself, simply use the following:

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Implementation

Network

The goal is to see the results of mixup and not the network itself. Hence, for demonstration purposes, a simple convolutional neural network (CNN) of 4 layers, followed by 2 layers of full-connected layers is implemented. Note that for both the mixup and non-mixup training procedure, the same network is applied to ensure fairness in comparison.

We can build the simple network like the following:

"""
Create a simple CNN
"""
class CNN(nn.Module):
   def __init__(self):
        super(CNN, self).__init__()

        # Network consists of 4 convolutional layers followed by 2 fully-connected layers
        self.conv11 = nn.Conv2d(3, 64, 3)
        self.conv12 = nn.Conv2d(64, 64, 3)
        self.conv21 = nn.Conv2d(64, 128, 3)
        self.conv22 = nn.Conv2d(128, 128, 3)
        self.fc1 = nn.Linear(128 * 5 * 5, 256)
        self.fc2 = nn.Linear(256, 10)
   def forward(self, x):
       x = F.relu(self.conv11(x))
       x = F.relu(self.conv12(x))
       x = F.max_pool2d(x, (2,2))
       x = F.relu(self.conv21(x))
       x = F.relu(self.conv22(x))
       x = F.max_pool2d(x, (2,2))

       # Size is calculated based on kernel size 3 and padding 0
       x = x.view(-1, 128 * 5 * 5)
       x = F.relu(self.fc1(x))
       x = self.fc2(x)

       return nn.Sigmoid()(x)

Mixup

The mixup stage is done during the dataset loading process. Therefore, we must write our own datasets instead of using the default ones provided by torchvision.datasets.

The following is a simple implementation of mixup by incorporating the beta distribution function from NumPy:

"""
Dataset and Dataloader creation
All data are downloaded found via Graviti Open Dataset which links to CIFAR-10 official page
The dataset implementation is where mixup take place
"""

class CIFAR_Dataset(Dataset):
    def __init__(self, data_dir, train, transform):
        self.data_dir = data_dir
        self.train = train
        self.transform = transform
        self.data = []
        self.targets = []

        # Loading all the data depending on whether the dataset is training or testing
        if self.train:
            for i in range(5):
                with open(data_dir + 'data_batch_' + str(i+1), 'rb') as f:
                    entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    self.targets.extend(entry['labels'])
        else:
            with open(data_dir + 'test_batch', 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                self.targets.extend(entry['labels'])

        # Reshape it and turn it into the HWC format which PyTorch takes in the images
        # Original CIFAR format can be seen via its official page
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        # Create a one hot label
        label = torch.zeros(10)
        label[self.targets[idx]] = 1.

        # Transform the image by converting to tensor and normalizing it
        if self.transform:
            image = transform(self.data[idx])

        # If data is for training, perform mixup, only perform mixup roughly on 1 for every 5 images
        if self.train and idx > 0 and idx%5 == 0:

            # Choose another image/label randomly
            mixup_idx = random.randint(0, len(self.data)-1)
            mixup_label = torch.zeros(10)
            label[self.targets[mixup_idx]] = 1.
            if self.transform:
                mixup_image = transform(self.data[mixup_idx])

            # Select a random number from the given beta distribution
            # Mixup the images accordingly
            alpha = 0.2
            lam = np.random.beta(alpha, alpha)
            image = lam * image + (1 - lam) * mixup_image
            label = lam * label + (1 - lam) * mixup_label

        return image, label

Note that we did not apply mixup for all images, but roughly every one in five. We also used a beta distribution of 0.2. You may change the distribution as well as the number of images that are mixed for different experiments. Perhaps you may achieve even better results!

Training and Evaluation

The following code shows the training procedure. We set the batch size to 128, the learning rate to 1e-3, and the total number of epochs to 30. The entire training is performed twice — with and without the mixup. The loss also has to be defined by ourselves, as currently, BCE loss doesn’t allow labels with decimals:

"""
Initialize the network, loss Adam optimizer
Torch BCE Loss does not support mixup labels (not 1 or 0), so we implement our own
"""
net = CNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
def bceloss(x, y):
    eps = 1e-6
    return -torch.mean(y * torch.log(x + eps) + (1 - y) * torch.log(1 - x + eps))
best_Acc = 0


"""
Training Procedure
"""
for epoch in range(NUM_EPOCHS):
    net.train()
    # We train and visualize the loss every 100 iterations
    for idx, (imgs, labels) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        preds = net(imgs)
        loss = bceloss(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("Epoch {} Iteration {}, Current Loss: {}".format(epoch, idx, loss))

    # We evaluate the network after every epoch based on test set accuracy
    net.eval()
    with torch.no_grad():
        total = 0
        numCorrect = 0
        for (imgs, labels) in test_dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            preds = net(imgs)
            numCorrect += (torch.argmax(preds, dim=1) == torch.argmax(labels, dim=1)).float().sum()
            total += len(imgs)
        acc = numCorrect/total
        print("Current image classification accuracy at epoch {}: {}".format(epoch, acc))
        if acc > best_Acc:
            best_Acc = acc

Extending Beyond Image Classification

While mixup has pushed state-of-the-art accuracies in image classification, research has shown that its benefits extend into other computer vision tasks such as generation and robustness to adversarial examples. Research literature has also been extending the concept into 3D representations which are also shown to be very effective (e.g., PointMixup).

Conclusion

So there you have it! Hopefully this article gives you a basic overview and guidance on how to apply mixup onto your image classification network training.


This article was originally published by Ta-Ying Cheng on Towards Data Science