import torch
# Defining vaiance schedules
def linear_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
= torch.linspace(beta_start, beta_end, timesteps)
beta_schedule return beta_schedule
def quadratic_beta_schedule(timesteps, beta_start = 0.0001, beta_end = 0.02):
= torch.linspace(beta_start**0.5, beta_end**0.5, timesteps)
beta_schedule return beta_schedule
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
"""
= timesteps + 1
steps = torch.linspace(0, timesteps, steps)
x = torch.cos((x / timesteps + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas = alphas / alphas[0]
alphas = 1 - (alphas[1:] / alphas[: -1])
beta_schedule return torch.clip(input = beta_schedule, min = 0.0001, max = 0.9999)
def sigmoid_beta_schedule(timesteps):
= 0.0001
beta_start = 0.02
beta_end = torch.linspace(-6, 6, timesteps)
betas return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
Denoising Diffusion Probabilistic Models
This blogpost is basically my learning notes as I was trying to understand and implement the paper Denoising Diffusion Probabilistic Models. All the resources used in understanding the paper are listed at the end of the post.
0.1 What are diffusion probabilistic models?
{fig.align-center width=80%}
Before we even define what diffusion models are, let’s first understand what diffusion means in the context of generative models. The term diffusion refers to the process of iteratively transforming an image i.e adding controlled noise to the data. Diffusion models are models that learn to to denoise the pertubed images.
This setup therefore has two components:
- A forward diffusion process that adds noise to the data
- A reverse diffusion process that denoises the data
0.2 Forward diffusion process
{fig.align-center width=50%}
Given a data point X_0 sampled from a distribution q(\mathbf X) (in math, this is written as X_0 \sim q(\mathbf X)), the forward diffusion process gradually adds a small amount of Gaussian noise to the data point at each time step t according to a known variance schedule \beta_t. The variance schedule \beta _t determines the amount of noise added at each time step and can be linear, quadratic etc. The forward diffusion process q(\mathbf x_t | \mathbf x_{t-1}) can be mathematicaly summarized as:
q(\mathbf x_t | \mathbf x_{t-1}) = \mathcal N(\mathbf x_t; \sqrt{1 - \beta_t} \mathbf x_{t-1} , \beta_t \mathbf I)
From the image above, we can quickly break down this equation:
q(\mathbf x_t | \mathbf x_{t-1}) represents the probability distribution of the data point at time t given the data point at time step t-1. In other words, it tells us how likely we are to find a data point at time t if we know the data point at time t-1.
\mathcal N() means that the probability distribution of the data point at time t given the data point at time step t-1 is a normal distribution with a mean of \sqrt{1 - \beta_t} \mathbf x_{t-1} and a variance of \beta_t \mathbf I.
So at each time step t, \mathbf x_t is generated by taking the previous data point \mathbf x_{t-1}, attenuating it by \sqrt{1 - \beta_t} and adding Gaussian noise with variance \beta_t. Therefore as t increases, \mathbf x_t becomes more and more noisy until it becomes pure Gaussian noise.
0.2.1 Defining the forward diffusion process in code
Let’s try and put some of the evil-looking math above into code.First, we’ll define various variance schedules \beta_t that determine the amount of noise added at each time step. A linear schedule was used in the paper:
{fig.align-center width=60%}
However, the paper Improved Denoising Diffusion Probabilistic Models has shown that using a cosine schedule leads to better results since linear schedules tend to destroy information much more quickly.
Say we are at time step t=0 and we want to obtain an image at time step t = 56, we would have to apply the forward diffusion process 56 times. This is time consuming and computationally expensive. Conveniently, it has been shown that we can sample \mathbf x_t at any time step t as: q(\mathbf x_t | \mathbf x_0) = \mathcal N(\mathbf x_t; \sqrt{\bar \alpha_t} \mathbf x_0, (1 - \bar \alpha_t) \mathbf I) where \bar{\alpha}_t = \prod_{i=1}^T \alpha_i and \alpha_t = 1 - \beta_t.
Let’s put the above into code assuming that T = 300 steps
import matplotlib.pyplot as plt
= 300
timesteps
# Define beta schedule
= linear_beta_schedule(timesteps)
betas
# Define alphas
= 1 - betas
alphas = torch.cumprod(alphas, dim = 0)
alphas_cumprod
# Calculating q(x_t | x_0)
= torch.sqrt(alphas_cumprod)
sqrt_alphas_cumprod = torch.sqrt(1 - alphas_cumprod) # standard deviation
sqrt_one_minus_alphas_cumprod
# Function that allows us to extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
= t.shape[0]
batch_size = a.gather(-1, t.cpu()) # get corresponding alphas for given time index
out return out.reshape(batch_size, *((1, ) * (len(x_shape) - 1))).to(t.device) # Bs, 1, 1, 1
# Define forward diffusion q(x_t | x_0)
def q_sample(x_start, t, noise = None):
if noise is None:
= torch.randn_like(x_start)
noise
# Extract alphas and sqrt_alphas_cumprod for given time index
= extract(sqrt_alphas_cumprod, t, x_start.shape) # mean
sqrt_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) # standard deviation
sqrt_one_minus_alphas_cumprod_t
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
-1, torch.tensor(299)) # get corresponding alphas for given time index sqrt_alphas_cumprod.gather(
tensor(0.2192)
Now that we have defined the forward diffusion process, let’s see how we can use it to add noise to an image. Noise is added to PyTorch tensors, rather than Pillow Images, so we’ll convert images to tensors as described in the paper
{fig.align-center width=80%}
from torchvision import transforms
from PIL import Image
= Image.open('images/blog_image.png')
image
= 128
image_size = transforms.Compose([
transform
transforms.Resize(image_size),
transforms.CenterCrop(image_size),# convert to tensor of shape (C, H, W) in range [0, 1]
transforms.ToTensor(), lambda t: (t * 2) - 1) # rescale to [-1, 1]
transforms.Lambda(
])
# Transform image
= transform(image).unsqueeze(0)
x_start min(), x_start.max() x_start.shape, x_start.
(torch.Size([1, 3, 128, 128]), tensor(-0.8196), tensor(1.))
We also define the reverse transform, that takes in a PyTorch tensor in the range [-1, 1] and turns it back into a Pillow image
import numpy as np
= transforms.Compose([
reverse_transform lambda t: (t + 1) / 2), # rescale to [0, 1]
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # (C, H, W) to (H, W, C)
transforms.Lambda(lambda t: t * 255.), # rescale to [0, 255]
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), # convert to numpy array of type uint8
transforms.Lambda(# convert to PIL image
transforms.ToPILImage()
])
0)) reverse_transform(x_start.squeeze(
# Function that adds noise to a tensor then converts it back to a PIL image
def get_noisy_image(x_start, t):
# Add noise at time t
= q_sample(x_start, t)
x_noisy
# Convert to PIL image
= reverse_transform(x_noisy.squeeze(0))
noisy_image return noisy_image
# Create time step
= torch.tensor([50])
t # Get noisy image at time t
get_noisy_image(x_start, t)
Now, let’s visualize the forward diffusion process by adding noise to an image at different time steps
# Function that visualizes noisy images at various timesteps
def visualize_noisy_images(imgs, cols, rows = 1, title = False, **kwargs):
= plt.subplots(rows, cols, figsize = (20, 10))
fig, ax for i in range(len(imgs)):
**kwargs)
ax.ravel()[i].imshow(imgs[i], if title:
f't = {steps[i]}')
ax.ravel()[i].set_title('off')
ax.ravel()[i].axis(# overall title
#plt.suptitle('Noisy images at various timesteps', fontsize = 16)
plt.tight_layout()
plt.show()
# Define time instance
= [0, 50, 100, 150, 200, 250, 299]
steps
# Get noisy images at each time step
= [get_noisy_image(x_start, torch.tensor([t])) for t in steps]
imgs
# Visualize noisy images
= len(imgs), rows = 1) visualize_noisy_images(imgs, cols
import torch
# Input tensor
= torch.tensor([[1, 2, 3],
input_tensor 4, 5, 6],
[7, 8, 9]])
[
# Index tensor
= torch.tensor([[0, 2, 1],
index_tensor 2, 1, 0]])
[
# Gather along dimension 1
= torch.gather(input_tensor, 0, index_tensor)
output_tensor
print(output_tensor)
tensor([[1, 8, 6],
[7, 5, 3]])
0.3 Reverse diffusion process
After progressively adding noise to an image, if we can then sample from q(\mathbf x_{t-1} | \mathbf x_t), we can reverse the forward process and denoise the image. However, estimating q(\mathbf x_{t-1} | \mathbf x_t) is not trivial since it would require us to know the distribution of all images. This requires us to use a model that can learn to approximate these conditional probabilities to denoise the images. The reverse diffusion process can thus be represented as:
p_{\theta}(\mathbf x_{t-1} | \mathbf x_t) = \mathcal N(\mathbf x_{t-1}; \mu_{\theta}(\mathbf x_t, t), \sum_{\theta}(\mathbf x_t, t))
The little \theta subscript means that are parameters that need to be learned by the model. Therefore, the network has to learn to predict the mean and variance of the distribution of the data point at time t-1 given the data point \mathbf x_t at time t.
In Equation 11 and 12, the authors demonstrate that the mean can be parameterized to make the neural network learn the added noise.
{fig.align-center width=80%}
As a result, the neural network becomes a noise predictor that is optimized using a MSE between the predicted noise \epsilon _\theta and the actual noise \epsilon i.e: Loss_t = ||\epsilon_t - \epsilon_{\theta}(\mathbf x_t, t)||^2
where \epsilon_{\theta}(\mathbf x_t, t) = $_( x_0 + , t) $
The training procedure thus becomes:
{fig.align-center width=70%}
In simple terms: - we take a random sample \mathbf x_0 from the data distribution q(\mathbf x)
we sample a time step t from a uniform distribution
we sample noise \mathbf \epsilon from a normal distribution and corrupt the input by the amount of noise found at time step t to obtain \mathbf x_t. The amount of noise is determined by the variance schedule \beta_t.
The neural network is trained to predict the added noise \mathbf \epsilon
Jay Allamar illustrates this nicely in the blogpost The Illustrated Stable Diffusion {fig.align-center width=80%}
This would be a good place to define the loss function as follows:
l1_loss: The L1 loss aka MAE is the average absolute difference between the predicted and actual values.
l2_loss: The L2 loss aka MSE is the average squared difference between the predicted and actual values.
Huber Loss: The huber loss is a combination of the L1 and L2 losses. It uses L1 loss for small errors and L2 loss for large errors. It is less sensitive to outliers than the L2 loss. It is defined as:
\text{huber}{x} = \begin{cases} \frac{1}{2}{y_i - \hat{y_i}}^2 & \text{for } |y_i - \hat{y_i}| \leq \delta \\ \delta (|y_i - \hat{y_i}| - \frac{1}{2} \delta ^2) & \text{otherwise} \end{cases}
import torch.nn.functional as F
def p_losses(denoise_model, x_start, t, noise = None, loss_type = "l1"):
if noise is None:
= torch.randn_like(x_start)
noise
# Get noisy image
= q_sample(x_start = x_start, t = t, noise = noise)
x_noisy
# Predict noise
= denoise_model(x_noisy, t)
predicted_noise
if loss_type == "l1":
= F.l1_loss(noise, predicted_noise)
loss elif loss_type == "l2":
= F.mse_loss(noise, predicted_noise)
loss elif loss_type == "huber":
= F.smooth_l1_loss(noise, predicted_noise)
loss else:
raise NotImplementedError()
return loss
0.4 The Network used in diffusion models
As we have previously seen, a diffusion model is a noise predictor that takes in a noised image and returns a noise prediction. The typical network used in diffusion models is a U-Net. A U-Net is an autoencoder that has skip connections between the encoder and the decoder. The U-Net processes an image by progressively downsampling it in the encoder and upsampling it in the decoder. The skip connections allow the decoder to access the encoder’s feature maps which helps in mitigating the information loss that can occur during downsampling in the encoder.
Okay, let’s now implement these concepts in PyTorch. This is based on Phil Wang’s implementation
0.4.1 Network helpers
import torch
import torch.nn as nn
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from torch import einsum
import math
# Return true if x is not None
def exists(x):
return x is not None
# Return val if val is not none else d() if d is a function else d
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# Residual block i.e x + f(x)
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
# Upsample block + Conv2d
def Upsample(dim_in, dim_out = None):
return nn.Sequential(
= 2, mode = "nearest"), # Expects B.s, C, R, Cols
nn.Upsample(scale_factor = dim_in, out_channels = default(dim_out, dim_in), kernel_size = 3, padding = 'same')
nn.Conv2d(in_channels
)
# Downsample block
def Downsample(dim_in, dim_out = None):
return nn.Sequential(
'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2), # Halves the height and width and multiplies the channels by 4
Rearrange(= dim_in * 4, out_channels = default(dim_out, dim_in), kernel_size = 3, padding='same')
nn.Conv2d(in_channels )
# Test the above blocks
= torch.randn(1, 3, 32, 32)
x print("Input shape: ", x.shape)
print("Upsample shape: ", Upsample(3)(x).shape)
print("Downsample shape: ", Downsample(3)(x).shape)
Input shape: torch.Size([1, 3, 32, 32])
Upsample shape: torch.Size([1, 3, 64, 64])
Downsample shape: torch.Size([1, 3, 16, 16])
0.5 Time embeddings
In the forward and reverse diffusion processes, we need to know the time step t at which we are in order to determine the amount of noise to add/remove. This information can be passed to the network using a time embedding of dim dimension and added at the residual blocks. The authors used sinusoidal position embeddings (as used in the Transformer) to provide a continuous and meaningful representation of the time steps. Read more about position embeddings at this fantastic blog. The formula for sinusoidal position embeddings is:
$$ PE_{(pos, 2i)} = sin() \ \
\PE_{(pos, 2i+1)} = cos() $$
where - \text{{PE}}(pos, 2i) represents the (2i)-th dimension of the positional encoding for the word at position pos, - \text{{PE}}(pos, 2i + 1) represents the (2i + 1)-th dimension, - (i) is the dimension index, and - (d) is the dimensionality of the positional encoding.
Let’s put this in code:
# Create a sinusoidal time encoding
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
= time.device # cpu or gpu
device = self.dim // 2
half_dim = math.log(10000) / (half_dim - 1) # scale factor
embeddings = torch.exp(torch.arange(half_dim, device = device) * -embeddings) # Decreasing exponential frequencies
embeddings = time[:, None] * embeddings[None, :] # Multiply the frequencies with the time (T, 1) x (1, D/2) = (T, D/2)
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim = -1) # Concatenate the sin and cos embeddings (T, D)
embeddings return embeddings
This results in a time embedding which contains pairs of sine and cosine values for each time step. The time embedding is then added to the input image at each residual block.
# Test the above class
= 32
dim = torch.arange(0, 10).float()
time_steps print("Time steps: ", time_steps)
= dim // 2
half_dim print("Half dim: ", half_dim)
= math.log(10000) / (half_dim - 1)
embeddings print("Embeddings: ", embeddings)
= torch.exp(torch.arange(half_dim) * -embeddings)
embeddings print("Embeddings: ", embeddings)
= torch.matmul(time_steps[:, None], embeddings[None, :])
embeddings print("Embeddings shape: ", embeddings.shape)
= torch.cat((embeddings.sin(), embeddings.cos()), dim = -1)
embeddings print("Embeddings shape: ", embeddings.shape)
print("Embeddings[0]: ", embeddings[0])
Time steps: tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
Half dim: 16
Embeddings: 0.6140226914650789
Embeddings: tensor([1.0000e+00, 5.4117e-01, 2.9286e-01, 1.5849e-01, 8.5770e-02, 4.6416e-02,
2.5119e-02, 1.3594e-02, 7.3564e-03, 3.9811e-03, 2.1544e-03, 1.1659e-03,
6.3096e-04, 3.4146e-04, 1.8479e-04, 1.0000e-04])
Embeddings shape: torch.Size([10, 16])
Embeddings shape: torch.Size([10, 32])
Embeddings[0]: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
0.6 Residual blocks of the U-Net
Phil Wang implements a weight standardized ResNet which as been observed to work better with group normalization (see Kolesnikov et al.). Group Normalization is a normalization layer that divides channels into groups and normalizes the features within each group. Weight Standardization is a normalization technique that smooths the loss landscape by standardizing the weights in convolutional layers Qiao et al..
# Weight standardization: standardize the weights of the conv layers
class WeightStandardizedConv2d(nn.Conv2d):
# (https://arxiv.org/abs/1912.11370
def __init__(self, in_channels, out_channels, kernel_size, padding = 'same'):
super().__init__(in_channels, out_channels, kernel_size, padding = padding)
def forward(self, x):
= 1e-5 if x.dtype == torch.float32 else 1e-3
eps = self.weight # Weight of kernel of shape (output_chan, input_chan, kernel_size, kernel_size)
weight = reduce(weight, "o ... -> o 1 1 1", "mean") # Mean of the kernel (output_chan, input_chan, kernel_size, kernel_size) -> (output_chan, 1, 1, 1)
mean = reduce(weight, "o ... -> o 1 1 1", torch.var) # Partial ensures that the function is called with the unbiased parameter set to False
var = (weight - mean) / (var + eps).sqrt() # Normalize the weights
normalized_weight return F.conv2d(
x,
normalized_weight,self.bias,
self.stride,
self.padding,
self.dilation,
self.groups
)
# Weight standardization conv + Group norm + SiLU
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.conv = WeightStandardizedConv2d(in_channels = dim, out_channels = dim_out, kernel_size = 3)
# Separate dim_out channels into groups
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
# Weight standardization + Group norm - Perfect combo
= self.conv(x)
x = self.norm(x)
x
if exists(scale_shift):
= scale_shift
scale, shift = x * (scale + 1) + shift # Incorporate the time embedding
x
= self.act(x)
x return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = (
= time_emb_dim, out_features = dim_out * 2)) if exists(time_emb_dim) else None
nn.Sequential(nn.SiLU(), nn.Linear(in_features
)
self.block1 = Block(dim, dim_out, groups = groups) # Weight standardization conv + Group norm + SiLU
self.block2 = Block(dim_out, dim_out, groups = groups) # Weight standardization conv + Group norm + SiLU
self.res_conv = nn.Conv2d(in_channels = dim, out_channels = dim_out, kernel_size = 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
= None
scale_shift if exists(time_emb) and exists(self.mlp):
= self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
time_emb = time_emb.chunk(2, dim = 1) # Split the time embedding into 2 parts [b c/2, 1, 1] [b c/2, 1, 1]
scale_shift
= self.block1(x, scale_shift = scale_shift) #
h = self.block2(h)
h return self.res_conv(x) + h # Residual connection
In the context of convolutional neural networks (CNNs), a “kernel size of 1” refers to using a 1x1 convolutional layer.
A convolutional layer with a kernel size of 1x1 operates on local input patches of size 1x1. This might seem counterintuitive at first because a convolutional layer with a larger kernel size is often used to capture spatial hierarchies and patterns. However, a 1x1 convolutional layer has some specific use cases and advantages:
Channel-wise Transformation: A 1x1 convolution is applied independently to each element in the input tensor along the spatial dimensions. It doesn’t consider neighboring elements in the same spatial location. Instead, it focuses on transforming the input tensor along the channel dimension.
0.7 Attention module
The DDPM authors used attention within the convolutional blocks:
Phil Wang implements two variants, self-attention as used in the Transformer and linear attention as used in Shen et al whose memory and computational costs grow linearly with input size as opposed to quadratically as in the Transformer.
# Attention block
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_heads = 32):
super().__init__()
self.scale = dim_heads ** -0.5 # scale attntion by sqrt(dim_heads)
self.heads = heads
= dim_heads * heads # Concatenate the heads
hidden_dim self.to_qkv = nn.Conv2d(in_channels = dim, out_channels = hidden_dim * 3, kernel_size = 1, bias = False)
self.to_out = nn.Conv2d(in_channels = hidden_dim, out_channels = dim, kernel_size = 1)
def forward(self, x):
= x.shape
b, c, h, w = self.to_qkv(x).chunk(3, dim = 1) # Split the output into 3 parts [b, c, h, w] -> [b, c/3, h, w] x 3
qkv = map(
q, k, v lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv
)
= q * self.scale # Scale the query
q
# Matrix multiplication between query and key
#sim = einsum('b h d i, b h d j -> b h i j', q, k)
= torch.matmul(q.transpose(-1, -2), k) # Make query n x dk and key n x dk
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim
# Softmax to get the attention weights
= sim.softmax(dim = -1)
attn
# Matrix multiplication between attention weights and value
#out = einsum('b h i j, b h d j -> b h i d', attn, v)
= torch.matmul(attn, v.transpose(-1, -2)) # B x H x n x n * B x H x n x head_size -> B x H x n x head_size
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) # Rearrange/concaetenate to B x C x R x Cols as expected by nn.Conv2d
out return self.to_out(out)
# Linear attention block
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
= dim_head * heads
hidden_dim self.to_qkv = nn.Conv2d(in_channels = dim, out_channels = hidden_dim * 3, kernel_size=1, bias=False)
self.to_out = nn.Sequential(
= hidden_dim, out_channels = dim, kernel_size = 1),
nn.Conv2d(in_channels 1, dim)
nn.GroupNorm(
)
def forward(self, x):
= x.shape
b, c, h, w = self.to_qkv(x).chunk(3, dim = 1) # Split the output into 3 parts [b, c, h, w] -> [b, c/3, h, w] x 3
qkv = map(
q, k, v lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv
)
= q.softmax(dim = -2)
q = k.softmax(dim = -1)
k
= q * self.scale
q = torch.einsum('b h d n, b h e n -> b h d e', k, v)
context # context = torch.matmul(k.transpose(-1, -2), v) # Make key dk x n and value n x dv
= torch.einsum('b h d e, b h d n -> b h e n', context, q)
out # out = torch.matmul(context, q) # B x H x dk x dv * B x H x dq x n -> B x H x dk x n
= rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
out return self.to_out(out)
0.8 Pre-normalization
Recent research has showed that pre-normalization in Transformers leads to well-behaved gradients and faster convergence: see Xiong et al, Nguyen et al.
In this implementation, pre-normalization will be used to apply group normalization before the attention layer.
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.GroupNorm(1, dim) # Similar to instance norm
self.fn = fn
def forward(self, x):
= self.norm(x)
x return self.fn(x)
0.9 U-Net
Now let’s create a U-Net that takes in noisy images and their respective noise levels and predicts the noise that was added to the image.
= 3
init_dim = (1, 2, 4, 8)
dim_mults = [init_dim, *map(lambda m: m * init_dim, dim_mults)]
dims -1], dims[1:] dims, dims[:
([3, 3, 6, 12, 24], [3, 3, 6, 12], [3, 6, 12, 24])
= list(zip(dims[:-1], dims[1:]))
in_out print("in_out: ", in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
= ind >= len(in_out) - 1
is_last print("dim_in, dim_out, is_last, ind: ", dim_in, dim_out, is_last, ind)
in_out: [(3, 3), (3, 6), (6, 12), (12, 24)]
dim_in, dim_out, is_last, ind: 3 3 False 0
dim_in, dim_out, is_last, ind: 3 6 False 1
dim_in, dim_out, is_last, ind: 6 12 False 2
dim_in, dim_out, is_last, ind: 12 24 True 3
class Unet(nn.Module):
def __init__(
self,
dim,= None,
init_dim = None,
out_dim = (1, 2, 4, 8),
dim_mults = 3,
channels = False,
self_condition = 4,
resnet_block_groups
):super().__init__()
# Determine dimensions
self.channels = channels
self.self_condition = self_condition
= channels * (2 if self_condition else 1)
input_channels
= default(init_dim, dim)
init_dim self.init_conv = nn.Conv2d(in_channels = input_channels, out_channels = init_dim, kernel_size = 1) # Only change the number of channels
= [init_dim, *map(lambda m: m * init_dim, dim_mults)]
dims = list(zip(dims[:-1], dims[1:]))
in_out
= partial(ResnetBlock, groups = resnet_block_groups)
block_klass
# Time embeddings
= dim * 4
time_dim
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),= dim, out_features = time_dim),
nn.Linear(in_features
nn.GELU(),= time_dim, out_features = time_dim),
nn.Linear(in_features
)
# Layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
= len(in_out)
num_resolutions
for ind, (dim_in, dim_out) in enumerate(in_out):
= ind == (num_resolutions - 1)
is_last
self.downs.append(
nn.ModuleList([= dim_in, dim_out = dim_in, time_emb_dim = time_dim), # Resnet block
block_klass(dim = dim_in, dim_out = dim_in, time_emb_dim = time_dim), # Resnet block
block_klass(dim # Pre-norm Linear attention block + Residual
Residual(PreNorm(dim_in, LinearAttention(dim_in))), = dim_in, dim_out = dim_out) if not is_last else nn.Conv2d(in_channels = dim_in, out_channels = dim_out, kernel_size = 3, padding = 'same')
Downsample(dim_in
])
)
= dims[-1]
mid_dim self.mid_block1 = block_klass(dim = mid_dim, dim_out = mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(dim = mid_dim)))
self.mid_block2 = block_klass(dim = mid_dim, dim_out = mid_dim, time_emb_dim = time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
= ind == (num_resolutions - 1)
is_last
self.ups.append(
nn.ModuleList([= dim_out + dim_in, dim_out = dim_out, time_emb_dim = time_dim),
block_klass(dim = dim_out + dim_in, dim_out = dim_out, time_emb_dim = time_dim),
block_klass(dim
Residual(PreNorm(dim_out, LinearAttention(dim_out))),= dim_out, dim_out = dim_in) if not is_last else nn.Conv2d(in_channels = dim_out, out_channels = dim_in, kernel_size = 3, padding = 'same')
Upsample(dim_in
])
)
self.out_dim = default(out_dim, channels)
self.final_res_block = block_klass(dim = dim * 2, dim_out = dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(in_channels = dim, out_channels = self.out_dim, kernel_size = 1)
def forward(self, x, time, x_self_cond = None):
if self.self_condition:
= default(x_self_cond, lambda: torch.zeros_like(x))
x_self_cond = torch.cat((x, x_self_cond), dim = 1)
x
# Make the input have desired number of channels
= self.init_conv(x)
x # Return a copy of x
= x.clone()
r
# Create time embeddings
= self.time_mlp(time)
t
= []
h
for block1, block2, attn, downsample in self.downs:
= block1(x, t)
x
h.append(x)
= block2(x, t)
x = attn(x)
x
h.append(x)
= downsample(x)
x
= self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x
for block1, block2, attn, upsample in self.ups:
= torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x
= torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x
= upsample(x)
x
= torch.cat((x, r), dim = 1)
x
= self.final_res_block(x, t)
x return self.final_conv(x)
Now that we have defined our noise predictor, the training process
0.10 Define a PyTorch Dataset + DataLoader
Now, let’s define a PyTorch Dataset and DataLoader.
from datasets import load_dataset
# Load dataset from HuggingFace hub
= load_dataset("fashion_mnist")
dataset = 28
image_size = 1
channels = 32 batch_size
from torchvision import transforms
from torch.utils.data import DataLoader
# Define transforms
= transforms.Compose([
transform
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),lambda t: (t * 2) - 1) # rescale to [-1, 1]
transforms.Lambda(
])
# Define function
def transforms(examples):
"pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
examples[del examples["image"]
return examples
# On the fly transforms
= dataset.with_transform(transforms).remove_columns("label")
transformed_dataset
# Create dataloader
= DataLoader(transformed_dataset["train"], batch_size = batch_size, shuffle = True) dataloader
0.11 Sampling from the model
So how do we use the model to generate images? The paper says the following:
\epsilon _ {\theta} (x_t, t) refers to the noise predicted by the trained model. The authors showed that the variance \sigma_t ^2 can be set as:
The entire process can be summarized as:
Generating new images from a diffusion model involves:
- sample pure noise \mathbf x_T from a Gaussian distribution
- for each time step t from T to 0, use the noise predictor to denoise the image and obtain \mathbf x_{t-1}
Ideally, as Jay Allamar put it, we get an image that’s closer to the images the model was trained on (not the exact images themselves, but the distribution - the world of pixel arrangements where the sky is usually blue and above the ground, people have two eyes, cats look a certain way – pointy ears and clearly unimpressed):
# Creating the sampling procedure
= torch.sqrt( 1.0 / alphas)
sqrt_recip_alphas
# Define posterior variance i.e equation 7
= F.pad(alphas_cumprod[:-1], (1, 0), value = 1.0) # padding of 1 along the first dimension and no padding along the second dimension
alphas_cumprod_prev = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
posterior_variance
@torch.no_grad()
def p_sample(model, x, t, t_index):
= extract(betas, t, x.shape)
betas_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_recip_alphas, t, x.shape)
sqrt_recip_alphas_t
# Impelemting equation 11 or equation 4 in the sampling section
= sqrt_recip_alphas_t * (
model_mean - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
x
)
if t_index == 0:
return model_mean
else:
= extract(posterior_variance, t, x.shape)
posterior_variance_t = torch.randn_like(x)
noise return model_mean + torch.sqrt(posterior_variance_t) * noise
# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
= next(model.parameters()).device
device
= shape[0]
b # Start from pure noise (for each example in the batch)
= torch.randn(shape, device = device)
img = []
imgs for i in reversed(range(timesteps)):
= p_sample(model, img, torch.full((b,), i, device = device, dtype = torch.long), i) # x_t-1
img
imgs.append(img.cpu().numpy())
return np.array(imgs)
@torch.no_grad()
def sample(model, image_size, batch_size = 16, channels = 3):
return p_sample_loop(model, shape = (batch_size, channels, image_size, image_size))
0.12 Train the model
Now let’s train the model and sample from it
from pathlib import Path
def num_to_groups(num, divisor):
= num // divisor
groups = num % divisor
remainder = [divisor] * groups
arr if remainder > 0:
arr.append(remainder)return arr
= Path("./results")
results_folder = True)
results_folder.mkdir(exist_ok = 1000 save_and_sample_every
4, 32) num_to_groups(
[4]
# Model setup
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
= Unet(
model = image_size,
dim = channels,
channels = (1, 2, 4)
dim_mults
)= model.to(device)
model = torch.optim.Adam(model.parameters(), lr = 1e-3)
optimizer
# Model training
from torchvision.utils import save_image
= 6
epochs
for epoch in range(epochs):
for step, batch in enumerate(dataloader):
= batch["pixel_values"].shape[0]
batch_size = batch["pixel_values"].to(device)
batch
# Sample time steps uniformly for every example in batch
= torch.randint(0, timesteps, (batch_size,), device = device, dtype = torch.long)
t
= p_losses(model, batch, t, loss_type = "huber")
loss
if step % 500 == 0:
print(f"Epoch: {epoch} Step: {step} Loss: {loss.item()}")
# Zero out gradientse
optimizer.zero_grad()
# Backpropagate
loss.backward()
# Update weights
optimizer.step()
# Save generated images
# if step !=0 and step % save_and_sample_every == 0:
# milestone = step // save_and_sample_every
# batches = num_to_groups(num = 4, divisor = batch_size)
# all_images_list = list(map(lambda n: sample(model, image_size, batch_size = n, channels = channels), batches))
# all_images = torch.cat(all_images_list, dim = 0)
# all_images = (all_images + 1) * 0.5 # rescale to [0, 1]
# save_image(all_images, str(results_folder / f'sample_{milestone}.png'), nrow = 6)
Epoch: 0 Step: 0 Loss: 0.5224699378013611
Epoch: 0 Step: 500 Loss: 0.041550733149051666
Epoch: 0 Step: 1000 Loss: 0.04898069426417351
Epoch: 0 Step: 1500 Loss: 0.04184384271502495
Epoch: 1 Step: 0 Loss: 0.05208709090948105
Epoch: 1 Step: 500 Loss: 0.031184278428554535
Epoch: 1 Step: 1000 Loss: 0.03631388023495674
Epoch: 1 Step: 1500 Loss: 0.03278825432062149
Epoch: 2 Step: 0 Loss: 0.026345837861299515
Epoch: 2 Step: 500 Loss: 0.027054786682128906
Epoch: 2 Step: 1000 Loss: 0.04869785159826279
Epoch: 2 Step: 1500 Loss: 0.023147860541939735
Epoch: 3 Step: 0 Loss: 0.027534985914826393
Epoch: 3 Step: 500 Loss: 0.036470092833042145
Epoch: 3 Step: 1000 Loss: 0.02620980143547058
Epoch: 3 Step: 1500 Loss: 0.04134272783994675
Epoch: 4 Step: 0 Loss: 0.025678491219878197
Epoch: 4 Step: 500 Loss: 0.02390681765973568
Epoch: 4 Step: 1000 Loss: 0.03659454733133316
Epoch: 4 Step: 1500 Loss: 0.02753089740872383
Epoch: 5 Step: 0 Loss: 0.04199949651956558
Epoch: 5 Step: 500 Loss: 0.02578447200357914
Epoch: 5 Step: 1000 Loss: 0.03484490513801575
Epoch: 5 Step: 1500 Loss: 0.034843917936086655
0.13 Sampling from the model
To generate new images from the model, we need to sample pure noise \mathbf x_T from a Gaussian distribution and then use the noise predictor to denoise the image and obtain \mathbf x_{t-1} for each time step t from T to 0.
# Sample 8 images
= sample(model, image_size = image_size, batch_size = 8, channels = 1)
sampled_images
# Transpose to (B, H, W, C) as expected by matplotlib
= sampled_images[-1].transpose(0, 2, 3, 1) # Get images at time step 0
sampled_images_c
# Visualize sampled images
= 4, rows = 2, cmap = 'gray') visualize_noisy_images(sampled_images_c, cols
The above shows the result of sampling starting from pure noise i.e torch.randn((BATCH_SIZE, CHANNELS, HEIGHT, WIDTH)) where BATCH_SIZE = 8. Not bad, at least the model can generate images that look like clothing. Let’s create a gif of the denoising process:
import matplotlib.animation as animation
= 3
image_index
= plt.figure()
fig = []
ims
for i in range(timesteps):
= plt.imshow(sampled_images[i][image_index].transpose(1, 2, 0), animated = True, cmap = 'gray')
im # Add title
#plt.title(f"Time step: {i}")
ims.append([im])
= animation.ArtistAnimation(fig, ims, interval = 50, blit = True, repeat_delay = 1000)
animate 'diffusion.gif')
animate.save( plt.close()