使用U-Net处理Postdam数据集进行语义分割任务 如何从准备数据到训练和评估一个基于U-Net的模型。训练使用遥感影像分析研究数据集
Postdam数据集
遥感影像-语义分割数据集:Postdam数据集
像素大小 512*512
训练图片为.tif 标签图片为.tif
数据集 (train3678张 val920张)
使用U-Net处理Postdam数据集进行语义分割,但它同样适用于其他类型的分割任务,包括遥感影像分析。以下是详细的步骤和代码示例,帮助同学从准备数据到训练和评估一个基于U-Net的模型。
1. 安装依赖
首先确保你的环境中已经安装了必要的库:
pip install torch torchvision tifffile segmentation-models-pytorch albumentations matplotlib
2. 数据准备
假设你的Postdam数据集已经被组织为包含.tif格式的图像和标签文件夹。我们将使用segmentation-models-pytorch
来简化U-Net的实现,并利用albumentations
进行数据增强。
创建自定义的数据加载器
创建一个Python脚本来加载和预处理数据。这里我们假定每个图像都有一个对应的标签(也是.tif格式),并且它们都位于同一目录下,但分别在不同的子文件夹中(例如,images/
和masks/
)。
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import tifffile
class PostdamDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
mask_path = os.path.join(self.mask_dir, self.images[index].replace('.tif', '_mask.tif'))
image = np.array(tifffile.imread(img_path))
mask = np.array(tifffile.imread(mask_path), dtype=np.float32)
mask[mask == 255.0] = 1.0 # 如果你的标签是以255表示背景或其他值,请根据实际情况调整
if self.transform is not None:
augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"]
mask = augmentations["mask"]
return image, mask
# 数据增强和转换
train_transform = A.Compose(
[
A.Resize(height=512, width=512),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
val_transform = A.Compose(
[
A.Resize(height=512, width=512),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
3. 模型定义
我们可以直接使用segmentation-models-pytorch
提供的U-Net实现。
import segmentation_models_pytorch as smp
def get_model():
model = smp.Unet(
encoder_name="resnet34", # 使用resnet34作为编码器
encoder_weights="imagenet", # 使用预训练权重
in_channels=3, # 输入通道数(RGB图像)
classes=6, # 输出类别数(根据实际数据集调整)
)
return model
4. 训练模型
接下来,编写Python脚本来训练模型。
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
def train_model(model, train_loader, val_loader, epochs, optimizer, criterion, device):
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
images = images.to(device)
masks = masks.long().to(device) # 确保目标张量类型为long
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks.squeeze(1)) # squeeze维度以匹配loss函数要求
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss}")
# 验证阶段略去,通常包括计算验证集上的损失和准确率等指标
# 加载数据
train_dataset = PostdamDataset(image_dir="path/to/train/images/", mask_dir="path/to/train/masks/", transform=train_transform)
val_dataset = PostdamDataset(image_dir="path/to/val/images/", mask_dir="path/to/val/masks/", transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
# 初始化模型、优化器和损失函数
model = get_model().to("cuda") # 或者 "cpu"
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
# 开始训练
train_model(model, train_loader, val_loader, epochs=20, optimizer=optimizer, criterion=criterion, device="cuda")
5. 可视化预测结果
训练完成后,可以对测试集中的图片进行预测并可视化结果。
def visualize_predictions(model, loader, device):
model.eval()
with torch.no_grad():
for images, masks in loader:
images = images.to(device)
outputs = model(images)
preds = torch.argmax(outputs, dim=1).cpu().numpy()
for i in range(len(images)):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(images[i].permute(1, 2, 0).cpu().numpy())
plt.title('Image')
plt.subplot(1, 2, 2)
plt.imshow(preds[i], cmap='gray')
plt.title('Prediction')
plt.show()
break # 仅展示一张图片为例
# 使用方法
visualize_predictions(model, val_loader, "cuda")
仅供参考,代码 ,步骤提供了一个完整的流程,从数据准备、模型定义到训练和结果可视化的完整指南,特别适用于使用U-Net进行Postdam数据集的语义分割任务。