SATELLITE DATA FOR AGRICULTURAL ECONOMISTS: Theory and Practice
Lecture date: 22-01-2026
1 Google Colaboratory (Deep Learning)
Agricultural economists often need precise spatial data to analyze crop distributions, model economic impacts, assess encroachment into protected areas due to agricultural expansion among others. In this session, we introduce basics of using deep learning approach to segment tea fields with a practical example in Kericho, the heart of tea growing in Kenya. We obtained high resolution Sentinel-2 satellite image from Google Earth Engine, you can directly get the notebook here. On the other hand, we obtained labels by manually digitizing tea plantations within GEE. The labels cover a small portion of the downloaded Satellite image so that we can train the model and use it to segment tea fields elsewhere: beyond the training area. We use torchgeo for this modeling task due to the following reasons:
- It is simple to use, eliminating a lot of issues such as georeferencing, chipping, label creation.
- It also maximally obtain training samples from the region of interest through its RandomGeoSampler.
2 Loading libraries
Since torchgeo is not natively installed in Google Colaboratory, we will have to install it. We will also install torchseg to help with the segmentation work. Other supporting libraries will just be imported as they are already pre-installed in colab.
# Install libraries not already available in colab
!pip install torchgeo torchseg
# Import the necessary libraries
import json
import torch
import rasterio
import torchseg
import torchgeo
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
import rasterio.sample
import geopandas as gpd
import albumentations as A
from rasterio.plot import show
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchgeo.transforms import AppendNDVI
from rasterio.transform import from_bounds
from albumentations.pytorch import ToTensorV2
from torchgeo.samplers import RandomGeoSampler, GridGeoSampler
from torchgeo.datasets import VectorDataset, RasterDataset, stack_samples
from sklearn.metrics import accuracy_score, confusion_matrix, classification_reportAfter installing the libraries, we make an important step of confirming the working directory. This is important since both our .tif and .gpkg files will be read from this location so we need to be sure of the path. We can use pwd function to print it.
3 Visualize the data
Let us confirm that our data is as expected. That is, we have .tif and .gpkg data overlapping. We simply define a function which achieves the plotting. In the function:
raster_path: Path to the raster filevector_path: Path to the vector filergb_bands: A tuple specifying the band order for RGBstretch = True: A boolean, whether to apply contrast stretching to enhance visulization
def plot_rgb_with_geopandas(raster_path, vector_path, rgb_bands=(3, 2, 1), stretch=True):
with rasterio.open(raster_path) as src:
img = src.read(rgb_bands)
# Handle NaNs (nodata)
mask = src.dataset_mask()
if stretch:
# Stretch the RGB bands individually
img_stretched = np.zeros_like(img, dtype=np.float32)
for i in range(img.shape[0]):
band = img[i]
p2, p98 = np.percentile(band[band > 0], (2, 98)) # Ignore zeros/NaNs
band = np.clip((band - p2) / (p98 - p2), 0, 1)
img_stretched[i] = band
img = img_stretched
fig, ax = plt.subplots(figsize=(10, 10))
rasterio.plot.show(
img,
ax=ax,
transform=src.transform,
cmap=None,
adjust=None
)
# Read and plot the vector
gdf = gpd.read_file(vector_path)
gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=1.5)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.legend(["Polygons"], loc="upper right")
plt.title("Enhanced RGB Image with Polygons Overlay")
plt.show()
if __name__ == "__main__":
raster_file = "/content/win2025_dl.tif"
polygons = "/content/full_labels_proj.gpkg"
plot_rgb_with_geopandas(raster_path=raster_file, vector_path=polygons)4 Defining the Data set
With torchgeo, we do NOT have to pre-chip the satellite image into small chips of say 256 by 256 pixels. It achieves this on the fly. However, we need to tell it the path to our data. In fact, we can have several images here (say, if we had another region in western Kenya etc). For now, it is the win2025_dl.tif image in our working directory, which we obtained from GEE. We define the class as follows:
# Define the GeoTiff dataset class
class GeoTiffDataset(RasterDataset):
filename_glob = "win2025_dl.tif"
raster_data = GeoTiffDataset(paths = "/content")In that case, raster_data is a blueprint of the satellite image we have in the directory. Next, we do the same for the label data. This label data is a .gpkg file which has a column stating the identity of each polygon as either tea or not tea. In other words, the class column. Here, we call the class column as tea_no_tea, you can name it as you wish. This is very important as it is what the library uses to create a label binary layer under the hood to intersect with the satellite image. We achieve this as follows:
5 Combining raster_data and label_data
Now that we have blueprints of both raster_data and label_data, the next step is to intersect the two. This will behave like cropping or clipping the raster to the extent of the label_data. Remember when we plotted the raster and vector, we observed that the vector only covered a small portion of the raster. We are not worried that the raster extent is bigger than the label extent because the cropping will be done on the fly for us. Chips for training the model will only be obtained from where the two datasets intersect/overlap. Regions outside the label_data will not be sampled. As simple as it can get, we achieve this intersection using an \(&\) operator.
# Create the intersection of the raster and vector/label datasets
training_data = raster_data & label_data # Just simple & achieves thisYou notice the printout that Converting LabelDataset res from (0.0001, 0.0001) to (10, 10). This tells us that our vector label data with polygons has now been converted to a binary raster under the hood with a pixel size similar to that of the satellite image we have - Sentinel-2. Something imortant to note also is that the pixels in both layers have been aligned.
6 Create sampler
Since we did not pre-chip the satellite image, we need to provide a sampler which will do the task of obtaining small chips from the original image and pushing them to the model for training at a time. This is important since we cannot push the whole image to the model at once. In this case, we use RandomGeoSampler which randomly picks unique but overlaping patches from the region of interest. This allows us to pick more patches from the study area than if we pre-chipped it. We achieve this as follows:
# Define the sampler that will execute the task of extracting samples
sampler = RandomGeoSampler(dataset = training_data, size = 32, length = 500)dataset = training_data ==> Is the dataset where samples will be obtained.
size = 32 ==> Is the number of pixels on length and width of each patch to be chipped
length = 1000 ==> The number of patches to obtain.
Remember, the patches can overlap but each is unique even if by a single pixel. Initially, the function obtains non-overlapping unique patches then proceeds to slightly overlapping patches until the length is reached or until no further unique patches are obtainable. This is a big advantage in case of limited training data like in our case.
7 Create data loader
Once we have created a sampler that splits the big image into small patches within the region of interest, we do the next task of creating a data loader. This is what collects the slices or patches from the sampler in groups and pushes them to the model for training, validation, and testing as well as prediction tasks, as the case may be. This is attained as follows:
# Initialize the dataloader. This is the function that will be serving the role of availing batches of extracted samples for model training
dataloader = DataLoader( # The function call
dataset = training_data, # The dataset from which sampler will obtain patches
batch_size = 50, # Number of patches to sample at a time
sampler = sampler, # The sampler we defined to pick patches from training_data
collate_fn = stack_samples # Function that combined the chipped patches into a batch
)8 Check the data loader
Here we then check whether the dataloader is working as expected. That is, whether it is getting 50 small images and 50 labels from the sampler to take to the model.
for batch in dataloader:
image = batch["image"][:, :12, :, :]
mask = image[:, -1, :, :]
print(f"Image batch length: {len(image)}")
print(f"Mask batch length: {len(mask)}")
breakbatch = next(iter(dataloader))
images = batch["image"][:4]
masks = batch["mask"][:4]
def normalize_image(img, eps=1e-6):
"""
Normalize a 3-band image tensor for display.
Clipping extreme values and scaling to 0-1.
"""
# Correct band indices for RGB: [Red, Green, Blue] = [2, 1, 0]
rgb = img[[2, 1, 0]]
# Clip to 2nd-98th percentile to reduce effect of outliers
vmin = torch.quantile(rgb, 0.02)
vmax = torch.quantile(rgb, 0.98)
rgb = torch.clamp(rgb, vmin.item(), vmax.item())
# Normalize to 0-1
rgb = (rgb - vmin) / (vmax - vmin + eps)
return rgb
# Plot images and masks side by side
plt.figure(figsize=(12, 6))
for i in range(4):
# Plot image (normalized RGB)
plt.subplot(2, 4, i+1)
rgb_img = normalize_image(images[i]).permute(1, 2, 0).cpu().numpy()
plt.imshow(rgb_img)
plt.title(f"Image {i+1}")
plt.axis('off')
# Plot mask (binary)
plt.subplot(2, 4, i+5)
plt.imshow(masks[i].cpu(), cmap='gray', vmin=0, vmax=1)
plt.title(f"Label {i+1}")
plt.axis('off')
plt.tight_layout()
plt.show()9 Define the model
Here we then remember our torchseg to give us a model/method to train our data which the sampler has sliced and the loader has gathered. torchseg has several models but we will use Unet which is a very common deep learning model, especially for segmentation tasks. In this setup, encoder_name="resnet18" specifies the encoder, which is the feature extractor part. encoder_weights=False means we are not using pre-trained weights from imagenet, we will train it from scratch from our data. in_channels=10 indicates that our Sentinel-2 image has 10 channels/bands. This is different from 3 channels of RGB. Lastly classes=2 sets the number of output classes, which in our case is 2, tea and non-tea.
10 Move the model to device
After defining the model. We need to move it to the right device. If a GPU is available, we should use it to speed up the model training. Something we could not do with classical machine learning. However, if we do not have GPU in our machine, we fall back to CPU. As you can imagine, there is a lot of data going to the model at a time and GPU excels in such kinds of analyses.
11 Defining the loss function and optimizer
We then set up the loss function and optimizer, which tell the model how wrong it is at each epoch and how to improve weights in the next training episode/epoch. crossEntropyLoss is useful for multiclass/binary classification. We go ahead and tell it to ignore pixels labelled -1. Remember in our label we only had 0 (no tea) and 1 (tea). However, sometimes we may have unknowns and label them as -1 or any other number. Such unknowns should be ignored in loss calculation. Adam (Adaptive Moment Estimation) is an optimizer that automatically adjusts the learning rate during training. It helps to attain faster and more stable convergence than older methods like SGD. lr=0.001 marks the initial learning rate. It is a smaller number that controls how big the model’s updates are. 0.001 is a common default that works well. If too low, the model can be very slow.
12 Evaluation metrics
When training our model, we need to know whether the model is learning from the samples it is exposed to during each epoch. Therefore, we need to capture some evaluation metrics. There are so many evaluation metrics such as loss, accuracy, intersection over union, F1, recall etc. Here, we will only look at two for now; loss and accuracy.
13 Training loop
Most of the things going into the training loop we have already defined. Note that num_epochs is the number of times we want the model to go through all samples. We can set it to larger number if we see that the model is still learning. However, making it too big can delay the whole process and maybe waste precious energy/electricity. Set it modestly. Also note that we are capturing the evaluation metrics in a json file so that we can load them later to visualize.
# Training loop
num_epochs = 50
best_accuracy = 0.0 # Track best accuracy (or IoU)
# Loss ignores pixels with value -1
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
total_correct = 0
total_pixels = 0
with tqdm(dataloader, desc=f"Epoch {epoch + 1} / {num_epochs}") as pbar:
for batch in pbar:
images = batch["image"][:, :12, :, :].to(device)
masks = batch["mask"].to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, masks.long())
epoch_loss += loss.item()
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ---- Accuracy ignoring -1 pixels ----
preds = torch.argmax(outputs, dim=1)
valid = masks != -1 # Boolean mask of valid pixels
total_correct += (preds[valid] == masks[valid]).sum().item()
total_pixels += valid.sum().item() # Count only valid pixels
pbar.set_postfix(loss=loss.item())
# Accuracy computed only on valid pixels
epoch_accuracy = total_correct / total_pixels * 100
metrics["loss"].append(epoch_loss)
metrics["accuracy"].append(epoch_accuracy)
print(f"Epoch [{epoch + 1}/{num_epochs}], "
f"Loss: {epoch_loss:.4f}, "
f"Accuracy (valid only): {epoch_accuracy:.2f}%")
# Save training metrics
with open("/content/training_metrics.json", "w") as f:
json.dump(metrics, f)
print("Training metrics saved to '/content/training_metrics.json'")
# Save best model based on accuracy
if epoch_accuracy > best_accuracy:
best_accuracy = epoch_accuracy
torch.save(model.state_dict(), "/content/best_tea_model.pth")
print(f"Best model saved with accuracy: {best_accuracy:.2f}%")14 Visualizing training metrics
In this section of the code, we visualize the evaluation metrics. Normally, the accuracy should increase from one epoch to the next and plateau at some point. As it plateaus, the best model is achieved and can be saved based on the accuracy so that it can be used for predictions. At the same time, the loss should be decreasing over time. If it still rises, consider increasing the number of epochs.
# Load training metrics from JSON file
with open("/content/training_metrics.json", "r") as f:
metrics = json.load(f)
# Extract loss and accuracy values
loss_values = metrics["loss"]
accuracy_values = metrics["accuracy"]
epochs = range(1, len(loss_values) + 1)
# Create the plots
fig, ax1 = plt.subplots()
# Plot loss
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss", color="tab:red")
ax1.plot(epochs, loss_values, label="Loss", color="tab:red")
ax1.tick_params(axis="y", labelcolor="tab:red")
# Create a second y-axis for accuracy
ax2 = ax1.twinx()
ax2.set_ylabel("Accuracy (%)", color="tab:blue")
ax2.plot(epochs, accuracy_values, label="Accuracy", color="tab:blue")
ax2.tick_params(axis="y", labelcolor="tab:blue")
fig.tight_layout()
plt.title("Training Loss & Accuracy Over Epochs")
plt.show()15 Prediction over the whole study area
This is the main reason why the model was trained. To help us predict other tea plantations outside the training region. So, we will use the trained model and apply it on the image to make predictions and then merge the prediction outputs into one binary map which we can export and run a number of uses on.
# Set model to evaluation mode
model.eval()
# Create a GridGeoSampler for full coverage prediction
grid_sampler = GridGeoSampler(
dataset = raster_data,
size = 32, # Same as training
stride = 4 # Overlapping to avoid edge artifacts
)We do not include labels or masks layers in the below code. This is because we want to predict them, so we only need the images. Again, in a typical machine learning workflow, the layers could come from somewhere else, that is, you can train model in region A and use it to make predictions in region B. The same applies to time.
# Iterate through the dataloader to confirm that it is able to load the data
class GeoTiffDataset(RasterDataset):
filename_glob = "*win2025_dl.tif"
raster_data = GeoTiffDataset(paths = "/content")
# Create prediction dataloader
pred_dataloader = DataLoader(
dataset = raster_data,
batch_size = 50,
sampler = grid_sampler,
collate_fn = stack_samples
)
# Confirming that the batch has no masks
for batch in pred_dataloader:
image = batch["image"][:, :11, :, :]
print(f"Image batch length: {len(image)}")
print(batch.keys())
break# Get raster metadata from the first file for output
with rasterio.open(raster_data.files[0]) as src:
profile = src.profile
transform = src.transform
crs = src.crs
height = src.height
width = src.width
# Initialize empty arrays
full_pred = np.zeros((height, width), dtype=np.float32)
count = np.zeros((height, width), dtype=np.uint8)
# Make predictions
with torch.no_grad():
for batch in tqdm(pred_dataloader, desc="Making predictions"):
# Handle image bands
images = batch["image"]
if images.shape[1] > 12:
images = images[:, :12, :, :]
images = images.to(device)
# Get predictions
outputs = model(images)
preds = torch.argmax(outputs, dim=1).cpu().numpy()
# Process each prediction
for i in range(len(preds)):
bbox = batch["bounds"][i]
minx, maxx = bbox.minx, bbox.maxx
miny, maxy = bbox.miny, bbox.maxy
# Convert to pixel coordinates
col_off, row_off = ~transform * (minx, maxy)
col_off, row_off = int(col_off), int(row_off)
# Get predicted patch dimensions
pred_height, pred_width = preds[i].shape[-2], preds[i].shape[-1]
# Calculate valid window in output array
row_start = max(row_off, 0)
col_start = max(col_off, 0)
row_end = min(row_off + pred_height, height)
col_end = min(col_off + pred_width, width)
# Calculate corresponding window in prediction
pred_row_start = max(0, -row_off)
pred_col_start = max(0, -col_off)
pred_row_end = pred_row_start + (row_end - row_start)
pred_col_end = pred_col_start + (col_end - col_start)
# Only proceed if we have valid dimensions
if (row_end > row_start) and (col_end > col_start):
full_pred[row_start:row_end, col_start:col_end] += preds[i, pred_row_start:pred_row_end, pred_col_start:pred_col_end]
count[row_start:row_end, col_start:col_end] += 1
# Average predictions
valid_mask = count > 0
full_pred[valid_mask] = full_pred[valid_mask] / count[valid_mask]
full_pred = np.round(full_pred).astype(np.uint8)
full_pred[~valid_mask] = 255 # nodata
# Update and save output
profile.update(
dtype=rasterio.uint8,
count=1,
compress='lzw',
nodata=255
)
with rasterio.open('/content/dl_prediction.tif', 'w', **profile) as dst:
dst.write(full_pred, 1)
dst.write_colormap(1, {
0: (0, 100, 0, 255), # Green for class 0
1: (255, 215, 0, 255), # Gold for class 1
255: (0, 0, 0, 0) # Transparent for nodata
})
print("Prediction saved as '/content/dl_prediction.tif'")Finally we can read the predicted output and visualize it in gray scale.
with rasterio.open('/content/dl_prediction.tif') as src:
pred = src.read(1)
# Plot
plt.imshow(pred, cmap = 'gray')
plt.axis('off')
plt.show()In the following code we evaluate Raster-Based Predictions Against Ground Truth Points. That is, some points where we know are tea or no tea fields based on actual field visits.
We perform the following steps:
- Load the prediction raster.
- Load evaluation points from a GeoPackage.
- Extract predicted values at each point location.
- Filters out nodata values (255).
- Computes accuracy, confusion matrix, and a classification report.
Required inputs:
dl_prediction.tif: Raster file containing predicted class labels.eval.gpkg: GeoPackage containing point geometries and a ‘tea_no_tea’ label field.
# 1. Load prediction raster
with rasterio.open('/content/dl_prediction.tif') as src:
pred_raster = src.read(1)
transform = src.transform
# 2. Load evaluation points
gdf = gpd.read_file('/content/eval.gpkg')
# 3. Extract predicted values at point locations
coords = [(x,y) for x, y in zip(gdf.geometry.x, gdf.geometry.y)]
with rasterio.open('/content/dl_prediction.tif') as src:
sampled_pred = list(src.sample(coords))
sampled_pred = np.array([val[0] for val in sampled_pred])
# 4. Remove nodata points (255)
valid_idx = sampled_pred != 255
sampled_pred = sampled_pred[valid_idx]
true_labels = gdf.loc[valid_idx, 'tea_no_tea'].values
# 5. Run evaluation
print("Accuracy:", accuracy_score(true_labels, sampled_pred))
print("Confusion Matrix:\n", confusion_matrix(true_labels, sampled_pred))
print("Classification Report:\n", classification_report(true_labels, sampled_pred))16 Advanced code
In this code we implement slightly advanced functionalities in a U-Net model for semantic segmentation of satellite imagery (GeoTIFF) with labeled vector data (GeoPackage). Key features include:
Data Augmentation: Uses albumentations for on-the-fly augmentations including RandomRotate90 and Horizontal/Vertical flips.
Samplers: RandomGeoSampler for training/validation/testing (1000/300/200 samples), GridGeoSampler for inference (stride = 16).
Metrics Tracked: Accuracy (pixel-wise), IoU (Jaccard Index) per class, with NaN handling for empty classes.
Validation: Best model saved based on IoU.
Outputs: Metrics (loss, accuracy, IoU) logged per epoch to training_metrics.csv.
# Import the necessary libraries
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
augmentation = A.Compose([
A.RandomRotate90(p = 0.5),
A.HorizontalFlip(p = 0.5),
A.VerticalFlip(p = 0.5)
])
def stack_samples_aug(batch):
images = [torch.nan_to_num(sample['image'], nan = 0.0) for sample in batch]
labels = [sample['mask'] for sample in batch]
augmented_images = []
augmented_labels = []
for img, lbl in zip(images, labels):
img_np = img.permute(1, 2, 0).numpy()
lbl_np = lbl.squeeze(0).numpy()
augmented = augmentation(image = img_np, mask = lbl_np)
augmented_images.append(torch.tensor(augmented['image']).permute(2, 0, 1))
augmented_labels.append(torch.tensor(augmented['mask']).unsqueeze(0))
images = torch.stack(augmented_images)
labels = torch.stack(augmented_labels).squeeze(1).long()
return images, labels
dataloader = DataLoader(
dataset = training_data,
batch_size = 50,
sampler = sampler,
collate_fn = stack_samples_aug
)
class GeoTiffDataset(RasterDataset):
filename_glob = 's2_large.tif'
raster_data = GeoTiffDataset(paths = '/content/')
class LabelDataset(VectorDataset):
filename_glob = 'vect.gpkg'
label_data = LabelDataset(paths = '/content/', label_name = 'tea_no_tea')
label_data.is_image = False
print(raster_data)
print(label_data)
dataset = raster_data & label_data
train_sampler = RandomGeoSampler(dataset=dataset, size = 32, length = 1000)
val_sampler = RandomGeoSampler(dataset=dataset, size = 32, length = 300)
test_sampler = RandomGeoSampler(dataset=dataset, size = 32, length = 200)
pred_sampler = GridGeoSampler(dataset=dataset, size = 32, stride = 16)
train_loader = DataLoader(dataset = dataset, sampler = train_sampler, batch_size = 50, collate_fn = stack_samples_aug)
val_loader = DataLoader(dataset = dataset, sampler = val_sampler, batch_size = 50, collate_fn = stack_samples_aug)
test_loader = DataLoader(dataset = dataset, sampler = test_sampler, batch_size = 50, collate_fn = stack_samples_aug)
pred_loader = DataLoader(dataset = dataset, sampler = pred_sampler, batch_size = 50, collate_fn = stack_samples_aug)
model = torchseg.Unet(
encoder_name = 'resnet18',
encoder_weights = False,
in_channels = 10,
classes = 2,
).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
# Metrics
def compute_accuracy(output, labels):
_, preds = torch.max(output, dim = 1)
correct = (preds == labels).sum().item()
total = labels.numel()
return correct/total
def compute_iou(output, labels, num_classes = 2):
_, preds = torch.max(output, dim = 1)
ious = []
for cls in range(num_classes):
intersection = ((preds == cls) & (labels == cls)).sum().item()
union = ((preds == cls) | (labels == cls)).sum().item()
if union == 0:
ious.append(float('nan'))
else:
ious.append(intersection / union)
mean_iou = torch.nanmean(torch.tensor(ious))
return mean_iou.item()
metrics_df = pd.DataFrame(columns=['epoch', 'train_loss', 'train_accuracy', 'train_iou', 'val_loss', 'val_accuracy', 'val_iou'])
best_val_iou = 0.0
best_epoch = 0
best_model_path = '/content/best_model.pth'
epochs = 10
for epoch in range(epochs):
model.train()
running_loss = 0.0
running_accuracy = 0.0
running_iou = 0.0
batch_count = 0
with tqdm(total = len(train_loader), desc = f"Epoch {epoch + 1}/{epochs} [Train]", dynamic_ncols = True, leave = True, position = 0) as pbar:
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
output = model(images)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_accuracy += compute_accuracy(output, labels)
running_iou += compute_iou(output, labels)
batch_count += 1
avg_loss = running_loss / batch_count
avg_accuracy = running_accuracy / batch_count
avg_iou = running_iou / batch_count
pbar.set_postfix(loss = f"{avg_loss:.4f}", accuracy = f"{avg_accuracy:.2%}", iou=f"{avg_iou:.4f}")
pbar.update(1)
print(f"Training - Epoch {epoch + 1}/{epochs}: Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_accuracy:.2%}, Avg IoU: {avg_iou:.4f}")
model.eval()
val_loss = 0.0
val_accuracy = 0.0
val_iou = 0.0
val_batch_count = 0
with torch.no_grad():
with tqdm(total = len(val_loader), desc=f"Epoch {epoch + 1}/{epochs} [Validation]", dynamic_ncols=True, leave = True, position = 0) as pbar:
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = loss_fn(output, labels)
val_loss += loss.item()
val_accuracy += compute_accuracy(output, labels)
val_iou += compute_iou(output, labels)
val_batch_count += 1
avg_val_loss = val_loss / val_batch_count
avg_val_accuracy = val_accuracy / val_batch_count
avg_val_iou = val_iou / val_batch_count
pbar.set_postfix(loss = f"{avg_val_loss:.4f}", accurcy=f"{avg_val_accuracy:.2%}", iou = f"{avg_val_iou:.4f}")
pbar.update(1)
if avg_val_iou > best_val_iou:
best_val_iou = avg_val_iou
best_epoch = epoch + 1
torch.save(model.state_dict(), best_model_path)
print(f"New best model saved at epoch: {best_epoch} with validation IoU of: {best_val_iou:.4f}")
print(f"Validation - Epoch {epoch + 1}/{epochs}: Avg Loss: {avg_val_loss:.4f}, Avg Accuracy: {avg_val_accuracy:.2%}, Avg IoU: {avg_val_iou:.4f}")
new_row = pd.DataFrame([{
"epoch": epoch + 1,
"train_loss": avg_loss,
"train_accuracy": avg_accuracy,
"train_iou": avg_iou,
"val_loss": avg_val_loss,
"val_accuracy": avg_val_accuracy,
"val_iou": avg_val_iou
}])
metrics_df = pd.concat([metrics_df, new_row], ignore_index = True)
metrics_df.to_csv("training_metrics.csv", index= False)
print("Metrics saved to training_metrics.csv")17 Exercise (Optional)
Given the advanced model, make prediction over the whole study region. Compare with the simpler model before.
18 Summary
In this script we demonstrate a robust pipeline for geospatial image segmentation task using a U-Net model with extensive data augmentation (rotations, flips) to improve generalization. By leveraging both raster and vector geospatial data, we train the model with cross-entropy loss and evaluated using multiple metrics, including pixel-wise accuracy and class-specific IoU. The best model is dynamically saved based on validation IoU, while all training and validation metrics are systematically logged for further analysis in other coding environments like R. This approach provides a scalable framework for satellite image segmentation tasks, balancing augmentation, model optimization, and detailed performance tracking. We hope you find it useful for related tasks.
19 References
Buslaev, A., Parinov, A., Khvedchenya, E., Iglovikov, V. I., & Kalinin, A. A. (2018). Albumentations: Fast and flexible image augmentations. ArXiv e-prints. arXiv:1809.06839.