Denoising Diffusion Probabilistic Models

This blogpost is basically my learning notes as I was trying to understand and implement the paper Denoising Diffusion Probabilistic Models. All the resources used in understanding the paper are listed at the end of the post.

0.1 What are diffusion probabilistic models?

ddpm{fig.align-center width=80%}

Before we even define what diffusion models are, let’s first understand what diffusion means in the context of generative models. The term diffusion refers to the process of iteratively transforming an image i.e adding controlled noise to the data. Diffusion models are models that learn to to denoise the pertubed images.

This setup therefore has two components:

  • A forward diffusion process that adds noise to the data
  • A reverse diffusion process that denoises the data

0.2 Forward diffusion process

Forward diffusion process image source{fig.align-center width=50%}

Given a data point X_0 sampled from a distribution q(\mathbf X) (in math, this is written as X_0 \sim q(\mathbf X)), the forward diffusion process gradually adds a small amount of Gaussian noise to the data point at each time step t according to a known variance schedule \beta_t. The variance schedule \beta _t determines the amount of noise added at each time step and can be linear, quadratic etc. The forward diffusion process q(\mathbf x_t | \mathbf x_{t-1}) can be mathematicaly summarized as:

q(\mathbf x_t | \mathbf x_{t-1}) = \mathcal N(\mathbf x_t; \sqrt{1 - \beta_t} \mathbf x_{t-1} , \beta_t \mathbf I)

From the image above, we can quickly break down this equation:

  • q(\mathbf x_t | \mathbf x_{t-1}) represents the probability distribution of the data point at time t given the data point at time step t-1. In other words, it tells us how likely we are to find a data point at time t if we know the data point at time t-1.

  • \mathcal N() means that the probability distribution of the data point at time t given the data point at time step t-1 is a normal distribution with a mean of \sqrt{1 - \beta_t} \mathbf x_{t-1} and a variance of \beta_t \mathbf I.

So at each time step t, \mathbf x_t is generated by taking the previous data point \mathbf x_{t-1}, attenuating it by \sqrt{1 - \beta_t} and adding Gaussian noise with variance \beta_t. Therefore as t increases, \mathbf x_t becomes more and more noisy until it becomes pure Gaussian noise.

0.2.1 Defining the forward diffusion process in code

Let’s try and put some of the evil-looking math above into code.First, we’ll define various variance schedules \beta_t that determine the amount of noise added at each time step. A linear schedule was used in the paper:

{fig.align-center width=60%}

However, the paper Improved Denoising Diffusion Probabilistic Models has shown that using a cosine schedule leads to better results since linear schedules tend to destroy information much more quickly.

import torch
# Defining vaiance schedules
def linear_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
    beta_schedule = torch.linspace(beta_start, beta_end, timesteps)
    return beta_schedule

def quadratic_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
    beta_schedule = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps)
    return beta_schedule

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas = torch.cos((x / timesteps + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas = alphas / alphas[0]
    beta_schedule = 1 - (alphas[1:] / alphas[: -1])
    return torch.clip(input = beta_schedule, min = 0.0001, max = 0.9999)

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
    

Say we are at time step t=0 and we want to obtain an image at time step t = 56, we would have to apply the forward diffusion process 56 times. This is time consuming and computationally expensive. Conveniently, it has been shown that we can sample \mathbf x_t at any time step t as: q(\mathbf x_t | \mathbf x_0) = \mathcal N(\mathbf x_t; \sqrt{\bar \alpha_t} \mathbf x_0, (1 - \bar \alpha_t) \mathbf I) where \bar{\alpha}_t = \prod_{i=1}^T \alpha_i and \alpha_t = 1 - \beta_t.

Let’s put the above into code assuming that T = 300 steps

import matplotlib.pyplot as plt
timesteps = 300

# Define beta schedule
betas = linear_beta_schedule(timesteps)

# Define alphas
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim = 0)

# Calculating q(x_t | x_0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod) # standard deviation

# Function that allows us to extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu()) # get corresponding alphas for given time index
    return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1))).to(t.device) # Bs, 1, 1, 1

# Define forward diffusion q(x_t | x_0)
def q_sample(x_start, t, noise = None):
    if noise is None:
        noise = torch.randn_like(x_start)
    
    # Extract alphas and sqrt_alphas_cumprod for given time index
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) # mean
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) # standard deviation

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    

sqrt_alphas_cumprod.gather(-1, torch.tensor(299)) # get corresponding alphas for given time index
tensor(0.2192)

Now that we have defined the forward diffusion process, let’s see how we can use it to add noise to an image. Noise is added to PyTorch tensors, rather than Pillow Images, so we’ll convert images to tensors as described in the paper

{fig.align-center width=80%}

from torchvision import transforms
from PIL import Image
image = Image.open('images/blog_image.png')

image_size = 128
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(), # convert to tensor of shape (C, H, W) in range [0, 1]
    transforms.Lambda(lambda t: (t * 2) - 1) # rescale to [-1, 1]
])

# Transform image
x_start = transform(image).unsqueeze(0)
x_start.shape, x_start.min(), x_start.max()
(torch.Size([1, 3, 128, 128]), tensor(-0.8196), tensor(1.))

We also define the reverse transform, that takes in a PyTorch tensor in the range [-1, 1] and turns it back into a Pillow image

import numpy as np

reverse_transform = transforms.Compose([
    transforms.Lambda(lambda t: (t + 1) / 2), # rescale to [0, 1]
    transforms.Lambda(lambda t: t.permute(1, 2, 0)), # (C, H, W) to (H, W, C)
    transforms.Lambda(lambda t: t * 255.), # rescale to [0, 255]
    transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), # convert to numpy array of type uint8
    transforms.ToPILImage() # convert to PIL image

])

reverse_transform(x_start.squeeze(0))

# Function that adds noise to a tensor then converts it back to a PIL image
def get_noisy_image(x_start, t):
    # Add noise at time t
    x_noisy = q_sample(x_start, t)

    # Convert to PIL image
    noisy_image = reverse_transform(x_noisy.squeeze(0))
    return noisy_image

# Create time step
t = torch.tensor([50])
# Get noisy image at time t
get_noisy_image(x_start, t)

Now, let’s visualize the forward diffusion process by adding noise to an image at different time steps

# Function that visualizes noisy images at various timesteps
def visualize_noisy_images(imgs, cols,  rows = 1, title = False, **kwargs):
    fig, ax = plt.subplots(rows, cols, figsize = (20, 10))
    for i in range(len(imgs)):
        ax.ravel()[i].imshow(imgs[i], **kwargs)
        if title:
            ax.ravel()[i].set_title(f't = {steps[i]}')
        ax.ravel()[i].axis('off')
    # overall title
    #plt.suptitle('Noisy images at various timesteps', fontsize = 16)
    plt.tight_layout()
    plt.show()

# Define time instance
steps = [0, 50, 100, 150, 200, 250, 299]

# Get noisy images at each time step
imgs = [get_noisy_image(x_start, torch.tensor([t])) for t in steps]

# Visualize noisy images
visualize_noisy_images(imgs, cols = len(imgs), rows = 1)

import torch

# Input tensor
input_tensor = torch.tensor([[1, 2, 3],
                             [4, 5, 6],
                             [7, 8, 9]])

# Index tensor
index_tensor = torch.tensor([[0, 2, 1],
                             [2, 1, 0]])

# Gather along dimension 1
output_tensor = torch.gather(input_tensor, 0, index_tensor)

print(output_tensor)
tensor([[1, 8, 6],
        [7, 5, 3]])

0.3 Reverse diffusion process

After progressively adding noise to an image, if we can then sample from q(\mathbf x_{t-1} | \mathbf x_t), we can reverse the forward process and denoise the image. However, estimating q(\mathbf x_{t-1} | \mathbf x_t) is not trivial since it would require us to know the distribution of all images. This requires us to use a model that can learn to approximate these conditional probabilities to denoise the images. The reverse diffusion process can thus be represented as:

p_{\theta}(\mathbf x_{t-1} | \mathbf x_t) = \mathcal N(\mathbf x_{t-1}; \mu_{\theta}(\mathbf x_t, t), \sum_{\theta}(\mathbf x_t, t))

Note

The little \theta subscript means that are parameters that need to be learned by the model. Therefore, the network has to learn to predict the mean and variance of the distribution of the data point at time t-1 given the data point \mathbf x_t at time t.

In Equation 11 and 12, the authors demonstrate that the mean can be parameterized to make the neural network learn the added noise.

{fig.align-center width=80%}

As a result, the neural network becomes a noise predictor that is optimized using a MSE between the predicted noise \epsilon _\theta and the actual noise \epsilon i.e: Loss_t = ||\epsilon_t - \epsilon_{\theta}(\mathbf x_t, t)||^2

where \epsilon_{\theta}(\mathbf x_t, t) = $_( x_0  +  , t) $

The training procedure thus becomes:

{fig.align-center width=70%}

In simple terms: - we take a random sample \mathbf x_0 from the data distribution q(\mathbf x)

  • we sample a time step t from a uniform distribution

  • we sample noise \mathbf \epsilon from a normal distribution and corrupt the input by the amount of noise found at time step t to obtain \mathbf x_t. The amount of noise is determined by the variance schedule \beta_t.

  • The neural network is trained to predict the added noise \mathbf \epsilon

Jay Allamar illustrates this nicely in the blogpost The Illustrated Stable Diffusion ddpm{fig.align-center width=80%}

This would be a good place to define the loss function as follows:

Note
  • l1_loss: The L1 loss aka MAE is the average absolute difference between the predicted and actual values.

  • l2_loss: The L2 loss aka MSE is the average squared difference between the predicted and actual values.

  • Huber Loss: The huber loss is a combination of the L1 and L2 losses. It uses L1 loss for small errors and L2 loss for large errors. It is less sensitive to outliers than the L2 loss. It is defined as:

\text{huber}{x} = \begin{cases} \frac{1}{2}{y_i - \hat{y_i}}^2 & \text{for } |y_i - \hat{y_i}| \leq \delta \\ \delta (|y_i - \hat{y_i}| - \frac{1}{2} \delta ^2) & \text{otherwise} \end{cases}

import torch.nn.functional as F
def p_losses(denoise_model, x_start, t, noise = None, loss_type = "l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    # Get noisy image
    x_noisy = q_sample(x_start = x_start, t = t, noise = noise)

    # Predict noise
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == "l1":
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == "l2":
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()
    
    return loss

0.4 The Network used in diffusion models

As we have previously seen, a diffusion model is a noise predictor that takes in a noised image and returns a noise prediction. The typical network used in diffusion models is a U-Net. A U-Net is an autoencoder that has skip connections between the encoder and the decoder. The U-Net processes an image by progressively downsampling it in the encoder and upsampling it in the decoder. The skip connections allow the decoder to access the encoder’s feature maps which helps in mitigating the information loss that can occur during downsampling in the encoder.

Okay, let’s now implement these concepts in PyTorch. This is based on Phil Wang’s implementation

0.4.1 Network helpers

import torch
import torch.nn as nn
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from torch import einsum
import math
# Return true if x is not None
def exists(x):
    return x is not None

# Return val if val is not none else d() if d is a function else d
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


# Residual block i.e x + f(x)
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
    
# Upsample block + Conv2d
def Upsample(dim_in, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = "nearest"), # Expects B.s, C, R, Cols
        nn.Conv2d(in_channels = dim_in, out_channels = default(dim_out, dim_in), kernel_size = 3, padding = 'same')
    )

# Downsample block
def Downsample(dim_in, dim_out = None):
    return nn.Sequential(
       Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), # Halves the height and width and multiplies the channels by 4
       nn.Conv2d(in_channels = dim_in * 4, out_channels = default(dim_out, dim_in), kernel_size = 3, padding='same')
    )
# Test the above blocks
x = torch.randn(1, 3, 32, 32)
print("Input shape: ", x.shape)
print("Upsample shape: ", Upsample(3)(x).shape)
print("Downsample shape: ", Downsample(3)(x).shape)
Input shape:  torch.Size([1, 3, 32, 32])
Upsample shape:  torch.Size([1, 3, 64, 64])
Downsample shape:  torch.Size([1, 3, 16, 16])

0.5 Time embeddings

In the forward and reverse diffusion processes, we need to know the time step t at which we are in order to determine the amount of noise to add/remove. This information can be passed to the network using a time embedding of dim dimension and added at the residual blocks. The authors used sinusoidal position embeddings (as used in the Transformer) to provide a continuous and meaningful representation of the time steps. Read more about position embeddings at this fantastic blog. The formula for sinusoidal position embeddings is:

$$ PE_{(pos, 2i)} = sin() \ \

\PE_{(pos, 2i+1)} = cos() $$

where - \text{{PE}}(pos, 2i) represents the (2i)-th dimension of the positional encoding for the word at position pos, - \text{{PE}}(pos, 2i + 1) represents the (2i + 1)-th dimension, - (i) is the dimension index, and - (d) is the dimensionality of the positional encoding.

Let’s put this in code:

# Create a sinusoidal time encoding
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device # cpu or gpu
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1) # scale factor
        embeddings = torch.exp(torch.arange(half_dim, device = device) * -embeddings) # Decreasing exponential frequencies
        embeddings = time[:, None] * embeddings[None, :] # Multiply the frequencies with the time (T, 1) x (1, D/2) = (T, D/2)
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim = -1) # Concatenate the sin and cos embeddings (T, D)
        return embeddings

This results in a time embedding which contains pairs of sine and cosine values for each time step. The time embedding is then added to the input image at each residual block.

# Test the above class
dim = 32
time_steps = torch.arange(0, 10).float()
print("Time steps: ", time_steps)
half_dim = dim // 2
print("Half dim: ", half_dim)
embeddings = math.log(10000) / (half_dim - 1)
print("Embeddings: ", embeddings)
embeddings = torch.exp(torch.arange(half_dim) * -embeddings)
print("Embeddings: ", embeddings)
embeddings = torch.matmul(time_steps[:, None], embeddings[None, :])
print("Embeddings shape: ", embeddings.shape)
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim = -1)
print("Embeddings shape: ", embeddings.shape)
print("Embeddings[0]: ", embeddings[0])

Time steps:  tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
Half dim:  16
Embeddings:  0.6140226914650789
Embeddings:  tensor([1.0000e+00, 5.4117e-01, 2.9286e-01, 1.5849e-01, 8.5770e-02, 4.6416e-02,
        2.5119e-02, 1.3594e-02, 7.3564e-03, 3.9811e-03, 2.1544e-03, 1.1659e-03,
        6.3096e-04, 3.4146e-04, 1.8479e-04, 1.0000e-04])
Embeddings shape:  torch.Size([10, 16])
Embeddings shape:  torch.Size([10, 32])
Embeddings[0]:  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

0.6 Residual blocks of the U-Net

Phil Wang implements a weight standardized ResNet which as been observed to work better with group normalization (see Kolesnikov et al.). Group Normalization is a normalization layer that divides channels into groups and normalizes the features within each group. Weight Standardization is a normalization technique that smooths the loss landscape by standardizing the weights in convolutional layers Qiao et al..

weight standardization
# Weight standardization: standardize the weights of the conv layers
class WeightStandardizedConv2d(nn.Conv2d):
    # (https://arxiv.org/abs/1912.11370
    def __init__(self, in_channels, out_channels, kernel_size, padding = 'same'):
        super().__init__(in_channels, out_channels, kernel_size, padding = padding)
    
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        weight = self.weight # Weight of kernel of shape (output_chan, input_chan, kernel_size, kernel_size)
        mean = reduce(weight, "o ... -> o 1 1 1", "mean") # Mean of the kernel (output_chan, input_chan, kernel_size, kernel_size) -> (output_chan, 1, 1, 1)
        var = reduce(weight, "o ... -> o 1 1 1", torch.var) # Partial ensures that the function is called with the unbiased parameter set to False
        normalized_weight = (weight - mean) / (var + eps).sqrt() # Normalize the weights
        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        )

# Weight standardization conv + Group norm + SiLU    
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.conv = WeightStandardizedConv2d(in_channels = dim, out_channels = dim_out, kernel_size = 3)
        # Separate dim_out channels into groups
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        # Weight standardization + Group norm - Perfect combo
        x = self.conv(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift # Incorporate the time embedding
            
        x = self.act(x)
        return x
    
class ResnetBlock(nn.Module):

    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(in_features = time_emb_dim, out_features = dim_out * 2)) if exists(time_emb_dim) else None
        )

        self.block1 = Block(dim, dim_out, groups = groups) # Weight standardization conv + Group norm + SiLU
        self.block2 = Block(dim_out, dim_out, groups = groups) # Weight standardization conv + Group norm + SiLU
        self.res_conv = nn.Conv2d(in_channels = dim, out_channels = dim_out, kernel_size = 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(time_emb) and exists(self.mlp):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1) # Split the time embedding into 2 parts [b c/2, 1, 1] [b c/2, 1, 1]

        h = self.block1(x, scale_shift = scale_shift) #
        h = self.block2(h)
        return self.res_conv(x) + h # Residual connection

Note

In the context of convolutional neural networks (CNNs), a “kernel size of 1” refers to using a 1x1 convolutional layer.

A convolutional layer with a kernel size of 1x1 operates on local input patches of size 1x1. This might seem counterintuitive at first because a convolutional layer with a larger kernel size is often used to capture spatial hierarchies and patterns. However, a 1x1 convolutional layer has some specific use cases and advantages:

Channel-wise Transformation: A 1x1 convolution is applied independently to each element in the input tensor along the spatial dimensions. It doesn’t consider neighboring elements in the same spatial location. Instead, it focuses on transforming the input tensor along the channel dimension.

0.7 Attention module

The DDPM authors used attention within the convolutional blocks:

Phil Wang implements two variants, self-attention as used in the Transformer and linear attention as used in Shen et al whose memory and computational costs grow linearly with input size as opposed to quadratically as in the Transformer.

types of attention
# Attention block
class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_heads = 32):
        super().__init__()
        self.scale = dim_heads ** -0.5 # scale attntion by sqrt(dim_heads)
        self.heads = heads
        hidden_dim = dim_heads * heads # Concatenate the heads
        self.to_qkv = nn.Conv2d(in_channels = dim, out_channels = hidden_dim * 3, kernel_size = 1, bias = False)
        self.to_out = nn.Conv2d(in_channels = hidden_dim, out_channels = dim, kernel_size = 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1) # Split the output into 3 parts [b, c, h, w] -> [b, c/3, h, w] x 3
        q, k, v = map(
            lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv
        )

        q = q * self.scale # Scale the query

        # Matrix multiplication between query and key
        #sim = einsum('b h d i, b h d j -> b h i j', q, k)
        sim = torch.matmul(q.transpose(-1, -2), k) # Make query n x dk and key n x dk
        sim = sim - sim.amax(dim = -1, keepdim = True).detach()

        # Softmax to get the attention weights
        attn = sim.softmax(dim = -1)

        # Matrix multiplication between attention weights and value
        #out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = torch.matmul(attn, v.transpose(-1, -2)) # B x H x n x n * B x H x n x head_size -> B x H x n x head_size  
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) # Rearrange/concaetenate to B x C x R x Cols as expected by nn.Conv2d
        return self.to_out(out)


# Linear attention block
class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(in_channels = dim, out_channels = hidden_dim * 3, kernel_size=1, bias=False)
        self.to_out = nn.Sequential(
            nn.Conv2d(in_channels = hidden_dim, out_channels = dim, kernel_size = 1),
            nn.GroupNorm(1, dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)  # Split the output into 3 parts [b, c, h, w] -> [b, c/3, h, w] x 3
        q, k, v = map(
            lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv
        )

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
        # context = torch.matmul(k.transpose(-1, -2), v) # Make key dk x n and value n x dv

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        # out = torch.matmul(context, q) # B x H x dk x dv * B x H x dq x n -> B x H x dk x n 
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

0.8 Pre-normalization

Recent research has showed that pre-normalization in Transformers leads to well-behaved gradients and faster convergence: see Xiong et al, Nguyen et al.

Xiong et al.

In this implementation, pre-normalization will be used to apply group normalization before the attention layer.

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.GroupNorm(1, dim) # Similar to instance norm
        self.fn = fn

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

0.9 U-Net

Now let’s create a U-Net that takes in noisy images and their respective noise levels and predicts the noise that was added to the image.

init_dim = 3
dim_mults = (1, 2, 4, 8)
dims = [init_dim, *map(lambda m: m * init_dim, dim_mults)]
dims, dims[:-1], dims[1:]
([3, 3, 6, 12, 24], [3, 3, 6, 12], [3, 6, 12, 24])
in_out = list(zip(dims[:-1], dims[1:]))
print("in_out: ", in_out)

for ind, (dim_in, dim_out) in enumerate(in_out):
    is_last = ind >= len(in_out) - 1
    print("dim_in, dim_out, is_last, ind: ", dim_in, dim_out, is_last, ind)
in_out:  [(3, 3), (3, 6), (6, 12), (12, 24)]
dim_in, dim_out, is_last, ind:  3 3 False 0
dim_in, dim_out, is_last, ind:  3 6 False 1
dim_in, dim_out, is_last, ind:  6 12 False 2
dim_in, dim_out, is_last, ind:  12 24 True 3
class Unet(nn.Module):
    def __init__(
            self, 
            dim,
            init_dim = None,
            out_dim = None,
            dim_mults = (1, 2, 4, 8),
            channels = 3,
            self_condition = False,
            resnet_block_groups = 4,
    ):
        super().__init__()

        # Determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(in_channels = input_channels, out_channels = init_dim, kernel_size = 1) # Only change the number of channels

        dims = [init_dim, *map(lambda m: m * init_dim, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # Time embeddings
        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(in_features = dim, out_features = time_dim),
            nn.GELU(),
            nn.Linear(in_features = time_dim, out_features = time_dim),
        )

        # Layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind == (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList([
                    block_klass(dim = dim_in, dim_out = dim_in, time_emb_dim = time_dim), # Resnet block
                    block_klass(dim = dim_in, dim_out = dim_in, time_emb_dim = time_dim), # Resnet block
                    Residual(PreNorm(dim_in, LinearAttention(dim_in))), # Pre-norm Linear attention block + Residual
                    Downsample(dim_in = dim_in, dim_out = dim_out) if not is_last else nn.Conv2d(in_channels = dim_in, out_channels = dim_out, kernel_size = 3, padding = 'same')
                ])
                    
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(dim = mid_dim, dim_out = mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(dim = mid_dim)))
        self.mid_block2 = block_klass(dim = mid_dim, dim_out = mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList([
                    block_klass(dim = dim_out + dim_in, dim_out = dim_out, time_emb_dim = time_dim),
                    block_klass(dim = dim_out + dim_in, dim_out = dim_out, time_emb_dim = time_dim),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    Upsample(dim_in = dim_out, dim_out = dim_in) if not is_last else nn.Conv2d(in_channels = dim_out, out_channels = dim_in, kernel_size = 3, padding = 'same')


                ])
            )
        
        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim = dim * 2, dim_out = dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv2d(in_channels = dim, out_channels = self.out_dim, kernel_size = 1)

    def forward(self, x, time, x_self_cond = None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x, x_self_cond), dim = 1)

        # Make the input have desired number of channels
        x = self.init_conv(x)
        # Return a copy of x
        r = x.clone()

        # Create time embeddings
        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        
        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)
        
    



        

Now that we have defined our noise predictor, the training process

0.10 Define a PyTorch Dataset + DataLoader

Now, let’s define a PyTorch Dataset and DataLoader.

from datasets import load_dataset

# Load dataset from HuggingFace hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 32
from torchvision import transforms
from torch.utils.data import DataLoader

# Define transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1) # rescale to [-1, 1]
])

# Define function
def transforms(examples):
    examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
    del examples["image"]

    return examples

# On the fly transforms
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# Create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size = batch_size, shuffle = True)

0.11 Sampling from the model

So how do we use the model to generate images? The paper says the following:

\epsilon _ {\theta} (x_t, t) refers to the noise predicted by the trained model. The authors showed that the variance \sigma_t ^2 can be set as:

The entire process can be summarized as:

sampling

Generating new images from a diffusion model involves:

  • sample pure noise \mathbf x_T from a Gaussian distribution
  • for each time step t from T to 0, use the noise predictor to denoise the image and obtain \mathbf x_{t-1}

Ideally, as Jay Allamar put it, we get an image that’s closer to the images the model was trained on (not the exact images themselves, but the distribution - the world of pixel arrangements where the sky is usually blue and above the ground, people have two eyes, cats look a certain way – pointy ears and clearly unimpressed):

# Creating the sampling procedure
sqrt_recip_alphas = torch.sqrt( 1.0 / alphas)

# Define posterior variance i.e equation 7
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.0) # padding of 1 along the first dimension and no padding along the second dimension
posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    # Impelemting equation 11 or equation 4 in the sampling section
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise
    

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # Start from pure noise (for each example in the batch)
    img = torch.randn(shape, device = device)
    imgs = []
    for i in reversed(range(timesteps)):
        img = p_sample(model, img, torch.full((b,), i, device = device, dtype = torch.long), i) # x_t-1
        imgs.append(img.cpu().numpy())
    
    return np.array(imgs)

@torch.no_grad()
def sample(model, image_size, batch_size = 16, channels = 3):
    return p_sample_loop(model, shape = (batch_size, channels, image_size, image_size))

0.12 Train the model

Now let’s train the model and sample from it

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000
num_to_groups(4, 32)
[4]
# Model setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Unet(
    dim = image_size,
    channels = channels,
    dim_mults = (1, 2, 4)

)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

# Model training
from torchvision.utils import save_image
epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        batch_size = batch["pixel_values"].shape[0]
        batch = batch["pixel_values"].to(device)
        
        # Sample time steps uniformly for every example in batch
        t = torch.randint(0, timesteps, (batch_size,), device = device, dtype = torch.long)

        loss = p_losses(model, batch, t, loss_type = "huber")

        if step % 500 == 0:
            print(f"Epoch: {epoch} Step: {step} Loss: {loss.item()}")
        
        # Zero out gradientse
        optimizer.zero_grad()

        # Backpropagate
        loss.backward()

        # Update weights
        optimizer.step()

        # Save generated images
        # if step !=0 and step % save_and_sample_every == 0:
        #     milestone = step // save_and_sample_every
        #     batches = num_to_groups(num = 4, divisor = batch_size)
        #     all_images_list = list(map(lambda n: sample(model, image_size, batch_size = n, channels = channels), batches))
        #     all_images = torch.cat(all_images_list, dim = 0)
        #     all_images = (all_images + 1) * 0.5 # rescale to [0, 1]
        #     save_image(all_images, str(results_folder / f'sample_{milestone}.png'), nrow = 6)
Epoch: 0 Step: 0 Loss: 0.5224699378013611
Epoch: 0 Step: 500 Loss: 0.041550733149051666
Epoch: 0 Step: 1000 Loss: 0.04898069426417351
Epoch: 0 Step: 1500 Loss: 0.04184384271502495
Epoch: 1 Step: 0 Loss: 0.05208709090948105
Epoch: 1 Step: 500 Loss: 0.031184278428554535
Epoch: 1 Step: 1000 Loss: 0.03631388023495674
Epoch: 1 Step: 1500 Loss: 0.03278825432062149
Epoch: 2 Step: 0 Loss: 0.026345837861299515
Epoch: 2 Step: 500 Loss: 0.027054786682128906
Epoch: 2 Step: 1000 Loss: 0.04869785159826279
Epoch: 2 Step: 1500 Loss: 0.023147860541939735
Epoch: 3 Step: 0 Loss: 0.027534985914826393
Epoch: 3 Step: 500 Loss: 0.036470092833042145
Epoch: 3 Step: 1000 Loss: 0.02620980143547058
Epoch: 3 Step: 1500 Loss: 0.04134272783994675
Epoch: 4 Step: 0 Loss: 0.025678491219878197
Epoch: 4 Step: 500 Loss: 0.02390681765973568
Epoch: 4 Step: 1000 Loss: 0.03659454733133316
Epoch: 4 Step: 1500 Loss: 0.02753089740872383
Epoch: 5 Step: 0 Loss: 0.04199949651956558
Epoch: 5 Step: 500 Loss: 0.02578447200357914
Epoch: 5 Step: 1000 Loss: 0.03484490513801575
Epoch: 5 Step: 1500 Loss: 0.034843917936086655

0.13 Sampling from the model

To generate new images from the model, we need to sample pure noise \mathbf x_T from a Gaussian distribution and then use the noise predictor to denoise the image and obtain \mathbf x_{t-1} for each time step t from T to 0.

# Sample 8 images
sampled_images = sample(model, image_size = image_size, batch_size = 8, channels = 1)

# Transpose to (B, H, W, C) as expected by matplotlib
sampled_images_c = sampled_images[-1].transpose(0, 2, 3, 1) # Get images at time step 0

# Visualize sampled images
visualize_noisy_images(sampled_images_c, cols = 4, rows = 2, cmap = 'gray')

The above shows the result of sampling starting from pure noise i.e torch.randn((BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)) where BATCH_SIZE = 8. Not bad, at least the model can generate images that look like clothing. Let’s create a gif of the denoising process:

import matplotlib.animation as animation

image_index = 3

fig = plt.figure()
ims = []

for i in range(timesteps):
    im = plt.imshow(sampled_images[i][image_index].transpose(1, 2, 0), animated = True, cmap = 'gray')
    # Add title
    #plt.title(f"Time step: {i}")
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval = 50, blit = True, repeat_delay = 1000)
animate.save('diffusion.gif')
plt.close()

sampling

1 Resources