Code icon

The App is Under a Quick Maintenance

We apologize for the inconvenience. Please come back later

Menu iconMenu iconDeep Learning and AI Superhero
Deep Learning and AI Superhero

Chapter 9: Practical Projects

9.5 Project 5: GAN-based Image Generation

Generative Adversarial Networks (GANs) have ushered in a new era in the realm of image generation, revolutionizing the field with their innovative approach. This ambitious project seeks to elevate the original GAN implementation, specifically tailored for generating handwritten digits from the widely-used MNIST dataset.

Our primary objective is to incorporate a series of cutting-edge enhancements designed to significantly boost overall performance, improve training stability, and elevate the quality of generated images to unprecedented levels.

By leveraging state-of-the-art techniques and architectural improvements, we aim to push the boundaries of what's possible with GANs. These enhancements will not only address common challenges associated with GAN training, such as mode collapse and convergence issues, but also introduce novel features that promise to yield more realistic and diverse output.

Through this project, we anticipate demonstrating the full potential of GANs in creating high-fidelity, handwritten digit images that are virtually indistinguishable from their real counterparts.

9.5.1 Enhanced GAN Architecture

To enhance the overall performance and capability of our GAN, we will implement a more intricate and layered architecture for both the generator and discriminator components. This advanced structure will incorporate additional convolutional layers, skip connections, and normalization techniques to improve the network's ability to learn complex features and generate high-quality images. By increasing the depth and sophistication of our models, we aim to capture more nuanced patterns in the data and produce more realistic and detailed handwritten digit images.

import tensorflow as tf
from tensorflow.keras import layers, models

def build_generator(latent_dim):
    model = models.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 256)),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def build_discriminator():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

generator = build_generator(latent_dim=100)
discriminator = build_discriminator()

Let's break it down:

  1. Generator:
  • Takes a latent vector (noise) as input
  • Uses transposed convolutions to upsample the input to a 28x28 image
  • Incorporates batch normalization and LeakyReLU activations for stability and non-linearity
  • Final layer uses tanh activation to produce image-like output
  1. Discriminator:
  • Takes a 28x28 image as input
  • Uses convolutional layers to downsample the input
  • Incorporates LeakyReLU activations and dropout for regularization
  • Final dense layer outputs a single value, representing the probability of the input being real

The architecture is designed to generate and discriminate 28x28 grayscale images, which aligns with the MNIST dataset format. The use of batch normalization, LeakyReLU, and dropout helps in stabilizing the training process and preventing issues like mode collapse.

9.5.2 Wasserstein Loss with Gradient Penalty

To enhance training stability and mitigate mode collapse, we will implement the Wasserstein loss function with gradient penalty. This advanced technique, known as WGAN-GP (Wasserstein GAN with Gradient Penalty), offers several advantages over traditional GAN loss functions.

By utilizing the Wasserstein distance as a measure of dissimilarity between the real and generated data distributions, we can achieve more stable training dynamics and potentially generate higher quality images.

The gradient penalty term further enforces the Lipschitz constraint on the critic (discriminator) function, helping to prevent issues such as vanishing gradients and ensuring a smoother training process. This implementation will contribute significantly to the overall robustness and performance of our GAN model.

import tensorflow as tf

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def gradient_penalty(discriminator, real_images, fake_images):
    alpha = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)
    
    grads = gp_tape.gradient(pred, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

@tf.function
def train_step(images, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
        gp = gradient_penalty(discriminator, images, generated_images)
        disc_loss += 10 * gp
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

 Let's break it down:

  1. Loss Functions:
  • The discriminator_loss function calculates the Wasserstein loss for the discriminator.
  • The generator_loss function calculates the Wasserstein loss for the generator.
  1. Gradient Penalty:
  • The gradient_penalty function implements the gradient penalty, which helps enforce the Lipschitz constraint on the discriminator.
  1. Training Step:
  • The train_step function defines a single training iteration for both the generator and discriminator.
  • It generates fake images, computes losses, applies the gradient penalty, and updates both networks.

This implementation aims to improve training stability and mitigate issues like mode collapse, which are common challenges in GAN training.

9.5.3 Progressive Growing

Implement progressive growing as an advanced technique to gradually increase the resolution and complexity of generated images during the training process. This approach starts with low-resolution images and progressively adds layers to both the generator and discriminator, allowing the model to learn coarse features first before focusing on finer details.

By doing so, we can achieve more stable training dynamics and potentially generate higher quality images at larger resolutions. This method has shown remarkable success in producing highly realistic images and can significantly improve the overall performance of our GAN model for handwritten digit generation.

def build_progressive_generator(latent_dim, target_resolution=28):
    model = models.Sequential()
    model.add(layers.Dense(4*4*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))
    
    current_resolution = 4
    while current_resolution < target_resolution:
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(alpha=0.2))
        current_resolution *= 2
    
    model.add(layers.Conv2D(1, (5, 5), padding='same', use_bias=False, activation='tanh'))
    return model

progressive_generator = build_progressive_generator(latent_dim=100)

Here's a breakdown of the code:

  • The function takes two parameters: latent_dim (the size of the input noise vector) and target_resolution (default is 28, which matches the MNIST image size).
  • It starts by creating a base model with a dense layer that's reshaped to a 4x4x256 tensor, followed by batch normalization and LeakyReLU activation.
  • The core of the progressive growing technique is implemented in the while loop:
    • It keeps adding transposed convolutional layers (upsampling) until the current resolution reaches the target resolution.
    • Each iteration doubles the resolution (e.g., 4x4 → 8x8 → 16x16 → 28x28).
  • Each upsampling step includes a Conv2DTranspose layer, batch normalization, and LeakyReLU activation.
  • The final layer is a Conv2D layer with a tanh activation, which produces the output image.
  • After defining the function, it's used to create a progressive_generator with a latent dimension of 100.

This progressive growing approach allows the model to learn coarse features first before focusing on finer details, potentially leading to more stable training and higher quality generated images.

9.5.4 Spectral Normalization

Implement spectral normalization for the discriminator to enhance training stability and prevent the occurrence of exploding gradients. This technique constrains the Lipschitz constant of the discriminator function, effectively limiting the impact of individual input perturbations on the output.

By applying spectral normalization to the weights of the discriminator's layers, we ensure that the largest singular value of the weight matrices is bounded, leading to more consistent and reliable training dynamics. This approach has been shown to be particularly effective in stabilizing GAN training, especially when dealing with complex architectures or challenging datasets.

The implementation of spectral normalization contributes significantly to the overall robustness of our GAN model, potentially resulting in higher quality generated images and improved convergence characteristics.

from tensorflow.keras.layers import Conv2D, Dense
from tensorflow.keras.constraints import max_norm

class SpectralNormalization(tf.keras.constraints.Constraint):
    def __init__(self, iterations=1):
        self.iterations = iterations
    
    def __call__(self, w):
        w_shape = w.shape.as_list()
        w = tf.reshape(w, [-1, w_shape[-1]])
        u = tf.random.normal([1, w_shape[-1]])
        
        for _ in range(self.iterations):
            v = tf.matmul(u, tf.transpose(w))
            v = v / tf.norm(v)
            u = tf.matmul(v, w)
            u = u / tf.norm(u)
        
        sigma = tf.matmul(tf.matmul(v, w), tf.transpose(u))[0, 0]
        return w / sigma

def SpectralConv2D(filters, kernel_size, **kwargs):
    return Conv2D(filters, kernel_size, kernel_constraint=SpectralNormalization(), **kwargs)

def SpectralDense(units, **kwargs):
    return Dense(units, kernel_constraint=SpectralNormalization(), **kwargs)

Here's a code breakdown:

  • SpectralNormalization class: This is a custom constraint class that applies spectral normalization to the weights of a layer. It works by estimating the spectral norm of the weight matrix and using it to normalize the weights.
  • __call__ method: This method implements the core algorithm of spectral normalization. It uses power iteration to estimate the largest singular value (spectral norm) of the weight matrix and then uses this to normalize the weights.
  • SpectralConv2D and SpectralDense functions: These are wrapper functions that create Conv2D and Dense layers with spectral normalization applied to their kernels. They make it easy to add spectral normalization to a model.

The purpose of spectral normalization is to constrain the Lipschitz constant of the discriminator function in a GAN. This helps prevent exploding gradients and stabilizes the training process, potentially leading to higher quality generated images and improved convergence.

9.5.5 Self-Attention Mechanism

Incorporate a self-attention mechanism to enhance the model's ability to capture global dependencies in the generated images. This advanced technique allows the network to focus on relevant features across different spatial locations, leading to improved coherence and detail in the output.

By implementing self-attention layers in both the generator and discriminator, we enable the model to learn long-range dependencies more effectively, resulting in higher quality and more realistic handwritten digit images. This approach has shown remarkable success in various image generation tasks and promises to significantly boost the performance of our GAN model.

import tensorflow as tf
from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        
        # Conv layers for self-attention
        self.f = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.g = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.h = layers.Conv2D(channels, 1, kernel_initializer='he_normal')

        # Trainable scalar weight gamma
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

    def call(self, x):
        batch_size, height, width, channels = tf.unstack(tf.shape(x))

        # Compute f, g, h transformations
        f = self.f(x)  # Query
        g = self.g(x)  # Key
        h = self.h(x)  # Value

        # Reshape tensors for self-attention calculation
        f_flatten = tf.reshape(f, [batch_size, height * width, -1])  # (B, H*W, C//8)
        g_flatten = tf.reshape(g, [batch_size, height * width, -1])  # (B, H*W, C//8)
        h_flatten = tf.reshape(h, [batch_size, height * width, channels])  # (B, H*W, C)

        # Compute attention scores
        s = tf.matmul(g_flatten, f_flatten, transpose_b=True)  # (B, H*W, H*W)
        beta = tf.nn.softmax(s)  # Attention map (B, H*W, H*W)

        # Apply attention weights to h
        o = tf.matmul(beta, h_flatten)  # (B, H*W, C)
        o = tf.reshape(o, [batch_size, height, width, channels])  # Reshape back

        # Apply self-attention mechanism
        return self.gamma * o + x  # Weighted residual connection

Let's break it down:

  1. The SelfAttention class is a custom layer that inherits from layers.Layer
    • This layer implements self-attention, allowing the model to learn long-range dependencies in an image.
    • Typically used in GANs, image segmentation models, and transformers.
  2. In the __init__ method:
    • Three convolutional layers (fg, and h) are defined, each with a 1x1 kernel.
      • f: Learns query features (reduces dimensionality).
      • g: Learns key features (reduces dimensionality).
      • h: Learns value features (keeps original dimensionality).
    • A trainable parameter gamma is added, initialized to zero, to control the contribution of the attention mechanism.
  3. The call method defines the forward pass:
    • Extracts spatial dimensions dynamically (batch_size, height, width, channels) to ensure compatibility with TensorFlow execution.
    • Computes feature transformations using Conv2D(1x1) convolutions:
      • f(x): Generates the query representation.
      • g(x): Generates the key representation.
      • h(x): Generates the value representation.
    • Computes the attention map:
      • Multiplies g and f (dot product similarity).
      • Applies softmax to normalize the attention scores.
    • Applies the attention map to h (weighted sum of attended features).
    • Uses a residual connection (gamma * o + x) to blend the original input with the attention output.
  4. Why This Matters?
    • This self-attention mechanism allows the model to focus on relevant features across different spatial locations.
    • Particularly useful in image generation tasks (GANs) to improve the quality and coherence of generated images.
    • Helps in capturing long-range dependencies, unlike convolutional layers, which have local receptive fields.

9.5.6 Improved Training Loop

Enhance the training process by implementing an advanced training loop that incorporates dynamic learning rate adjustments and intelligent early stopping mechanisms. This sophisticated approach adapts the learning rate over time to optimize convergence and automatically terminates training when performance plateaus, ensuring efficient use of computational resources and preventing overfitting.

Key features of this improved training loop include:

  • Learning rate scheduling: Utilize adaptive learning rate techniques such as exponential decay or cosine annealing to gradually reduce the learning rate as training progresses, allowing for fine-tuning of model parameters.
  • Early stopping: Implement a patience-based early stopping criterion that monitors a relevant performance metric (e.g., FID score) and halts training if no improvement is observed over a specified number of epochs.
  • Checkpoint saving: Regularly save model checkpoints during training, preserving the best-performing model iterations for later use or evaluation.
  • Progress monitoring: Integrate comprehensive logging and visualization tools to track key metrics, enabling real-time assessment of model performance and training dynamics.
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

# Learning rate schedule
initial_learning_rate = 0.0002
lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True)

# Optimizers
generator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)
discriminator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)

# Number of samples for visualization
num_samples = 16  # Adjust based on needs
LATENT_DIM = 100  # Ensure consistency

# Generate a fixed noise seed for consistent visualization
seed = tf.random.normal([num_samples, LATENT_DIM])

def train(dataset, epochs, batch_size, latent_dim):
    best_fid = float('inf')
    patience = 10
    no_improvement = 0
    
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch, batch_size, latent_dim)
        
        print(f"Epoch {epoch + 1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)
            
            # Generate fake images
            generated_images = generator(seed, training=False)
            
            # Select a batch of real images for FID calculation
            real_images = next(iter(dataset))[:num_samples]

            current_fid = calculate_fid(real_images, generated_images)
            
            if current_fid < best_fid:
                best_fid = current_fid
                no_improvement = 0
                
                # Save model properly
                generator.save(f"generator_epoch_{epoch + 1}.h5")
            else:
                no_improvement += 1
            
            if no_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

# Ensure dataset is properly defined
train(train_dataset, EPOCHS, BATCH_SIZE, LATENT_DIM)

Here's the code breakdown:

  1. Learning Rate Scheduling:
    • Uses an ExponentialDecay schedule to gradually reduce the learning rate, helping fine-tune model parameters.
    • This prevents instability in GAN training by reducing sudden large updates to weights.
  2. Optimizers:
    • Uses Adam optimizers for both the generator and discriminator, with:
      • A decaying learning rate (lr_schedule).
      • beta_1=0.5, which is common in GAN training to stabilize updates.
  3. Training Loop:
    • Iterates through epochs and batches, calling train_step() (not shown) to update the generator and discriminator weights.
    • Each batch update improves the generator’s ability to create more realistic samples and the discriminator’s ability to distinguish real from fake images.
  4. Periodic Evaluation (every 10 epochs):
    • Generates and saves images using a fixed random noise seed to track progression.
    • Calculates the Fréchet Inception Distance (FID) score, a widely used metric for evaluating the quality and diversity of generated images.
  5. Model Saving:
    • Saves the generator model (generator.save()) when a new best FID score is achieved.
    • Helps preserve the best-performing generator instead of just the final epoch.
  6. Early Stopping:
    • If there is no improvement in FID for a set patience of epochs (e.g., 10 epochs), training stops early.
    • Prevents overfitting, saves computation, and stops mode collapse (GAN failure where the generator produces only a few similar images).

9.5.7 Evaluation Metrics

Implement and utilize advanced evaluation metrics to assess the quality and diversity of generated images. Two key metrics we will focus on are:

  1. Fréchet Inception Distance (FID): This metric measures the similarity between real and generated images by comparing their feature representations extracted from a pre-trained Inception network. A lower FID score indicates higher quality and more realistic generated images.
  2. Inception Score (IS): This metric evaluates both the quality and diversity of generated images. It uses a pre-trained Inception network to measure how well the generated images can be classified into distinct categories. A higher Inception Score suggests better quality and more diverse generated images.

By incorporating these metrics into our evaluation process, we can quantitatively assess the performance of our GAN model and track improvements over time. This will provide valuable insights into the effectiveness of our various architectural and training enhancements.

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(real_images, generated_images, batch_size=32):
    """
    Calculates the Fréchet Inception Distance (FID) between real and generated images.
    """
    inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

    def get_features(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        features = inception_model.predict(images, batch_size=batch_size)
        return features

    # Extract features
    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    # Compute mean and covariance of features
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)

    # Compute squared mean difference
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    # Compute sqrt of covariance product (for numerical stability)
    covmean = sqrtm(sigma1.dot(sigma2))

    # Ensure the matrix is real-valued
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Compute final FID score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def calculate_inception_score(images, batch_size=32, splits=10):
    """
    Computes the Inception Score (IS) for generated images.
    """
    inception_model = InceptionV3(include_top=True, weights="imagenet")  # Use full model

    def get_preds(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        preds = inception_model.predict(images, batch_size=batch_size)  # Get logits
        preds = tf.nn.softmax(preds).numpy()  # Convert logits to probabilities
        return preds

    # Get model predictions
    preds = get_preds(images)

    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores)

Let's break down each function:

1. FID (Fréchet Inception Distance)

Compares real vs. generated images to check quality.

Uses InceptionV3 to extract image features.

Measures the difference in feature distributions (mean & covariance).

Lower FID = More realistic images.

2. IS (Inception Score)

Checks quality & diversity of generated images.

Uses InceptionV3 to classify images.

Measures sharpness (confident predictions) and variation (spread across classes).

Higher IS = Better quality & diversity.

9.5.8 Conclusion

This GAN project incorporates several advanced techniques to enhance the quality of generated images and the stability of training. The key improvements include:

  1. A deeper and more sophisticated architecture for both the generator and discriminator.
  2. Wasserstein loss with gradient penalty for improved training stability.
  3. Progressive growing to generate higher resolution images.
  4. Spectral normalization in the discriminator to prevent exploding gradients.
  5. Self-attention mechanism to capture global dependencies in generated images.
  6. An improved training loop with learning rate scheduling and early stopping.
  7. Advanced evaluation metrics (FID and Inception Score) for better assessment of generated image quality.

These enhancements should result in higher quality generated images, more stable training, and a better overall performance of the GAN. Remember to experiment with hyperparameters and architectures to find the optimal configuration for your specific use case.

9.5 Project 5: GAN-based Image Generation

Generative Adversarial Networks (GANs) have ushered in a new era in the realm of image generation, revolutionizing the field with their innovative approach. This ambitious project seeks to elevate the original GAN implementation, specifically tailored for generating handwritten digits from the widely-used MNIST dataset.

Our primary objective is to incorporate a series of cutting-edge enhancements designed to significantly boost overall performance, improve training stability, and elevate the quality of generated images to unprecedented levels.

By leveraging state-of-the-art techniques and architectural improvements, we aim to push the boundaries of what's possible with GANs. These enhancements will not only address common challenges associated with GAN training, such as mode collapse and convergence issues, but also introduce novel features that promise to yield more realistic and diverse output.

Through this project, we anticipate demonstrating the full potential of GANs in creating high-fidelity, handwritten digit images that are virtually indistinguishable from their real counterparts.

9.5.1 Enhanced GAN Architecture

To enhance the overall performance and capability of our GAN, we will implement a more intricate and layered architecture for both the generator and discriminator components. This advanced structure will incorporate additional convolutional layers, skip connections, and normalization techniques to improve the network's ability to learn complex features and generate high-quality images. By increasing the depth and sophistication of our models, we aim to capture more nuanced patterns in the data and produce more realistic and detailed handwritten digit images.

import tensorflow as tf
from tensorflow.keras import layers, models

def build_generator(latent_dim):
    model = models.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 256)),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def build_discriminator():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

generator = build_generator(latent_dim=100)
discriminator = build_discriminator()

Let's break it down:

  1. Generator:
  • Takes a latent vector (noise) as input
  • Uses transposed convolutions to upsample the input to a 28x28 image
  • Incorporates batch normalization and LeakyReLU activations for stability and non-linearity
  • Final layer uses tanh activation to produce image-like output
  1. Discriminator:
  • Takes a 28x28 image as input
  • Uses convolutional layers to downsample the input
  • Incorporates LeakyReLU activations and dropout for regularization
  • Final dense layer outputs a single value, representing the probability of the input being real

The architecture is designed to generate and discriminate 28x28 grayscale images, which aligns with the MNIST dataset format. The use of batch normalization, LeakyReLU, and dropout helps in stabilizing the training process and preventing issues like mode collapse.

9.5.2 Wasserstein Loss with Gradient Penalty

To enhance training stability and mitigate mode collapse, we will implement the Wasserstein loss function with gradient penalty. This advanced technique, known as WGAN-GP (Wasserstein GAN with Gradient Penalty), offers several advantages over traditional GAN loss functions.

By utilizing the Wasserstein distance as a measure of dissimilarity between the real and generated data distributions, we can achieve more stable training dynamics and potentially generate higher quality images.

The gradient penalty term further enforces the Lipschitz constraint on the critic (discriminator) function, helping to prevent issues such as vanishing gradients and ensuring a smoother training process. This implementation will contribute significantly to the overall robustness and performance of our GAN model.

import tensorflow as tf

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def gradient_penalty(discriminator, real_images, fake_images):
    alpha = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)
    
    grads = gp_tape.gradient(pred, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

@tf.function
def train_step(images, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
        gp = gradient_penalty(discriminator, images, generated_images)
        disc_loss += 10 * gp
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

 Let's break it down:

  1. Loss Functions:
  • The discriminator_loss function calculates the Wasserstein loss for the discriminator.
  • The generator_loss function calculates the Wasserstein loss for the generator.
  1. Gradient Penalty:
  • The gradient_penalty function implements the gradient penalty, which helps enforce the Lipschitz constraint on the discriminator.
  1. Training Step:
  • The train_step function defines a single training iteration for both the generator and discriminator.
  • It generates fake images, computes losses, applies the gradient penalty, and updates both networks.

This implementation aims to improve training stability and mitigate issues like mode collapse, which are common challenges in GAN training.

9.5.3 Progressive Growing

Implement progressive growing as an advanced technique to gradually increase the resolution and complexity of generated images during the training process. This approach starts with low-resolution images and progressively adds layers to both the generator and discriminator, allowing the model to learn coarse features first before focusing on finer details.

By doing so, we can achieve more stable training dynamics and potentially generate higher quality images at larger resolutions. This method has shown remarkable success in producing highly realistic images and can significantly improve the overall performance of our GAN model for handwritten digit generation.

def build_progressive_generator(latent_dim, target_resolution=28):
    model = models.Sequential()
    model.add(layers.Dense(4*4*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))
    
    current_resolution = 4
    while current_resolution < target_resolution:
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(alpha=0.2))
        current_resolution *= 2
    
    model.add(layers.Conv2D(1, (5, 5), padding='same', use_bias=False, activation='tanh'))
    return model

progressive_generator = build_progressive_generator(latent_dim=100)

Here's a breakdown of the code:

  • The function takes two parameters: latent_dim (the size of the input noise vector) and target_resolution (default is 28, which matches the MNIST image size).
  • It starts by creating a base model with a dense layer that's reshaped to a 4x4x256 tensor, followed by batch normalization and LeakyReLU activation.
  • The core of the progressive growing technique is implemented in the while loop:
    • It keeps adding transposed convolutional layers (upsampling) until the current resolution reaches the target resolution.
    • Each iteration doubles the resolution (e.g., 4x4 → 8x8 → 16x16 → 28x28).
  • Each upsampling step includes a Conv2DTranspose layer, batch normalization, and LeakyReLU activation.
  • The final layer is a Conv2D layer with a tanh activation, which produces the output image.
  • After defining the function, it's used to create a progressive_generator with a latent dimension of 100.

This progressive growing approach allows the model to learn coarse features first before focusing on finer details, potentially leading to more stable training and higher quality generated images.

9.5.4 Spectral Normalization

Implement spectral normalization for the discriminator to enhance training stability and prevent the occurrence of exploding gradients. This technique constrains the Lipschitz constant of the discriminator function, effectively limiting the impact of individual input perturbations on the output.

By applying spectral normalization to the weights of the discriminator's layers, we ensure that the largest singular value of the weight matrices is bounded, leading to more consistent and reliable training dynamics. This approach has been shown to be particularly effective in stabilizing GAN training, especially when dealing with complex architectures or challenging datasets.

The implementation of spectral normalization contributes significantly to the overall robustness of our GAN model, potentially resulting in higher quality generated images and improved convergence characteristics.

from tensorflow.keras.layers import Conv2D, Dense
from tensorflow.keras.constraints import max_norm

class SpectralNormalization(tf.keras.constraints.Constraint):
    def __init__(self, iterations=1):
        self.iterations = iterations
    
    def __call__(self, w):
        w_shape = w.shape.as_list()
        w = tf.reshape(w, [-1, w_shape[-1]])
        u = tf.random.normal([1, w_shape[-1]])
        
        for _ in range(self.iterations):
            v = tf.matmul(u, tf.transpose(w))
            v = v / tf.norm(v)
            u = tf.matmul(v, w)
            u = u / tf.norm(u)
        
        sigma = tf.matmul(tf.matmul(v, w), tf.transpose(u))[0, 0]
        return w / sigma

def SpectralConv2D(filters, kernel_size, **kwargs):
    return Conv2D(filters, kernel_size, kernel_constraint=SpectralNormalization(), **kwargs)

def SpectralDense(units, **kwargs):
    return Dense(units, kernel_constraint=SpectralNormalization(), **kwargs)

Here's a code breakdown:

  • SpectralNormalization class: This is a custom constraint class that applies spectral normalization to the weights of a layer. It works by estimating the spectral norm of the weight matrix and using it to normalize the weights.
  • __call__ method: This method implements the core algorithm of spectral normalization. It uses power iteration to estimate the largest singular value (spectral norm) of the weight matrix and then uses this to normalize the weights.
  • SpectralConv2D and SpectralDense functions: These are wrapper functions that create Conv2D and Dense layers with spectral normalization applied to their kernels. They make it easy to add spectral normalization to a model.

The purpose of spectral normalization is to constrain the Lipschitz constant of the discriminator function in a GAN. This helps prevent exploding gradients and stabilizes the training process, potentially leading to higher quality generated images and improved convergence.

9.5.5 Self-Attention Mechanism

Incorporate a self-attention mechanism to enhance the model's ability to capture global dependencies in the generated images. This advanced technique allows the network to focus on relevant features across different spatial locations, leading to improved coherence and detail in the output.

By implementing self-attention layers in both the generator and discriminator, we enable the model to learn long-range dependencies more effectively, resulting in higher quality and more realistic handwritten digit images. This approach has shown remarkable success in various image generation tasks and promises to significantly boost the performance of our GAN model.

import tensorflow as tf
from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        
        # Conv layers for self-attention
        self.f = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.g = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.h = layers.Conv2D(channels, 1, kernel_initializer='he_normal')

        # Trainable scalar weight gamma
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

    def call(self, x):
        batch_size, height, width, channels = tf.unstack(tf.shape(x))

        # Compute f, g, h transformations
        f = self.f(x)  # Query
        g = self.g(x)  # Key
        h = self.h(x)  # Value

        # Reshape tensors for self-attention calculation
        f_flatten = tf.reshape(f, [batch_size, height * width, -1])  # (B, H*W, C//8)
        g_flatten = tf.reshape(g, [batch_size, height * width, -1])  # (B, H*W, C//8)
        h_flatten = tf.reshape(h, [batch_size, height * width, channels])  # (B, H*W, C)

        # Compute attention scores
        s = tf.matmul(g_flatten, f_flatten, transpose_b=True)  # (B, H*W, H*W)
        beta = tf.nn.softmax(s)  # Attention map (B, H*W, H*W)

        # Apply attention weights to h
        o = tf.matmul(beta, h_flatten)  # (B, H*W, C)
        o = tf.reshape(o, [batch_size, height, width, channels])  # Reshape back

        # Apply self-attention mechanism
        return self.gamma * o + x  # Weighted residual connection

Let's break it down:

  1. The SelfAttention class is a custom layer that inherits from layers.Layer
    • This layer implements self-attention, allowing the model to learn long-range dependencies in an image.
    • Typically used in GANs, image segmentation models, and transformers.
  2. In the __init__ method:
    • Three convolutional layers (fg, and h) are defined, each with a 1x1 kernel.
      • f: Learns query features (reduces dimensionality).
      • g: Learns key features (reduces dimensionality).
      • h: Learns value features (keeps original dimensionality).
    • A trainable parameter gamma is added, initialized to zero, to control the contribution of the attention mechanism.
  3. The call method defines the forward pass:
    • Extracts spatial dimensions dynamically (batch_size, height, width, channels) to ensure compatibility with TensorFlow execution.
    • Computes feature transformations using Conv2D(1x1) convolutions:
      • f(x): Generates the query representation.
      • g(x): Generates the key representation.
      • h(x): Generates the value representation.
    • Computes the attention map:
      • Multiplies g and f (dot product similarity).
      • Applies softmax to normalize the attention scores.
    • Applies the attention map to h (weighted sum of attended features).
    • Uses a residual connection (gamma * o + x) to blend the original input with the attention output.
  4. Why This Matters?
    • This self-attention mechanism allows the model to focus on relevant features across different spatial locations.
    • Particularly useful in image generation tasks (GANs) to improve the quality and coherence of generated images.
    • Helps in capturing long-range dependencies, unlike convolutional layers, which have local receptive fields.

9.5.6 Improved Training Loop

Enhance the training process by implementing an advanced training loop that incorporates dynamic learning rate adjustments and intelligent early stopping mechanisms. This sophisticated approach adapts the learning rate over time to optimize convergence and automatically terminates training when performance plateaus, ensuring efficient use of computational resources and preventing overfitting.

Key features of this improved training loop include:

  • Learning rate scheduling: Utilize adaptive learning rate techniques such as exponential decay or cosine annealing to gradually reduce the learning rate as training progresses, allowing for fine-tuning of model parameters.
  • Early stopping: Implement a patience-based early stopping criterion that monitors a relevant performance metric (e.g., FID score) and halts training if no improvement is observed over a specified number of epochs.
  • Checkpoint saving: Regularly save model checkpoints during training, preserving the best-performing model iterations for later use or evaluation.
  • Progress monitoring: Integrate comprehensive logging and visualization tools to track key metrics, enabling real-time assessment of model performance and training dynamics.
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

# Learning rate schedule
initial_learning_rate = 0.0002
lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True)

# Optimizers
generator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)
discriminator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)

# Number of samples for visualization
num_samples = 16  # Adjust based on needs
LATENT_DIM = 100  # Ensure consistency

# Generate a fixed noise seed for consistent visualization
seed = tf.random.normal([num_samples, LATENT_DIM])

def train(dataset, epochs, batch_size, latent_dim):
    best_fid = float('inf')
    patience = 10
    no_improvement = 0
    
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch, batch_size, latent_dim)
        
        print(f"Epoch {epoch + 1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)
            
            # Generate fake images
            generated_images = generator(seed, training=False)
            
            # Select a batch of real images for FID calculation
            real_images = next(iter(dataset))[:num_samples]

            current_fid = calculate_fid(real_images, generated_images)
            
            if current_fid < best_fid:
                best_fid = current_fid
                no_improvement = 0
                
                # Save model properly
                generator.save(f"generator_epoch_{epoch + 1}.h5")
            else:
                no_improvement += 1
            
            if no_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

# Ensure dataset is properly defined
train(train_dataset, EPOCHS, BATCH_SIZE, LATENT_DIM)

Here's the code breakdown:

  1. Learning Rate Scheduling:
    • Uses an ExponentialDecay schedule to gradually reduce the learning rate, helping fine-tune model parameters.
    • This prevents instability in GAN training by reducing sudden large updates to weights.
  2. Optimizers:
    • Uses Adam optimizers for both the generator and discriminator, with:
      • A decaying learning rate (lr_schedule).
      • beta_1=0.5, which is common in GAN training to stabilize updates.
  3. Training Loop:
    • Iterates through epochs and batches, calling train_step() (not shown) to update the generator and discriminator weights.
    • Each batch update improves the generator’s ability to create more realistic samples and the discriminator’s ability to distinguish real from fake images.
  4. Periodic Evaluation (every 10 epochs):
    • Generates and saves images using a fixed random noise seed to track progression.
    • Calculates the Fréchet Inception Distance (FID) score, a widely used metric for evaluating the quality and diversity of generated images.
  5. Model Saving:
    • Saves the generator model (generator.save()) when a new best FID score is achieved.
    • Helps preserve the best-performing generator instead of just the final epoch.
  6. Early Stopping:
    • If there is no improvement in FID for a set patience of epochs (e.g., 10 epochs), training stops early.
    • Prevents overfitting, saves computation, and stops mode collapse (GAN failure where the generator produces only a few similar images).

9.5.7 Evaluation Metrics

Implement and utilize advanced evaluation metrics to assess the quality and diversity of generated images. Two key metrics we will focus on are:

  1. Fréchet Inception Distance (FID): This metric measures the similarity between real and generated images by comparing their feature representations extracted from a pre-trained Inception network. A lower FID score indicates higher quality and more realistic generated images.
  2. Inception Score (IS): This metric evaluates both the quality and diversity of generated images. It uses a pre-trained Inception network to measure how well the generated images can be classified into distinct categories. A higher Inception Score suggests better quality and more diverse generated images.

By incorporating these metrics into our evaluation process, we can quantitatively assess the performance of our GAN model and track improvements over time. This will provide valuable insights into the effectiveness of our various architectural and training enhancements.

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(real_images, generated_images, batch_size=32):
    """
    Calculates the Fréchet Inception Distance (FID) between real and generated images.
    """
    inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

    def get_features(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        features = inception_model.predict(images, batch_size=batch_size)
        return features

    # Extract features
    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    # Compute mean and covariance of features
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)

    # Compute squared mean difference
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    # Compute sqrt of covariance product (for numerical stability)
    covmean = sqrtm(sigma1.dot(sigma2))

    # Ensure the matrix is real-valued
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Compute final FID score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def calculate_inception_score(images, batch_size=32, splits=10):
    """
    Computes the Inception Score (IS) for generated images.
    """
    inception_model = InceptionV3(include_top=True, weights="imagenet")  # Use full model

    def get_preds(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        preds = inception_model.predict(images, batch_size=batch_size)  # Get logits
        preds = tf.nn.softmax(preds).numpy()  # Convert logits to probabilities
        return preds

    # Get model predictions
    preds = get_preds(images)

    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores)

Let's break down each function:

1. FID (Fréchet Inception Distance)

Compares real vs. generated images to check quality.

Uses InceptionV3 to extract image features.

Measures the difference in feature distributions (mean & covariance).

Lower FID = More realistic images.

2. IS (Inception Score)

Checks quality & diversity of generated images.

Uses InceptionV3 to classify images.

Measures sharpness (confident predictions) and variation (spread across classes).

Higher IS = Better quality & diversity.

9.5.8 Conclusion

This GAN project incorporates several advanced techniques to enhance the quality of generated images and the stability of training. The key improvements include:

  1. A deeper and more sophisticated architecture for both the generator and discriminator.
  2. Wasserstein loss with gradient penalty for improved training stability.
  3. Progressive growing to generate higher resolution images.
  4. Spectral normalization in the discriminator to prevent exploding gradients.
  5. Self-attention mechanism to capture global dependencies in generated images.
  6. An improved training loop with learning rate scheduling and early stopping.
  7. Advanced evaluation metrics (FID and Inception Score) for better assessment of generated image quality.

These enhancements should result in higher quality generated images, more stable training, and a better overall performance of the GAN. Remember to experiment with hyperparameters and architectures to find the optimal configuration for your specific use case.

9.5 Project 5: GAN-based Image Generation

Generative Adversarial Networks (GANs) have ushered in a new era in the realm of image generation, revolutionizing the field with their innovative approach. This ambitious project seeks to elevate the original GAN implementation, specifically tailored for generating handwritten digits from the widely-used MNIST dataset.

Our primary objective is to incorporate a series of cutting-edge enhancements designed to significantly boost overall performance, improve training stability, and elevate the quality of generated images to unprecedented levels.

By leveraging state-of-the-art techniques and architectural improvements, we aim to push the boundaries of what's possible with GANs. These enhancements will not only address common challenges associated with GAN training, such as mode collapse and convergence issues, but also introduce novel features that promise to yield more realistic and diverse output.

Through this project, we anticipate demonstrating the full potential of GANs in creating high-fidelity, handwritten digit images that are virtually indistinguishable from their real counterparts.

9.5.1 Enhanced GAN Architecture

To enhance the overall performance and capability of our GAN, we will implement a more intricate and layered architecture for both the generator and discriminator components. This advanced structure will incorporate additional convolutional layers, skip connections, and normalization techniques to improve the network's ability to learn complex features and generate high-quality images. By increasing the depth and sophistication of our models, we aim to capture more nuanced patterns in the data and produce more realistic and detailed handwritten digit images.

import tensorflow as tf
from tensorflow.keras import layers, models

def build_generator(latent_dim):
    model = models.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 256)),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def build_discriminator():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

generator = build_generator(latent_dim=100)
discriminator = build_discriminator()

Let's break it down:

  1. Generator:
  • Takes a latent vector (noise) as input
  • Uses transposed convolutions to upsample the input to a 28x28 image
  • Incorporates batch normalization and LeakyReLU activations for stability and non-linearity
  • Final layer uses tanh activation to produce image-like output
  1. Discriminator:
  • Takes a 28x28 image as input
  • Uses convolutional layers to downsample the input
  • Incorporates LeakyReLU activations and dropout for regularization
  • Final dense layer outputs a single value, representing the probability of the input being real

The architecture is designed to generate and discriminate 28x28 grayscale images, which aligns with the MNIST dataset format. The use of batch normalization, LeakyReLU, and dropout helps in stabilizing the training process and preventing issues like mode collapse.

9.5.2 Wasserstein Loss with Gradient Penalty

To enhance training stability and mitigate mode collapse, we will implement the Wasserstein loss function with gradient penalty. This advanced technique, known as WGAN-GP (Wasserstein GAN with Gradient Penalty), offers several advantages over traditional GAN loss functions.

By utilizing the Wasserstein distance as a measure of dissimilarity between the real and generated data distributions, we can achieve more stable training dynamics and potentially generate higher quality images.

The gradient penalty term further enforces the Lipschitz constraint on the critic (discriminator) function, helping to prevent issues such as vanishing gradients and ensuring a smoother training process. This implementation will contribute significantly to the overall robustness and performance of our GAN model.

import tensorflow as tf

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def gradient_penalty(discriminator, real_images, fake_images):
    alpha = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)
    
    grads = gp_tape.gradient(pred, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

@tf.function
def train_step(images, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
        gp = gradient_penalty(discriminator, images, generated_images)
        disc_loss += 10 * gp
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

 Let's break it down:

  1. Loss Functions:
  • The discriminator_loss function calculates the Wasserstein loss for the discriminator.
  • The generator_loss function calculates the Wasserstein loss for the generator.
  1. Gradient Penalty:
  • The gradient_penalty function implements the gradient penalty, which helps enforce the Lipschitz constraint on the discriminator.
  1. Training Step:
  • The train_step function defines a single training iteration for both the generator and discriminator.
  • It generates fake images, computes losses, applies the gradient penalty, and updates both networks.

This implementation aims to improve training stability and mitigate issues like mode collapse, which are common challenges in GAN training.

9.5.3 Progressive Growing

Implement progressive growing as an advanced technique to gradually increase the resolution and complexity of generated images during the training process. This approach starts with low-resolution images and progressively adds layers to both the generator and discriminator, allowing the model to learn coarse features first before focusing on finer details.

By doing so, we can achieve more stable training dynamics and potentially generate higher quality images at larger resolutions. This method has shown remarkable success in producing highly realistic images and can significantly improve the overall performance of our GAN model for handwritten digit generation.

def build_progressive_generator(latent_dim, target_resolution=28):
    model = models.Sequential()
    model.add(layers.Dense(4*4*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))
    
    current_resolution = 4
    while current_resolution < target_resolution:
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(alpha=0.2))
        current_resolution *= 2
    
    model.add(layers.Conv2D(1, (5, 5), padding='same', use_bias=False, activation='tanh'))
    return model

progressive_generator = build_progressive_generator(latent_dim=100)

Here's a breakdown of the code:

  • The function takes two parameters: latent_dim (the size of the input noise vector) and target_resolution (default is 28, which matches the MNIST image size).
  • It starts by creating a base model with a dense layer that's reshaped to a 4x4x256 tensor, followed by batch normalization and LeakyReLU activation.
  • The core of the progressive growing technique is implemented in the while loop:
    • It keeps adding transposed convolutional layers (upsampling) until the current resolution reaches the target resolution.
    • Each iteration doubles the resolution (e.g., 4x4 → 8x8 → 16x16 → 28x28).
  • Each upsampling step includes a Conv2DTranspose layer, batch normalization, and LeakyReLU activation.
  • The final layer is a Conv2D layer with a tanh activation, which produces the output image.
  • After defining the function, it's used to create a progressive_generator with a latent dimension of 100.

This progressive growing approach allows the model to learn coarse features first before focusing on finer details, potentially leading to more stable training and higher quality generated images.

9.5.4 Spectral Normalization

Implement spectral normalization for the discriminator to enhance training stability and prevent the occurrence of exploding gradients. This technique constrains the Lipschitz constant of the discriminator function, effectively limiting the impact of individual input perturbations on the output.

By applying spectral normalization to the weights of the discriminator's layers, we ensure that the largest singular value of the weight matrices is bounded, leading to more consistent and reliable training dynamics. This approach has been shown to be particularly effective in stabilizing GAN training, especially when dealing with complex architectures or challenging datasets.

The implementation of spectral normalization contributes significantly to the overall robustness of our GAN model, potentially resulting in higher quality generated images and improved convergence characteristics.

from tensorflow.keras.layers import Conv2D, Dense
from tensorflow.keras.constraints import max_norm

class SpectralNormalization(tf.keras.constraints.Constraint):
    def __init__(self, iterations=1):
        self.iterations = iterations
    
    def __call__(self, w):
        w_shape = w.shape.as_list()
        w = tf.reshape(w, [-1, w_shape[-1]])
        u = tf.random.normal([1, w_shape[-1]])
        
        for _ in range(self.iterations):
            v = tf.matmul(u, tf.transpose(w))
            v = v / tf.norm(v)
            u = tf.matmul(v, w)
            u = u / tf.norm(u)
        
        sigma = tf.matmul(tf.matmul(v, w), tf.transpose(u))[0, 0]
        return w / sigma

def SpectralConv2D(filters, kernel_size, **kwargs):
    return Conv2D(filters, kernel_size, kernel_constraint=SpectralNormalization(), **kwargs)

def SpectralDense(units, **kwargs):
    return Dense(units, kernel_constraint=SpectralNormalization(), **kwargs)

Here's a code breakdown:

  • SpectralNormalization class: This is a custom constraint class that applies spectral normalization to the weights of a layer. It works by estimating the spectral norm of the weight matrix and using it to normalize the weights.
  • __call__ method: This method implements the core algorithm of spectral normalization. It uses power iteration to estimate the largest singular value (spectral norm) of the weight matrix and then uses this to normalize the weights.
  • SpectralConv2D and SpectralDense functions: These are wrapper functions that create Conv2D and Dense layers with spectral normalization applied to their kernels. They make it easy to add spectral normalization to a model.

The purpose of spectral normalization is to constrain the Lipschitz constant of the discriminator function in a GAN. This helps prevent exploding gradients and stabilizes the training process, potentially leading to higher quality generated images and improved convergence.

9.5.5 Self-Attention Mechanism

Incorporate a self-attention mechanism to enhance the model's ability to capture global dependencies in the generated images. This advanced technique allows the network to focus on relevant features across different spatial locations, leading to improved coherence and detail in the output.

By implementing self-attention layers in both the generator and discriminator, we enable the model to learn long-range dependencies more effectively, resulting in higher quality and more realistic handwritten digit images. This approach has shown remarkable success in various image generation tasks and promises to significantly boost the performance of our GAN model.

import tensorflow as tf
from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        
        # Conv layers for self-attention
        self.f = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.g = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.h = layers.Conv2D(channels, 1, kernel_initializer='he_normal')

        # Trainable scalar weight gamma
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

    def call(self, x):
        batch_size, height, width, channels = tf.unstack(tf.shape(x))

        # Compute f, g, h transformations
        f = self.f(x)  # Query
        g = self.g(x)  # Key
        h = self.h(x)  # Value

        # Reshape tensors for self-attention calculation
        f_flatten = tf.reshape(f, [batch_size, height * width, -1])  # (B, H*W, C//8)
        g_flatten = tf.reshape(g, [batch_size, height * width, -1])  # (B, H*W, C//8)
        h_flatten = tf.reshape(h, [batch_size, height * width, channels])  # (B, H*W, C)

        # Compute attention scores
        s = tf.matmul(g_flatten, f_flatten, transpose_b=True)  # (B, H*W, H*W)
        beta = tf.nn.softmax(s)  # Attention map (B, H*W, H*W)

        # Apply attention weights to h
        o = tf.matmul(beta, h_flatten)  # (B, H*W, C)
        o = tf.reshape(o, [batch_size, height, width, channels])  # Reshape back

        # Apply self-attention mechanism
        return self.gamma * o + x  # Weighted residual connection

Let's break it down:

  1. The SelfAttention class is a custom layer that inherits from layers.Layer
    • This layer implements self-attention, allowing the model to learn long-range dependencies in an image.
    • Typically used in GANs, image segmentation models, and transformers.
  2. In the __init__ method:
    • Three convolutional layers (fg, and h) are defined, each with a 1x1 kernel.
      • f: Learns query features (reduces dimensionality).
      • g: Learns key features (reduces dimensionality).
      • h: Learns value features (keeps original dimensionality).
    • A trainable parameter gamma is added, initialized to zero, to control the contribution of the attention mechanism.
  3. The call method defines the forward pass:
    • Extracts spatial dimensions dynamically (batch_size, height, width, channels) to ensure compatibility with TensorFlow execution.
    • Computes feature transformations using Conv2D(1x1) convolutions:
      • f(x): Generates the query representation.
      • g(x): Generates the key representation.
      • h(x): Generates the value representation.
    • Computes the attention map:
      • Multiplies g and f (dot product similarity).
      • Applies softmax to normalize the attention scores.
    • Applies the attention map to h (weighted sum of attended features).
    • Uses a residual connection (gamma * o + x) to blend the original input with the attention output.
  4. Why This Matters?
    • This self-attention mechanism allows the model to focus on relevant features across different spatial locations.
    • Particularly useful in image generation tasks (GANs) to improve the quality and coherence of generated images.
    • Helps in capturing long-range dependencies, unlike convolutional layers, which have local receptive fields.

9.5.6 Improved Training Loop

Enhance the training process by implementing an advanced training loop that incorporates dynamic learning rate adjustments and intelligent early stopping mechanisms. This sophisticated approach adapts the learning rate over time to optimize convergence and automatically terminates training when performance plateaus, ensuring efficient use of computational resources and preventing overfitting.

Key features of this improved training loop include:

  • Learning rate scheduling: Utilize adaptive learning rate techniques such as exponential decay or cosine annealing to gradually reduce the learning rate as training progresses, allowing for fine-tuning of model parameters.
  • Early stopping: Implement a patience-based early stopping criterion that monitors a relevant performance metric (e.g., FID score) and halts training if no improvement is observed over a specified number of epochs.
  • Checkpoint saving: Regularly save model checkpoints during training, preserving the best-performing model iterations for later use or evaluation.
  • Progress monitoring: Integrate comprehensive logging and visualization tools to track key metrics, enabling real-time assessment of model performance and training dynamics.
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

# Learning rate schedule
initial_learning_rate = 0.0002
lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True)

# Optimizers
generator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)
discriminator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)

# Number of samples for visualization
num_samples = 16  # Adjust based on needs
LATENT_DIM = 100  # Ensure consistency

# Generate a fixed noise seed for consistent visualization
seed = tf.random.normal([num_samples, LATENT_DIM])

def train(dataset, epochs, batch_size, latent_dim):
    best_fid = float('inf')
    patience = 10
    no_improvement = 0
    
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch, batch_size, latent_dim)
        
        print(f"Epoch {epoch + 1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)
            
            # Generate fake images
            generated_images = generator(seed, training=False)
            
            # Select a batch of real images for FID calculation
            real_images = next(iter(dataset))[:num_samples]

            current_fid = calculate_fid(real_images, generated_images)
            
            if current_fid < best_fid:
                best_fid = current_fid
                no_improvement = 0
                
                # Save model properly
                generator.save(f"generator_epoch_{epoch + 1}.h5")
            else:
                no_improvement += 1
            
            if no_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

# Ensure dataset is properly defined
train(train_dataset, EPOCHS, BATCH_SIZE, LATENT_DIM)

Here's the code breakdown:

  1. Learning Rate Scheduling:
    • Uses an ExponentialDecay schedule to gradually reduce the learning rate, helping fine-tune model parameters.
    • This prevents instability in GAN training by reducing sudden large updates to weights.
  2. Optimizers:
    • Uses Adam optimizers for both the generator and discriminator, with:
      • A decaying learning rate (lr_schedule).
      • beta_1=0.5, which is common in GAN training to stabilize updates.
  3. Training Loop:
    • Iterates through epochs and batches, calling train_step() (not shown) to update the generator and discriminator weights.
    • Each batch update improves the generator’s ability to create more realistic samples and the discriminator’s ability to distinguish real from fake images.
  4. Periodic Evaluation (every 10 epochs):
    • Generates and saves images using a fixed random noise seed to track progression.
    • Calculates the Fréchet Inception Distance (FID) score, a widely used metric for evaluating the quality and diversity of generated images.
  5. Model Saving:
    • Saves the generator model (generator.save()) when a new best FID score is achieved.
    • Helps preserve the best-performing generator instead of just the final epoch.
  6. Early Stopping:
    • If there is no improvement in FID for a set patience of epochs (e.g., 10 epochs), training stops early.
    • Prevents overfitting, saves computation, and stops mode collapse (GAN failure where the generator produces only a few similar images).

9.5.7 Evaluation Metrics

Implement and utilize advanced evaluation metrics to assess the quality and diversity of generated images. Two key metrics we will focus on are:

  1. Fréchet Inception Distance (FID): This metric measures the similarity between real and generated images by comparing their feature representations extracted from a pre-trained Inception network. A lower FID score indicates higher quality and more realistic generated images.
  2. Inception Score (IS): This metric evaluates both the quality and diversity of generated images. It uses a pre-trained Inception network to measure how well the generated images can be classified into distinct categories. A higher Inception Score suggests better quality and more diverse generated images.

By incorporating these metrics into our evaluation process, we can quantitatively assess the performance of our GAN model and track improvements over time. This will provide valuable insights into the effectiveness of our various architectural and training enhancements.

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(real_images, generated_images, batch_size=32):
    """
    Calculates the Fréchet Inception Distance (FID) between real and generated images.
    """
    inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

    def get_features(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        features = inception_model.predict(images, batch_size=batch_size)
        return features

    # Extract features
    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    # Compute mean and covariance of features
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)

    # Compute squared mean difference
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    # Compute sqrt of covariance product (for numerical stability)
    covmean = sqrtm(sigma1.dot(sigma2))

    # Ensure the matrix is real-valued
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Compute final FID score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def calculate_inception_score(images, batch_size=32, splits=10):
    """
    Computes the Inception Score (IS) for generated images.
    """
    inception_model = InceptionV3(include_top=True, weights="imagenet")  # Use full model

    def get_preds(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        preds = inception_model.predict(images, batch_size=batch_size)  # Get logits
        preds = tf.nn.softmax(preds).numpy()  # Convert logits to probabilities
        return preds

    # Get model predictions
    preds = get_preds(images)

    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores)

Let's break down each function:

1. FID (Fréchet Inception Distance)

Compares real vs. generated images to check quality.

Uses InceptionV3 to extract image features.

Measures the difference in feature distributions (mean & covariance).

Lower FID = More realistic images.

2. IS (Inception Score)

Checks quality & diversity of generated images.

Uses InceptionV3 to classify images.

Measures sharpness (confident predictions) and variation (spread across classes).

Higher IS = Better quality & diversity.

9.5.8 Conclusion

This GAN project incorporates several advanced techniques to enhance the quality of generated images and the stability of training. The key improvements include:

  1. A deeper and more sophisticated architecture for both the generator and discriminator.
  2. Wasserstein loss with gradient penalty for improved training stability.
  3. Progressive growing to generate higher resolution images.
  4. Spectral normalization in the discriminator to prevent exploding gradients.
  5. Self-attention mechanism to capture global dependencies in generated images.
  6. An improved training loop with learning rate scheduling and early stopping.
  7. Advanced evaluation metrics (FID and Inception Score) for better assessment of generated image quality.

These enhancements should result in higher quality generated images, more stable training, and a better overall performance of the GAN. Remember to experiment with hyperparameters and architectures to find the optimal configuration for your specific use case.

9.5 Project 5: GAN-based Image Generation

Generative Adversarial Networks (GANs) have ushered in a new era in the realm of image generation, revolutionizing the field with their innovative approach. This ambitious project seeks to elevate the original GAN implementation, specifically tailored for generating handwritten digits from the widely-used MNIST dataset.

Our primary objective is to incorporate a series of cutting-edge enhancements designed to significantly boost overall performance, improve training stability, and elevate the quality of generated images to unprecedented levels.

By leveraging state-of-the-art techniques and architectural improvements, we aim to push the boundaries of what's possible with GANs. These enhancements will not only address common challenges associated with GAN training, such as mode collapse and convergence issues, but also introduce novel features that promise to yield more realistic and diverse output.

Through this project, we anticipate demonstrating the full potential of GANs in creating high-fidelity, handwritten digit images that are virtually indistinguishable from their real counterparts.

9.5.1 Enhanced GAN Architecture

To enhance the overall performance and capability of our GAN, we will implement a more intricate and layered architecture for both the generator and discriminator components. This advanced structure will incorporate additional convolutional layers, skip connections, and normalization techniques to improve the network's ability to learn complex features and generate high-quality images. By increasing the depth and sophistication of our models, we aim to capture more nuanced patterns in the data and produce more realistic and detailed handwritten digit images.

import tensorflow as tf
from tensorflow.keras import layers, models

def build_generator(latent_dim):
    model = models.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 256)),
        
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def build_discriminator():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

generator = build_generator(latent_dim=100)
discriminator = build_discriminator()

Let's break it down:

  1. Generator:
  • Takes a latent vector (noise) as input
  • Uses transposed convolutions to upsample the input to a 28x28 image
  • Incorporates batch normalization and LeakyReLU activations for stability and non-linearity
  • Final layer uses tanh activation to produce image-like output
  1. Discriminator:
  • Takes a 28x28 image as input
  • Uses convolutional layers to downsample the input
  • Incorporates LeakyReLU activations and dropout for regularization
  • Final dense layer outputs a single value, representing the probability of the input being real

The architecture is designed to generate and discriminate 28x28 grayscale images, which aligns with the MNIST dataset format. The use of batch normalization, LeakyReLU, and dropout helps in stabilizing the training process and preventing issues like mode collapse.

9.5.2 Wasserstein Loss with Gradient Penalty

To enhance training stability and mitigate mode collapse, we will implement the Wasserstein loss function with gradient penalty. This advanced technique, known as WGAN-GP (Wasserstein GAN with Gradient Penalty), offers several advantages over traditional GAN loss functions.

By utilizing the Wasserstein distance as a measure of dissimilarity between the real and generated data distributions, we can achieve more stable training dynamics and potentially generate higher quality images.

The gradient penalty term further enforces the Lipschitz constraint on the critic (discriminator) function, helping to prevent issues such as vanishing gradients and ensuring a smoother training process. This implementation will contribute significantly to the overall robustness and performance of our GAN model.

import tensorflow as tf

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.reduce_mean(real_output)
    fake_loss = tf.reduce_mean(fake_output)
    return fake_loss - real_loss

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

def gradient_penalty(discriminator, real_images, fake_images):
    alpha = tf.random.uniform([real_images.shape[0], 1, 1, 1], 0.0, 1.0)
    interpolated = alpha * real_images + (1 - alpha) * fake_images
    
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        pred = discriminator(interpolated, training=True)
    
    grads = gp_tape.gradient(pred, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp

@tf.function
def train_step(images, batch_size, latent_dim):
    noise = tf.random.normal([batch_size, latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
        gp = gradient_penalty(discriminator, images, generated_images)
        disc_loss += 10 * gp
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

 Let's break it down:

  1. Loss Functions:
  • The discriminator_loss function calculates the Wasserstein loss for the discriminator.
  • The generator_loss function calculates the Wasserstein loss for the generator.
  1. Gradient Penalty:
  • The gradient_penalty function implements the gradient penalty, which helps enforce the Lipschitz constraint on the discriminator.
  1. Training Step:
  • The train_step function defines a single training iteration for both the generator and discriminator.
  • It generates fake images, computes losses, applies the gradient penalty, and updates both networks.

This implementation aims to improve training stability and mitigate issues like mode collapse, which are common challenges in GAN training.

9.5.3 Progressive Growing

Implement progressive growing as an advanced technique to gradually increase the resolution and complexity of generated images during the training process. This approach starts with low-resolution images and progressively adds layers to both the generator and discriminator, allowing the model to learn coarse features first before focusing on finer details.

By doing so, we can achieve more stable training dynamics and potentially generate higher quality images at larger resolutions. This method has shown remarkable success in producing highly realistic images and can significantly improve the overall performance of our GAN model for handwritten digit generation.

def build_progressive_generator(latent_dim, target_resolution=28):
    model = models.Sequential()
    model.add(layers.Dense(4*4*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((4, 4, 256)))
    
    current_resolution = 4
    while current_resolution < target_resolution:
        model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(layers.BatchNormalization())
        model.add(layers.LeakyReLU(alpha=0.2))
        current_resolution *= 2
    
    model.add(layers.Conv2D(1, (5, 5), padding='same', use_bias=False, activation='tanh'))
    return model

progressive_generator = build_progressive_generator(latent_dim=100)

Here's a breakdown of the code:

  • The function takes two parameters: latent_dim (the size of the input noise vector) and target_resolution (default is 28, which matches the MNIST image size).
  • It starts by creating a base model with a dense layer that's reshaped to a 4x4x256 tensor, followed by batch normalization and LeakyReLU activation.
  • The core of the progressive growing technique is implemented in the while loop:
    • It keeps adding transposed convolutional layers (upsampling) until the current resolution reaches the target resolution.
    • Each iteration doubles the resolution (e.g., 4x4 → 8x8 → 16x16 → 28x28).
  • Each upsampling step includes a Conv2DTranspose layer, batch normalization, and LeakyReLU activation.
  • The final layer is a Conv2D layer with a tanh activation, which produces the output image.
  • After defining the function, it's used to create a progressive_generator with a latent dimension of 100.

This progressive growing approach allows the model to learn coarse features first before focusing on finer details, potentially leading to more stable training and higher quality generated images.

9.5.4 Spectral Normalization

Implement spectral normalization for the discriminator to enhance training stability and prevent the occurrence of exploding gradients. This technique constrains the Lipschitz constant of the discriminator function, effectively limiting the impact of individual input perturbations on the output.

By applying spectral normalization to the weights of the discriminator's layers, we ensure that the largest singular value of the weight matrices is bounded, leading to more consistent and reliable training dynamics. This approach has been shown to be particularly effective in stabilizing GAN training, especially when dealing with complex architectures or challenging datasets.

The implementation of spectral normalization contributes significantly to the overall robustness of our GAN model, potentially resulting in higher quality generated images and improved convergence characteristics.

from tensorflow.keras.layers import Conv2D, Dense
from tensorflow.keras.constraints import max_norm

class SpectralNormalization(tf.keras.constraints.Constraint):
    def __init__(self, iterations=1):
        self.iterations = iterations
    
    def __call__(self, w):
        w_shape = w.shape.as_list()
        w = tf.reshape(w, [-1, w_shape[-1]])
        u = tf.random.normal([1, w_shape[-1]])
        
        for _ in range(self.iterations):
            v = tf.matmul(u, tf.transpose(w))
            v = v / tf.norm(v)
            u = tf.matmul(v, w)
            u = u / tf.norm(u)
        
        sigma = tf.matmul(tf.matmul(v, w), tf.transpose(u))[0, 0]
        return w / sigma

def SpectralConv2D(filters, kernel_size, **kwargs):
    return Conv2D(filters, kernel_size, kernel_constraint=SpectralNormalization(), **kwargs)

def SpectralDense(units, **kwargs):
    return Dense(units, kernel_constraint=SpectralNormalization(), **kwargs)

Here's a code breakdown:

  • SpectralNormalization class: This is a custom constraint class that applies spectral normalization to the weights of a layer. It works by estimating the spectral norm of the weight matrix and using it to normalize the weights.
  • __call__ method: This method implements the core algorithm of spectral normalization. It uses power iteration to estimate the largest singular value (spectral norm) of the weight matrix and then uses this to normalize the weights.
  • SpectralConv2D and SpectralDense functions: These are wrapper functions that create Conv2D and Dense layers with spectral normalization applied to their kernels. They make it easy to add spectral normalization to a model.

The purpose of spectral normalization is to constrain the Lipschitz constant of the discriminator function in a GAN. This helps prevent exploding gradients and stabilizes the training process, potentially leading to higher quality generated images and improved convergence.

9.5.5 Self-Attention Mechanism

Incorporate a self-attention mechanism to enhance the model's ability to capture global dependencies in the generated images. This advanced technique allows the network to focus on relevant features across different spatial locations, leading to improved coherence and detail in the output.

By implementing self-attention layers in both the generator and discriminator, we enable the model to learn long-range dependencies more effectively, resulting in higher quality and more realistic handwritten digit images. This approach has shown remarkable success in various image generation tasks and promises to significantly boost the performance of our GAN model.

import tensorflow as tf
from tensorflow.keras import layers

class SelfAttention(layers.Layer):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.channels = channels
        
        # Conv layers for self-attention
        self.f = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.g = layers.Conv2D(channels // 8, 1, kernel_initializer='he_normal')
        self.h = layers.Conv2D(channels, 1, kernel_initializer='he_normal')

        # Trainable scalar weight gamma
        self.gamma = self.add_weight(name='gamma', shape=(1,), initializer='zeros', trainable=True)

    def call(self, x):
        batch_size, height, width, channels = tf.unstack(tf.shape(x))

        # Compute f, g, h transformations
        f = self.f(x)  # Query
        g = self.g(x)  # Key
        h = self.h(x)  # Value

        # Reshape tensors for self-attention calculation
        f_flatten = tf.reshape(f, [batch_size, height * width, -1])  # (B, H*W, C//8)
        g_flatten = tf.reshape(g, [batch_size, height * width, -1])  # (B, H*W, C//8)
        h_flatten = tf.reshape(h, [batch_size, height * width, channels])  # (B, H*W, C)

        # Compute attention scores
        s = tf.matmul(g_flatten, f_flatten, transpose_b=True)  # (B, H*W, H*W)
        beta = tf.nn.softmax(s)  # Attention map (B, H*W, H*W)

        # Apply attention weights to h
        o = tf.matmul(beta, h_flatten)  # (B, H*W, C)
        o = tf.reshape(o, [batch_size, height, width, channels])  # Reshape back

        # Apply self-attention mechanism
        return self.gamma * o + x  # Weighted residual connection

Let's break it down:

  1. The SelfAttention class is a custom layer that inherits from layers.Layer
    • This layer implements self-attention, allowing the model to learn long-range dependencies in an image.
    • Typically used in GANs, image segmentation models, and transformers.
  2. In the __init__ method:
    • Three convolutional layers (fg, and h) are defined, each with a 1x1 kernel.
      • f: Learns query features (reduces dimensionality).
      • g: Learns key features (reduces dimensionality).
      • h: Learns value features (keeps original dimensionality).
    • A trainable parameter gamma is added, initialized to zero, to control the contribution of the attention mechanism.
  3. The call method defines the forward pass:
    • Extracts spatial dimensions dynamically (batch_size, height, width, channels) to ensure compatibility with TensorFlow execution.
    • Computes feature transformations using Conv2D(1x1) convolutions:
      • f(x): Generates the query representation.
      • g(x): Generates the key representation.
      • h(x): Generates the value representation.
    • Computes the attention map:
      • Multiplies g and f (dot product similarity).
      • Applies softmax to normalize the attention scores.
    • Applies the attention map to h (weighted sum of attended features).
    • Uses a residual connection (gamma * o + x) to blend the original input with the attention output.
  4. Why This Matters?
    • This self-attention mechanism allows the model to focus on relevant features across different spatial locations.
    • Particularly useful in image generation tasks (GANs) to improve the quality and coherence of generated images.
    • Helps in capturing long-range dependencies, unlike convolutional layers, which have local receptive fields.

9.5.6 Improved Training Loop

Enhance the training process by implementing an advanced training loop that incorporates dynamic learning rate adjustments and intelligent early stopping mechanisms. This sophisticated approach adapts the learning rate over time to optimize convergence and automatically terminates training when performance plateaus, ensuring efficient use of computational resources and preventing overfitting.

Key features of this improved training loop include:

  • Learning rate scheduling: Utilize adaptive learning rate techniques such as exponential decay or cosine annealing to gradually reduce the learning rate as training progresses, allowing for fine-tuning of model parameters.
  • Early stopping: Implement a patience-based early stopping criterion that monitors a relevant performance metric (e.g., FID score) and halts training if no improvement is observed over a specified number of epochs.
  • Checkpoint saving: Regularly save model checkpoints during training, preserving the best-performing model iterations for later use or evaluation.
  • Progress monitoring: Integrate comprehensive logging and visualization tools to track key metrics, enabling real-time assessment of model performance and training dynamics.
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

# Learning rate schedule
initial_learning_rate = 0.0002
lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True)

# Optimizers
generator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)
discriminator_optimizer = Adam(learning_rate=lr_schedule, beta_1=0.5)

# Number of samples for visualization
num_samples = 16  # Adjust based on needs
LATENT_DIM = 100  # Ensure consistency

# Generate a fixed noise seed for consistent visualization
seed = tf.random.normal([num_samples, LATENT_DIM])

def train(dataset, epochs, batch_size, latent_dim):
    best_fid = float('inf')
    patience = 10
    no_improvement = 0
    
    for epoch in range(epochs):
        for batch in dataset:
            gen_loss, disc_loss = train_step(batch, batch_size, latent_dim)
        
        print(f"Epoch {epoch + 1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
        
        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch + 1, seed)
            
            # Generate fake images
            generated_images = generator(seed, training=False)
            
            # Select a batch of real images for FID calculation
            real_images = next(iter(dataset))[:num_samples]

            current_fid = calculate_fid(real_images, generated_images)
            
            if current_fid < best_fid:
                best_fid = current_fid
                no_improvement = 0
                
                # Save model properly
                generator.save(f"generator_epoch_{epoch + 1}.h5")
            else:
                no_improvement += 1
            
            if no_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

# Ensure dataset is properly defined
train(train_dataset, EPOCHS, BATCH_SIZE, LATENT_DIM)

Here's the code breakdown:

  1. Learning Rate Scheduling:
    • Uses an ExponentialDecay schedule to gradually reduce the learning rate, helping fine-tune model parameters.
    • This prevents instability in GAN training by reducing sudden large updates to weights.
  2. Optimizers:
    • Uses Adam optimizers for both the generator and discriminator, with:
      • A decaying learning rate (lr_schedule).
      • beta_1=0.5, which is common in GAN training to stabilize updates.
  3. Training Loop:
    • Iterates through epochs and batches, calling train_step() (not shown) to update the generator and discriminator weights.
    • Each batch update improves the generator’s ability to create more realistic samples and the discriminator’s ability to distinguish real from fake images.
  4. Periodic Evaluation (every 10 epochs):
    • Generates and saves images using a fixed random noise seed to track progression.
    • Calculates the Fréchet Inception Distance (FID) score, a widely used metric for evaluating the quality and diversity of generated images.
  5. Model Saving:
    • Saves the generator model (generator.save()) when a new best FID score is achieved.
    • Helps preserve the best-performing generator instead of just the final epoch.
  6. Early Stopping:
    • If there is no improvement in FID for a set patience of epochs (e.g., 10 epochs), training stops early.
    • Prevents overfitting, saves computation, and stops mode collapse (GAN failure where the generator produces only a few similar images).

9.5.7 Evaluation Metrics

Implement and utilize advanced evaluation metrics to assess the quality and diversity of generated images. Two key metrics we will focus on are:

  1. Fréchet Inception Distance (FID): This metric measures the similarity between real and generated images by comparing their feature representations extracted from a pre-trained Inception network. A lower FID score indicates higher quality and more realistic generated images.
  2. Inception Score (IS): This metric evaluates both the quality and diversity of generated images. It uses a pre-trained Inception network to measure how well the generated images can be classified into distinct categories. A higher Inception Score suggests better quality and more diverse generated images.

By incorporating these metrics into our evaluation process, we can quantitatively assess the performance of our GAN model and track improvements over time. This will provide valuable insights into the effectiveness of our various architectural and training enhancements.

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(real_images, generated_images, batch_size=32):
    """
    Calculates the Fréchet Inception Distance (FID) between real and generated images.
    """
    inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

    def get_features(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        features = inception_model.predict(images, batch_size=batch_size)
        return features

    # Extract features
    real_features = get_features(real_images)
    generated_features = get_features(generated_images)

    # Compute mean and covariance of features
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)

    # Compute squared mean difference
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    # Compute sqrt of covariance product (for numerical stability)
    covmean = sqrtm(sigma1.dot(sigma2))

    # Ensure the matrix is real-valued
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Compute final FID score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def calculate_inception_score(images, batch_size=32, splits=10):
    """
    Computes the Inception Score (IS) for generated images.
    """
    inception_model = InceptionV3(include_top=True, weights="imagenet")  # Use full model

    def get_preds(images):
        images = tf.image.resize(images, (299, 299))  # Resize images
        images = preprocess_input(images)  # Normalize to [-1, 1]
        preds = inception_model.predict(images, batch_size=batch_size)  # Get logits
        preds = tf.nn.softmax(preds).numpy()  # Convert logits to probabilities
        return preds

    # Get model predictions
    preds = get_preds(images)

    scores = []
    for i in range(splits):
        part = preds[i * (len(preds) // splits): (i + 1) * (len(preds) // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))

    return np.mean(scores), np.std(scores)

Let's break down each function:

1. FID (Fréchet Inception Distance)

Compares real vs. generated images to check quality.

Uses InceptionV3 to extract image features.

Measures the difference in feature distributions (mean & covariance).

Lower FID = More realistic images.

2. IS (Inception Score)

Checks quality & diversity of generated images.

Uses InceptionV3 to classify images.

Measures sharpness (confident predictions) and variation (spread across classes).

Higher IS = Better quality & diversity.

9.5.8 Conclusion

This GAN project incorporates several advanced techniques to enhance the quality of generated images and the stability of training. The key improvements include:

  1. A deeper and more sophisticated architecture for both the generator and discriminator.
  2. Wasserstein loss with gradient penalty for improved training stability.
  3. Progressive growing to generate higher resolution images.
  4. Spectral normalization in the discriminator to prevent exploding gradients.
  5. Self-attention mechanism to capture global dependencies in generated images.
  6. An improved training loop with learning rate scheduling and early stopping.
  7. Advanced evaluation metrics (FID and Inception Score) for better assessment of generated image quality.

These enhancements should result in higher quality generated images, more stable training, and a better overall performance of the GAN. Remember to experiment with hyperparameters and architectures to find the optimal configuration for your specific use case.