import torch
import torch.nn as nn
import torchvision.transforms.functional
# Implement the double 3X3 convolution blocks
# The original paper did not use padding, but we will use padding to keep the image size the same
class double_convolution(nn.Module):
"""
This class implements the double convolution block which consists of two 3X3 convolution layers,
each followed by a ReLU activation function.
"""
def __init__(self, in_channels, out_channels): # Initialize the class
super().__init__() # Initialize the parent class
# First 3X3 convolution layer
self.first_cnn = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
self.act1 = nn.ReLU()
# Second 3X3 convolution layer
self.second_cnn = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
self.act2 = nn.ReLU()
# Pass the input through the double convolution block
def forward(self, x):
x = self.first_cnn(x)
x = self.act1(x)
x = self.act2(self.second_cnn(x))
return x
# Implement the Downsample block that occurs after each double convolution block
class down_sample(nn.Module):
"""
This class implements the downsample block which consists of a Max Pooling layer with a kernel size of 2.
The Max Pooling layer halves the image size reducing the spatial resolution of the feature maps
while retaining the most important features.
"""
def __init__(self):
super().__init__()
self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
# Pass the input through the downsample block
def forward(self, x):
x = self.max_pool(x)
return x
# Implement the UpSample block that occurs in the decoder part of the network
class up_sample(nn.Module):
"""
This class implements the upsample block which consists of a convolution transpose layer with a kernel size of 2.
The convolution transpose layer doubles the image size increasing the spatial resolution of the feature maps
while retaining the learned features.
"""
def __init__(self, in_channels, out_channels):
super().__init__()
# Convolution transpose layer
self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)
# Pass the input through the upsample block
def forward(self, x):
x = self.up_sample(x)
return x
# Implement the crop and concatenate block that occurs in the decoder part of the network
# This block concatenates the output of the upsample block with the output of the corresponding downsample block
# The output of the crop and concatenate block is then passed through a double convolution block
class crop_and_concatenate(nn.Module):
"""
This class implements the crop and concatenate block which combines the output of the upsample block
with the corresponding features from the contracting path through skip connections,
allowing the network to recover the fine-grained details lost during downsampling
and produce a high-resolution output segmentation map.
"""
# def forward(self, upsampled, bypass):
# # Crop the feature map from the contacting path to match the size of the upsampled feature map
# bypass = torchvision.transforms.functional.center_crop(img = bypass, output_size = [upsampled.shape[2], upsampled.shape[3]])
# # Concatenate the upsampled feature map with the cropped feature map from the contracting path
# x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
# return x
# Alternatively crop the upsampled feature map to match the size of the feature map from the contracting path
def forward(self, upsampled, bypass):
upsampled = torchvision.transforms.functional.resize(img = upsampled, size = bypass.shape[2:], antialias=True)
x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
return x
# m = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
# input = torch.randn(1, 1024, 28, 28)
# m(input).shape
# m = nn.MaxPool2d(kernel_size = 2, stride = 2)
# xx = torch.randn(1, 1, 143, 143)
# m(xx).shape
## Implement the UNet architecture
class UNet(nn.Module):
# in_channels: number of channels in the input image
# out_channels: number of channels in the output image
def __init__(self, in_channels, out_channels):
super().__init__()
# Define the contracting path: convolution blocks followed by downsample blocks
self.down_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
[(in_channels, 64), (64, 128), (128, 256), (256, 512)]) # List of downsample blocks
self.down_samples = nn.ModuleList(down_sample() for _ in range(4))
# Define the bottleneck layer
self.bottleneck = double_convolution(in_channels = 512, out_channels = 1024)
# Define the expanding path: upsample blocks followed by convolution blocks
self.up_samples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
[(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of upsample blocks
self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))
self.up_conv = nn.ModuleList(double_convolution(in_chans, out_chans) for in_chans, out_chans in
[(1024, 512), (512, 256), (256, 128), (128, 64)]) # List of convolution blocks
# Final 1X1 convolution layer to produce the output segmentation map:
# The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
# while leaving the spatial dimensions unchanged.
self.final_conv = nn.Conv2d(in_channels = 64, out_channels = out_channels, kernel_size = 1)
# Pass the input through the UNet architecture
def forward(self, x):
# Pass the input through the contacting path
skip_connections = [] # List to store the outputs of the downsample blocks
for down_conv, down_sample in zip(self.down_conv, self.down_samples):
x = down_conv(x)
skip_connections.append(x)
x = down_sample(x)
# Pass the output of the contacting path through the bottleneck layer
x = self.bottleneck(x)
# Pass the output of the bottleneck layer through the expanding path
skip_connections = skip_connections[::-1] # Reverse the list of skip connections
for up_sample, concat, up_conv in zip(self.up_samples, self.concat, self.up_conv):
x = up_sample(x)
x = concat(x, skip_connections.pop(0)) # Remove the first element from the list of skip connections
x = up_conv(x)
# Pass the output of the expanding path through the final convolution layer
x = self.final_conv(x)
return xUNets 101
1 Implementing the original U-Net model
This code implements a U-Net model for semantic segmentation from the paper U-Net: Convolutional Networks for Biomedical Image Segmentation:
A U-Net consists of an encoder - a series of convolution and pooling layers which reduce the spatial resolution of the input, followed by a decoder - a series of transposed convolution and upsampling layers which increase the spatial resolution of the input. The encoder and decoder are connected by a bottleneck layer which is responsible for reducing the number of channels in the input.
The key innovation of U-Net is the addition of skip connections that connect the contracting path to the corresponding layers in the expanding path, allowing the network to recover fine-grained details lost during downsampling.
1.1 Sanity check for the model
import torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels = 3, out_channels = 1).to(device)
dummy_input = torch.randn((1, 3, 32, 32)).to(device)
mask = model(dummy_input)
mask.shape# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))2 Going deeper: Some more interesting UNets
This section implemets a U-Net model that incorporates some of the recent advances in deep learning, that is: - Residual networks for UNets: The key idea behind ResNets is the use of residual connections, which allow for the direct propagation of information through the network without being modified by the layers in between. The residual connection is achieved by adding the input of a layer to its output, so that the output of the layer becomes: y = f(x) + x. The shortcut connection skips one or more layers, with the change in dimensions, if any, compensated with a 1x1 convolutional layer.
- Group normalization: works by normalizing the activations of a layer across groups of channels instead of the entire batch. See more explanations and comparisons between different normalizations in this blog post.
- Swish activation function: is a self-gated activation function that is defined as
f(x) = x * sigmoid(x). It has been shown to outperform ReLU and other activation functions on deeper models across a number of challenging datasets.
- Attention gated Unets: is a modification of the U-Net architecture that uses attention gates to selectively focus on the most relevant parts of the input image. The attention gates are implemented as a 1x1 convolutional layer that learns a weight for each channel in the input. The output of the attention gate is then multiplied with the input to the layer, so that the output of the layer becomes:
y = f(x) * sigmoid(x). The attention gates are applied to the output of the contracting path and the input of the expanding path.
import torch
import torch.nn as nn
import torchvision
# Define a Residual block
class residual_block(nn.Module):
"""
This class implements a residual block which consists of two convolution layers with group normalization
"""
def __init__(self, in_channels, out_channels, n_groups = 8):
super().__init__()
# First convolution layer
self.first_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
self.first_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
self.act1 = nn.SiLU() # Swish activation function
# Second convolution layer
self.second_conv = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, padding = 1)
self.second_norm = nn.GroupNorm(num_groups = n_groups, num_channels = out_channels)
self.act2 = nn.SiLU() # Swish activation function
# If the number of input channels is not equal to the number of output channels,
# then use a 1X1 convolution layer to compensate for the difference in dimensions
# This allows the input to have the same dimensions as the output of the residual block
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1)
else:
# Pass the input as is
self.shortcut = nn.Identity()
# Pass the input through the residual block
def forward(self, x):
# Store the input
input = x
# Pass input through the first convolution layer
x = self.act1(self.second_norm(self.first_conv(x)))
# Pass the output of the first convolution layer through the second convolution layer
x = self.act2(self.second_norm(self.second_conv(x)))
# Add the input to the output of the second convolution layer
# This is the skip connection
x = x + self.shortcut(input)
return x
# Implement the DownSample block that occurs after each residual block
class down_sample(nn.Module):
def __init__(self):
super().__init__()
self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
# Pass the input through the downsample block
def forward(self, x):
x = self.max_pool(x)
return x
# Implement the UpSample block that occurs in the decoder path/expanding path
class up_sample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# Convolution transpose layer to upsample the input
self.up_sample = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 2, stride = 2)
# Pass the input through the upsample block
def forward(self, x):
x = self.up_sample(x)
return x
# Implement the crop and concatenate layer
class crop_and_concatenate(nn.Module):
def forward(self, upsampled, bypass):
# Crop the upsampled feature map to match the dimensions of the bypass feature map
upsampled = torchvision.transforms.functional.resize(upsampled, size = bypass.shape[2:], antialias=True)
x = torch.cat([upsampled, bypass], dim = 1) # Concatenate along the channel dimension
return x
# Implement an attention block
class attention_block(nn.Module):
def __init__(self, skip_channels, gate_channels, inter_channels = None, n_groups = 8):
super().__init__()
if inter_channels is None:
inter_channels = skip_channels // 2
# Implement W_g i.e the convolution layer that operates on the gate signal
# Upsample gate signal to be the same size as the skip connection
self.W_g = up_sample(in_channels = gate_channels, out_channels = skip_channels)
#self.W_g_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)
#self.W_g_act = nn.SiLU() # Swish activation function
# Implement W_x i.e the convolution layer that operates on the skip connection
self.W_x = nn.Conv2d(in_channels = skip_channels, out_channels = inter_channels, kernel_size = 1)
#self.W_x_norm = nn.GroupNorm(num_groups = n_groups, num_channels = inter_channels)
#self.W_x_act = nn.SiLU() # Swish activation function
# Implement phi i.e the convolution layer that operates on the output of W_x + W_g
self.phi = nn.Conv2d(in_channels = inter_channels, out_channels = 1, kernel_size = 1)
#self.phi_norm = nn.GroupNorm(num_groups = n_groups, num_channels = 1)
#self.phi_act = nn.SiLU() # Swish activation function
# Implement the sigmoid activation function
self.sigmoid = nn.Sigmoid()
# Implement the Swish activation function
self.act = nn.SiLU()
# Implement final group normalization layer
self.final_norm = nn.GroupNorm(num_groups = n_groups, num_channels = skip_channels)
# Pass the input through the attention block
def forward(self, skip_connection, gate_signal):
# Upsample the gate signal to match the channels of the skip connection
gate_signal = self.W_g(gate_signal)
# Ensure that the sizes of the skip connection and the gate signal match before addition
if gate_signal.shape[2:] != skip_connection.shape[2:]:
gate_signal = torchvision.transforms.functional.resize(gate_signal, size = skip_connection.shape[2:], antialias=True)
# Project to the intermediate channels
gate_signal = self.W_x(gate_signal)
# Project the skip connection to the intermediate channels
skip_signal = self.W_x(skip_connection)
# Add the skip connection and the gate signal
add_xg = gate_signal + skip_signal
# Pass the output of the addition through the activation function
add_xg = self.act(add_xg)
# Pass the output of attention through a 1x1 convolution layer to obtain the attention map
attention_map = self.sigmoid(self.phi(add_xg))
# Multiply the skip connection with the attention map
# Perform element-wise multiplication
skip_connection = torch.mul(skip_connection, attention_map)
skip_connection = nn.Conv2d(in_channels = skip_connection.shape[1], out_channels = skip_connection.shape[1], kernel_size = 1)(skip_connection)
skip_connection = self.act(self.final_norm(skip_connection))
return skip_connection
## Implement a residual attention U-Net
class ResidualAttentionUnet(nn.Module):
def __init__(self, in_channels, out_channels, n_groups = 4, n_channels = [64, 128, 256, 512, 1024]):
super().__init__()
# Define the contracting path: residual blocks followed by downsampling
self.down_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
[(in_channels, n_channels[0]), (n_channels[0], n_channels[1]), (n_channels[1], n_channels[2]), (n_channels[2], n_channels[3])])
self.down_samples = nn.ModuleList(down_sample() for _ in range(4))
# Define the bottleneck residual block
self.bottleneck = residual_block(n_channels[3], n_channels[4])
# Define the attention blocks
self.attention_blocks = nn.ModuleList(attention_block(skip_channels = residuals_chans, gate_channels = gate_chans) for gate_chans, residuals_chans in
[(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])
# Define the expanding path: upsample blocks, followed by crop and concatenate, followed by residual blocks
self.upsamples = nn.ModuleList(up_sample(in_chans, out_chans) for in_chans, out_chans in
[(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])
self.concat = nn.ModuleList(crop_and_concatenate() for _ in range(4))
self.up_conv = nn.ModuleList(residual_block(in_chans, out_chans) for in_chans, out_chans in
[(n_channels[4], n_channels[3]), (n_channels[3], n_channels[2]), (n_channels[2], n_channels[1]), (n_channels[1], n_channels[0])])
# Final 1X1 convolution layer to produce the output segmentation map:
# The primary purpose of 1x1 convolutions is to transform the channel dimension of the feature map,
# while leaving the spatial dimensions unchanged.
self.final_conv = nn.Conv2d(in_channels = n_channels[0] , out_channels = out_channels, kernel_size = 1)
# Pass the input through the residual attention U-Net
def forward(self, x):
# Store the skip connections
skip_connections = []
# # Store the gate signals
# gate_signals = []
# Pass the input through the contracting path
for down_conv, down_sample in zip(self.down_conv, self.down_samples):
x = down_conv(x)
skip_connections.append(x)
#gate_signals.append(x)
x = down_sample(x)
# Pass the output of the contracting path through the bottleneck
x = self.bottleneck(x)
skip_connections.append(x)
# Attention on the residual connections
#skip_connections = skip_connections[::-1]
n = len(skip_connections)
indices = [(n - 1 - i, n - 2 - i) for i in range(n - 1)]
attentions = []
for i, g_x in enumerate(indices):
g_gate = g_x[0]
x_residual = g_x[1]
attn = self.attention_blocks[i](skip_connections[x_residual], skip_connections[g_gate])
attentions.append(attn)
#attentions = attentions[::-1]
# Pass the output of the attention blocks through the expanding path
for up_sample, concat, up_conv in zip(self.upsamples, self.concat, self.up_conv):
x = up_sample(x)
x = concat(x, attentions.pop(0))
x = up_conv(x)
# Pass the output of the expanding path through the final convolution layer
x = self.final_conv(x)
return x2.1 Sanity check
## Sanity check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResidualAttentionUnet(in_channels = 3, out_channels = 1).to(device)
x = torch.randn((1, 3, 32, 32)).to(device)
mask = model(x)
mask.shape# See how data flows through the network
torchsummary.summary(model, input_size = (3, 32, 32))