The aim of the article is to implement GANs architecture using PyTorch framework. The article provides comprehensive understanding of GANs in PyTorch along with in-depth explanation of the code.
Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning. They consist of two neural networks, the generator and the discriminator, which are trained simultaneously through a competitive process. The generator creates new data instances, while the discriminator evaluates whether they are real (from the true data distribution) or fake (produced by the generator). This adversarial training process leads to the improvement of both networks over time
Implementing GANs using PyTorch Framework In this section, we are going to demonstrate the implementation of Generative Adversarial Network (GAN) architecture for generating realistic handwritten digits using the following steps:
Step 1: Importing Necessary Libraries We will be importing fundamental pytorch libraries : torch and torch.nn, torch.optim for updating the parameters of the neural network. torchvision is utilized for loading and preprocessing the MNIST dataset, making it easier to work with image data in PyTorch and torchvision.transforms is used to define transformations for preprocessing the MNIST images before feeding them into the GAN.
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np Step 2: Define Generator Function We have defined a generator class.
- Initialization: Inherits from nn.Module and takes a parameter
noise_dim , representing the dimensionality of the input noise vector. The main architecture is defined within this method. - Architecture: Utilizes a sequential neural network (
self.main ) consisting of linear, ReLU activation, unflatten, and convolutional transpose layers. These layers progressively upsample the input noise vector to generate a 28×28 grayscale image resembling handwritten digits. - Output Layer: The final layer applies a Tanh activation function to squish the pixel values of the output image to the range [-1, 1], making it suitable for real-valued image data.
- Forward Method: Implements the forward pass of the generator. It takes an input noise vector (
x ) and passes it through the sequential model (self.main ) to generate the output image.
# Generator class Generator(nn.Module): def __init__(self, noise_dim): super(Generator, self).__init__() self.noise_dim = noise_dim self.main = nn.Sequential( nn.Linear(noise_dim, 7 * 7 * 256), nn.ReLU(True), nn.Unflatten(1, (256, 7, 7)), nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1), nn.Tanh() )
def forward(self, x): return self.main(x) Step 3: Define Discriminator Function We have defined discriminator function.
- Initialization: Inherits from nn.Module. The discriminator is designed without any input parameters.
- Architecture: Utilizes a sequential neural network (
self.main ) comprising convolutional layers with LeakyReLU activations and batch normalization. These layers progressively downsample the input image to a single scalar output, determining the likelihood that the input image is real. - Output Layer: The final layer is a fully connected linear layer, producing a single scalar output representing the discriminator’s decision on the input image’s authenticity.
- Forward Method: Implements the forward pass of the discriminator. It takes an input image (
x ) and passes it through the sequential model (self.main ) to compute the discriminator’s output.
# Discriminator class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(1, 64, 5, stride=2, padding=2), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(64), nn.Conv2d(64, 128, 5, stride=2, padding=2), nn.LeakyReLU(0.2, inplace=True), nn.BatchNorm2d(128), nn.Flatten(), nn.Linear(7 * 7 * 128, 1) )
def forward(self, x): return self.main(x) Step 4: Combine the Generator and Discriminator Function Here, an instance is created “generator” with specified noise vector. The generator will be responsible for generating fake images from random noise. Next, we have created another instance “discriminator” to distinguish between real and fake images.
# Noise dimension NOISE_DIM = 100
# Generator and discriminator generator = Generator(NOISE_DIM) discriminator = Discriminator() Step 5: Device Configuration Device configuration allows for efficient training of the GAN models on the available hardware resources.
# Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = generator.to(device) discriminator = discriminator.to(device) Step 6: Set Loss Function, Optimizer and Hyperparameters In this section of the code ,we have used Binary Cross Entropy with Logits Loss as loss function, this function is used for binary classification and suits the problem to distinguish between real and fake images. We initialize two Adam optimizers, one for the generator (generator_optimizer) and one for the discriminator (discriminator_optimizer) with learning rate of 0.0002.
We set the number of epochs (NUM_EPOCHS ) to 5 and the batch size (BATCH_SIZE ) to 256. These hyperparameters determine the number of iterations and the size of the data batches used for training the GAN.
# Loss function criterion = nn.BCEWithLogitsLoss()
# Optimizers generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training hyperparameters NUM_EPOCHS = 5 BATCH_SIZE = 256 Step 7: DataLoader This section of the code prepares the MNIST dataset for training the GAN:
- Transformations: Images are transformed into tensors and normalized to range [-1, 1].
- Dataset: MNIST training dataset is loaded with specified transformations and downloaded if necessary.
- DataLoader: Creates batches of data, shuffles them, and handles loading them during training.
# DataLoader transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) Step 8: Training Process This training loop iterates over the specified number of epochs, training the GAN by alternating between updating the discriminator and the generator:
- For each epoch, it iterates through batches of real images from the DataLoader.
- It trains the discriminator with real images by computing the loss based on real and fake labels, then updates the discriminator’s parameters.
- Next, it generates fake images using random noise and trains the discriminator with them, updating its parameters accordingly.
- Finally, it trains the generator by generating fake images and computing the loss based on discriminator feedback, updating the generator’s parameters.
- Losses are printed periodically to monitor training progress.
# Training loop for epoch in range(NUM_EPOCHS): for i, data in enumerate(train_loader): real_images, _ = data real_images = real_images.to(device)
# Train discriminator with real images discriminator_optimizer.zero_grad() real_labels = torch.ones(real_images.size(0), 1, device=device) real_outputs = discriminator(real_images) real_loss = criterion(real_outputs, real_labels) real_loss.backward()
# Train discriminator with fake images noise = torch.randn(real_images.size(0), NOISE_DIM, device=device) fake_images = generator(noise) fake_labels = torch.zeros(real_images.size(0), 1, device=device) fake_outputs = discriminator(fake_images.detach()) fake_loss = criterion(fake_outputs, fake_labels) fake_loss.backward() discriminator_optimizer.step()
# Train generator generator_optimizer.zero_grad() fake_labels = torch.ones(real_images.size(0), 1, device=device) fake_outputs = discriminator(fake_images) gen_loss = criterion(fake_outputs, fake_labels) gen_loss.backward() generator_optimizer.step()
# Print losses if i % 100 == 0: print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], ' f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, ' f'Generator Loss: {gen_loss.item():.4f}')
Step 9: Visualization Now, we have defined generate_and_save_images to generate fake images using the trained generator model and save them to files:
- It sets the generator to evaluation mode and generates fake images from the given noise vector.
- The generated images are reshaped and plotted in a grid using Matplotlib.
- The function saves the generated images to files named with the epoch number and displays the images.
- Finally, it generates test noise and calls the function to create and save fake images using the trained generator model.
# Generate and save images def generate_and_save_images(model, epoch, noise): model.eval() with torch.no_grad(): fake_images = model(noise).cpu() fake_images = fake_images.view(fake_images.size(0), 28, 28)
fig = plt.figure(figsize=(4, 4)) for i in range(fake_images.size(0)): plt.subplot(4, 4, i+1) plt.imshow(fake_images[i], cmap='gray') plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch+1:04d}.png') plt.show()
# Generate test noise test_noise = torch.randn(16, NOISE_DIM, device=device) generate_and_save_images(generator, NUM_EPOCHS, test_noise) Complete Code and Output:
Python3
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.main = nn.Sequential(
nn.Linear(noise_dim, 7 * 7 * 256),
nn.ReLU(True),
nn.Unflatten(1, (256, 7, 7)),
nn.ConvTranspose2d(256, 128, 5, stride=1, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 5, stride=2, padding=2, output_padding=1),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(1, 64, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, 5, stride=2, padding=2),
nn.LeakyReLU(0.2, inplace=True),
nn.BatchNorm2d(128),
nn.Flatten(),
nn.Linear(7 * 7 * 128, 1)
)
def forward(self, x):
return self.main(x)
# Noise dimension
NOISE_DIM = 100
# Generator and discriminator
generator = Generator(NOISE_DIM)
discriminator = Discriminator()
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = generator.to(device)
discriminator = discriminator.to(device)
# Loss function
criterion = nn.BCEWithLogitsLoss()
# Optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Training hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 256
# DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# Training loop
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(train_loader):
real_images, _ = data
real_images = real_images.to(device)
# Train discriminator with real images
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1, device=device)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()
# Train discriminator with fake images
noise = torch.randn(real_images.size(0), NOISE_DIM, device=device)
fake_images = generator(noise)
fake_labels = torch.zeros(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
discriminator_optimizer.step()
# Train generator
generator_optimizer.zero_grad()
fake_labels = torch.ones(real_images.size(0), 1, device=device)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, fake_labels)
gen_loss.backward()
generator_optimizer.step()
# Print losses
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(train_loader)}], '
f'Discriminator Loss: {real_loss.item() + fake_loss.item():.4f}, '
f'Generator Loss: {gen_loss.item():.4f}')
# Generate and save images
def generate_and_save_images(model, epoch, noise):
model.eval()
with torch.no_grad():
fake_images = model(noise).cpu()
fake_images = fake_images.view(fake_images.size(0), 28, 28)
fig = plt.figure(figsize=(4, 4))
for i in range(fake_images.size(0)):
plt.subplot(4, 4, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch+1:04d}.png')
plt.show()
# Generate test noise
test_noise = torch.randn(16, NOISE_DIM, device=device)
generate_and_save_images(generator, NUM_EPOCHS, test_noise)
Output:
Epoch [1/5], Step [1/235], Discriminator Loss: 1.6305, Generator Loss: 1.0509 Epoch [1/5], Step [101/235], Discriminator Loss: 0.2560, Generator Loss: 4.2435 Epoch [1/5], Step [201/235], Discriminator Loss: 0.2019, Generator Loss: 5.7860 Epoch [2/5], Step [1/235], Discriminator Loss: 0.0429, Generator Loss: 4.2411 Epoch [2/5], Step [101/235], Discriminator Loss: 0.0505, Generator Loss: 4.4958 Epoch [2/5], Step [201/235], Discriminator Loss: 0.0449, Generator Loss: 4.6327 Epoch [3/5], Step [1/235], Discriminator Loss: 0.0257, Generator Loss: 5.1921 Epoch [3/5], Step [101/235], Discriminator Loss: 0.0354, Generator Loss: 5.5234 Epoch [3/5], Step [201/235], Discriminator Loss: 0.0290, Generator Loss: 5.2325 Epoch [4/5], Step [1/235], Discriminator Loss: 0.0104, Generator Loss: 5.6811 Epoch [4/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.6416 Epoch [4/5], Step [201/235], Discriminator Loss: 0.0030, Generator Loss: 6.3280 Epoch [5/5], Step [1/235], Discriminator Loss: 0.0079, Generator Loss: 5.6755 Epoch [5/5], Step [101/235], Discriminator Loss: 0.0097, Generator Loss: 5.9742 Epoch [5/5], Step [201/235], Discriminator Loss: 0.0055, Generator Loss: 6.0514 The output of the image is not clear as the image is trained only for 5 epochs, you can train the image for more number of epochs to get better results.

|