A Generative Adversarial Network (GAN) is a deep learning architecture that consists of two neural networks competing against each other in a zero-sum game framework. The goal of GANs is to generate new, synthetic data that resembles some known data distribution.
What is a Generative Adversarial Network?
Generative Adversarial Networks (GANs) are a powerful class of neural networks that are used for unsupervised learning. It was developed and introduced by Ian J. Goodfellow in 2014. GANs are basically made up of a system of two competing neural network models which compete with each other and are able to analyze, capture and copy the variations within a dataset.
Why were GANs developed in the first place?
It has been noticed most of the mainstream neural nets can be easily fooled into misclassifying things by adding only a small amount of noise into the original data. Surprisingly, the model after adding noise has higher confidence in the wrong prediction than when it predicted correctly. The reason for such an adversary is that most machine learning models learn from a limited amount of data, which is a huge drawback, as it is prone to overfitting. Also, the mapping between the input and the output is almost linear. Although, it may seem that the boundaries of separation between the various classes are linear, but in reality, they are composed of linearities, and even a small change in a point in the feature space might lead to the misclassification of data.
How do GANs work?
Generative Adversarial Networks (GANs) can be broken down into three parts:
- Generative: To learn a generative model, which describes how data is generated in terms of a probabilistic model.
- Adversarial: The training of a model is done in an adversarial setting.
- Networks: Use deep neural networks as artificial intelligence (AI) algorithms for training purposes.
In GANs, there is a Generator and a Discriminator. The Generator generates fake samples of data(be it an image, audio, etc.) and tries to fool the Discriminator. The Discriminator, on the other hand, tries to distinguish between the real and fake samples. The Generator and the Discriminator are both Neural Networks and they both run in competition with each other in the training phase. The steps are repeated several times and in this, the Generator and Discriminator get better and better in their respective jobs after each repetition. The work can be visualized by the diagram given below:
Here, the generative model captures the distribution of data and is trained in such a manner that it tries to maximize the probability of the Discriminator making a mistake. The Discriminator, on the other hand, is based on a model that estimates the probability that the sample that it got is received from the training data and not from the Generator. The GANs are formulated as a minimax game, where the Discriminator is trying to minimize its reward V(D, G) and the Generator is trying to minimize the Discriminator’s reward or in other words, maximize its loss. It can be mathematically described by the formula below:
where,
- G = Generator
- D = Discriminator
- Pdata(x) = distribution of real data
- P(z) = distribution of generator
- x = sample from Pdata(x)
- z = sample from P(z)
- D(x) = Discriminator network
- G(z) = Generator network
Generator Model
The Generator is trained while the Discriminator is idle. After the Discriminator is trained by the generated fake data of the Generator, we can get its predictions and use the results for training the Generator and get better from the previous state to try and fool the Discriminator.
Discriminator Model
The Discriminator is trained while the Generator is idle. In this phase, the network is only forward propagated and no back-propagation is done. The Discriminator is trained on real data for n epochs and sees if it can correctly predict them as real. Also, in this phase, the Discriminator is also trained on the fake generated data from the Generator and see if it can correctly predict them as fake.
Different Types of GAN Models
- Vanilla GAN: This is the simplest type of GAN. Here, the Generator and the Discriminator are simple multi-layer perceptrons. In vanilla GAN, the algorithm is really simple, it tries to optimize the mathematical equation using stochastic gradient descent.
- Conditional GAN (CGAN): CGAN can be described as a deep learning method in which some conditional parameters are put into place. In CGAN, an additional parameter ‘y’ is added to the Generator for generating the corresponding data. Labels are also put into the input to the Discriminator in order for the Discriminator to help distinguish the real data from the fake generated data.
- Deep Convolutional GAN (DCGAN): DCGAN is one of the most popular and also the most successful implementations of GAN. It is composed of ConvNets in place of multi-layer perceptrons. The ConvNets are implemented without max pooling, which is in fact replaced by convolutional stride. Also, the layers are not fully connected.
- Laplacian Pyramid GAN (LAPGAN): The Laplacian pyramid is a linear invertible image representation consisting of a set of band-pass images, spaced an octave apart, plus a low-frequency residual. This approach uses multiple numbers of Generator and Discriminator networks and different levels of the Laplacian Pyramid. This approach is mainly used because it produces very high-quality images. The image is down-sampled at first at each layer of the pyramid and then it is again up-scaled at each layer in a backward pass where the image acquires some noise from the Conditional GAN at these layers until it reaches its original size.
- Super Resolution GAN (SRGAN): SRGAN as the name suggests is a way of designing a GAN in which a deep neural network is used along with an adversarial network in order to produce higher-resolution images. This type of GAN is particularly useful in optimally up-scaling native low-resolution images to enhance their details minimizing errors while doing so.
GANs Implementation using Pytorch
Step 1: Importing the required libraries
Python3
import torch import torch.nn as nn import torch.optim as optim import torchvision from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np # Set device device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) |
Step 2: Loading the Dataset
Python3
train_dataset = datasets.CIFAR10(root = './data' ,\ train = True , download = True , transform = transform) dataloader = torch.utils.data.DataLoader(train_dataset, \ batch_size = 32 , shuffle = True ) |
Step 3: Defining parameters to be used in later processes
Python3
# Hyperparameters latent_dim = 100 lr = 0.0002 beta1 = 0.5 beta2 = 0.999 num_epochs = 10 |
Step 4: Defining a Utility Class to Build the Generator
Python3
# Define the generator class Generator(nn.Module): def __init__( self , latent_dim): super (Generator, self ).__init__() self .model = nn.Sequential( nn.Linear(latent_dim, 128 * 8 * 8 ), nn.ReLU(), nn.Unflatten( 1 , ( 128 , 8 , 8 )), nn.Upsample(scale_factor = 2 ), nn.Conv2d( 128 , 128 , kernel_size = 3 , padding = 1 ), nn.BatchNorm2d( 128 , momentum = 0.78 ), nn.ReLU(), nn.Upsample(scale_factor = 2 ), nn.Conv2d( 128 , 64 , kernel_size = 3 , padding = 1 ), nn.BatchNorm2d( 64 , momentum = 0.78 ), nn.ReLU(), nn.Conv2d( 64 , 3 , kernel_size = 3 , padding = 1 ), nn.Tanh() ) def forward( self , z): img = self .model(z) return img |
Step 5: Defining a Utility Class to Build the Discriminator
Python3
# Define the discriminator class Discriminator(nn.Module): def __init__( self ): super (Discriminator, self ).__init__() self .model = nn.Sequential( nn.Conv2d( 3 , 32 , kernel_size = 3 , stride = 2 , padding = 1 ), nn.LeakyReLU( 0.2 ), nn.Dropout( 0.25 ), nn.Conv2d( 32 , 64 , kernel_size = 3 , stride = 2 , padding = 1 ), nn.ZeroPad2d(( 0 , 1 , 0 , 1 )), nn.BatchNorm2d( 64 , momentum = 0.82 ), nn.LeakyReLU( 0.25 ), nn.Dropout( 0.25 ), nn.Conv2d( 64 , 128 , kernel_size = 3 , stride = 2 , padding = 1 ), nn.BatchNorm2d( 128 , momentum = 0.82 ), nn.LeakyReLU( 0.2 ), nn.Dropout( 0.25 ), nn.Conv2d( 128 , 256 , kernel_size = 3 , stride = 1 , padding = 1 ), nn.BatchNorm2d( 256 , momentum = 0.8 ), nn.LeakyReLU( 0.25 ), nn.Dropout( 0.25 ), nn.Flatten(), nn.Linear( 256 * 5 * 5 , 1 ), nn.Sigmoid() ) def forward( self , img): validity = self .model(img) return validity |
Step 6: Building the Generative Adversarial Network
Python3
# Define the generator and discriminator # Initialize generator and discriminator generator = Generator(latent_dim).to(device) discriminator = Discriminator().to(device) # Loss function adversarial_loss = nn.BCELoss() # Optimizers optimizer_G = optim.Adam(generator.parameters()\ , lr = lr, betas = (beta1, beta2)) optimizer_D = optim.Adam(discriminator.parameters()\ , lr = lr, betas = (beta1, beta2)) |
Step 7: Training the Generative Adversarial Network
Python3
# Training loop for epoch in range (num_epochs): for i, batch in enumerate (dataloader): # Convert list to tensor real_images = batch[ 0 ].to(device) # Adversarial ground truths valid = torch.ones(real_images.size( 0 ), 1 , device = device) fake = torch.zeros(real_images.size( 0 ), 1 , device = device) # Configure input real_images = real_images.to(device) # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Sample noise as generator input z = torch.randn(real_images.size( 0 ), latent_dim, device = device) # Generate a batch of images fake_images = generator(z) # Measure discriminator's ability # to classify real and fake images real_loss = adversarial_loss(discriminator\ (real_images), valid) fake_loss = adversarial_loss(discriminator\ (fake_images.detach()), fake) d_loss = (real_loss + fake_loss) / 2 # Backward pass and optimize d_loss.backward() optimizer_D.step() # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Generate a batch of images gen_images = generator(z) # Adversarial loss g_loss = adversarial_loss(discriminator(gen_images), valid) # Backward pass and optimize g_loss.backward() optimizer_G.step() # --------------------- # Progress Monitoring # --------------------- if (i + 1 ) % 100 = = 0 : print ( f"Epoch [{epoch + 1 } / {num_epochs}]\ Batch {i + 1 } / { len (dataloader)} " f "Discriminator Loss: {d_loss.item():.4f} " f "Generator Loss: {g_loss.item():.4f}" ) # Save generated images for every epoch if (epoch + 1 ) % 10 = = 0 : with torch.no_grad(): z = torch.randn( 16 , latent_dim, device = device) generated = generator(z).detach().cpu() grid = torchvision.utils.make_grid(generated,\ nrow = 4 , normalize = True ) plt.imshow(np.transpose(grid, ( 1 , 2 , 0 ))) plt.axis( "off" ) plt.show() |
Output:
Epoch [10/10] Batch 1500/1563 Discriminator Loss: 0.5253 Generator Loss: 1.3269
Application Of Generative Adversarial Networks (GANs):
GANs, or Generative Adversarial Networks, have many uses in many different fields. Here are some of the widely recognized uses of GANs:
- Image Synthesis and Generation : GANs are often used for picture synthesis and generation tasks, They may create fresh, lifelike pictures that mimic training data by learning the distribution that explains the dataset. The development of lifelike avatars, high-resolution photographs, and fresh artwork have all been facilitated by these types of generative networks.
- Image-to-Image Translation : GANs may be used for problems involving image-to-image translation, where the objective is to convert an input picture from one domain to another while maintaining its key features. GANs may be used, for instance, to change pictures from day to night, transform drawings into realistic images, or change the creative style of an image.
- Text-to-Image Synthesis : GANs have been used to create visuals from descriptions in text. GANs may produce pictures that translate to a description given a text input, such as a phrase or a caption. This application might have an impact on how realistic visual material is produced using text-based instructions.
- Data Augmentation : GANs can augment present data and increase the robustness and generalizability of machine-learning models by creating synthetic data samples.
- Data Generation for Training : GANs can enhance the resolution and quality of low-resolution images. By training on pairs of low-resolution and high-resolution images, GANs can generate high-resolution images from low-resolution inputs, enabling improved image quality in various applications such as medical imaging, satellite imaging, and video enhancement.
- Style Transfer and Editing : GANs have been employed for style transfer and editing in images and videos. They can learn the style of a reference image or video and apply that style to other images or videos, enabling artistic transformations, such as converting photographs into paintings or altering the appearance of videos.
Advantages of Generative Adversarial Networks (GANs):
- Synthetic data generation: GANs can generate new, synthetic data that resembles some known data distribution, which can be useful for data augmentation, anomaly detection, or creative applications.
- High-quality results: GANs can produce high-quality, photorealistic results in image synthesis, video synthesis, music synthesis, and other tasks.
- Unsupervised learning: GANs can be trained without labeled data, making them suitable for unsupervised learning tasks, where labeled data is scarce or difficult to obtain.
- Versatility: GANs can be applied to a wide range of tasks, including image synthesis, text-to-image synthesis, image-to-image translation, anomaly detection, data augmentation, and others.
Disadvantages of Generative Adversarial Networks (GANs):
- Training Instability: GANs can be difficult to train, with the risk of instability, mode collapse, or failure to converge.
- Computational Cost: GANs can require a lot of computational resources and can be slow to train, especially for high-resolution images or large datasets.
- Overfitting: GANs can overfit the training data, producing synthetic data that is too similar to the training data and lacking diversity.
- Bias and Fairness: GANs can reflect the biases and unfairness present in the training data, leading to discriminatory or biased synthetic data.
- Interpretability and Accountability: GANs can be opaque and difficult to interpret or explain, making it challenging to ensure accountability, transparency, or fairness in their applications.