This article was originally published by Ta-Ying Cheng on towords data science.
Generative Adversarial Networks (GANs), proposed by Goodfellow et al. in 2014, revolutionized a domain of image generation in computer vision — no one could believe that these stunning and lively images are actually generated purely by machines. In fact, people used to think the task of generation was impossible and were surprised with the power of GAN, because traditionally, there simply is no ground truth we can compare our generated images to.
This article introduces the simple intuition behind the creation of GAN, followed by an implementation of a convolutional GAN via PyTorch and its training procedure.
Jump ahead to a specific section.
Unlike traditional classification, where our network predictions can be directly compared to the ground truth correct answer, ‘correctness’ of a generated image is hard to define and measure. Goodfellow et al., in their original paper Generative Adversarial Networks, proposed an interesting idea: use a very well-trained classifier to distinguish between a generated image and an actual image. If such a classifier exists, we can create and train a generator network until it can output images that can completely fool the classifier.
GAN is the product of this procedure: it contains a generator that generates an image based on a given dataset, and a discriminator (classifier) to distinguish whether an image is real or generated. The detailed pipeline of a GAN can be seen in Figure 1.
Optimizing both the generator and the discriminator is difficult because, as you may imagine, the two networks have completely opposite goals: the generator wants to create something as realistic as possible, but the discriminator wants to distinguish generated materials.
To illustrate this, we let D(x) be the output from a discriminator, which is the probability of x being a real image, and G(z) be the output of our generator. The discriminator is analogous to a binary classifier, and so the goal for the discriminator would be to maximise the function:
Optimizing both the generator and the discriminator is difficult because, as you may imagine, the two networks have completely opposite goals: the generator wants to create something as realistic as possible, but the discriminator wants to distinguish generated materials.
To illustrate this, we let D(x) be the output from a discriminator, which is the probability of x being a real image, and G(z) be the output of our generator. The discriminator is analogous to a binary classifier, and so the goal for the discriminator would be to maximise the function:
which would theoretically converge to the discriminator predicting everything to a 0.5 probability.
In practice, however, the minimax game would often lead to the network not converging, so it is important to carefully tune the training process. Hyperparameters such as learning rates are significantly more important in training a GAN — small changes may lead to GANs generating a single output regardless of the input noises.
The entire program is built via the PyTorch library (including torchvision). Visualization of a GAN’s generated results are plotted using the Matplotlib library. The following code imports all the libraries:
Datasets are an important aspect when training GANs. The unstructured nature of images implies that any given class (i.e., dogs, cats, or a handwritten digit) can have a distribution of possible data, and such distribution is ultimately the basis of the contents generated by GAN.
For demonstration, this article will use the simplest MNIST dataset, which contains 60000 images of handwritten digits from 0 to 9. Unstructured datasets like MNIST can actually be found on Graviti. This is a young startup that wants to help the community with unstructured datasets, and they have some of the best public unstructured datasets on their platform, including MNIST.
It is preferable to train the neural network on GPUs, as they increase the training speed significantly. However, if only CPUs are available, you may still test the program. To allow your program to determine the hardware itself, simply use the following:
Due to the simplicity of numbers, the two architectures — discriminator and generator — are constructed by fully connected layers. Note that it is also slightly easier for a fully connected GAN to converge than a DCGAN at times.
The following are the PyTorch implementations of both architectures:
When training GAN, we are optimizing the results of the discriminator and, at the same time, improving our generator. Therefore, there would be two losses that contradict each other during each iteration to optimize them simultaneously. What we feed into the generator are random noises, and the generator supposedly should create images based on the slight differences of a given noise:
After 100 epochs, we can plot the datasets and see the results of generated digits from random noises:
As shown above, the generated results do look fairly like the real ones. Considering the networks are fairly simple, the results indeed seem promising!
GAN’s creation was so different from prior work in the computer vision domain. Numerous applications that followed surprised the academic community with what deep networks are capable of. Some astonishing work is described below.
CycleGAN by Zhu et al. introduces a concept that translates an image from domain X to domain Y without the need of pair samples. With horses transformed into zebras and summer sunshine transformed into a snowy storm, CycleGAN’s results were surprising and accurate.
Nvidia utilized the power of GAN to convert simple paintings into elegant and realistic photographs based on the semantics of the paintbrushes. Although the training resource was computationally expensive, it creates an entirely new domain of research and application.
GANs have also been extended to clean up adversarial images and transform them into clean examples that do not fool the classifications. More information on adversarial attacks and defences can be found here.
So there you have it! Hopefully this article provides and overview on how to build a GAN yourself.