如何采用U-Net作为基础模型训练使用水体分割遥感图像数据集_使用深度学习模型来进行水体分割的遥感图像数据集 图像分割任务
水体分割遥感图像数据集
2841张卫星拍摄的水体图像集合,每张mask标签,其中白色代表水,黑色代表水以外的其他东西。
1
1
针对水体分割的遥感图像数据集,我们可以使用深度学习模型来进行图像分割任务。采用U-Net作为基础模型,适用于遥感图像的分割任务。以下是完整的流程,包括数据准备、模型定义、训练过程、评估和推理及可视化。
数据准备
首先需要定义一个自定义的数据集类来加载和预处理你的数据集。假设图像和mask标签分别存储在两个文件夹中,每张图像都有对应的mask。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class WaterBodySegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx]) # 假设图像和mask同名但位于不同目录
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # 确保mask是灰度图
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
# 数据转换
transform = transforms.Compose([
transforms.Resize((400, 400)), # 根据实际情况调整大小
transforms.ToTensor(),
])
train_dataset = WaterBodySegmentationDataset(image_dir='path_to_train_images', mask_dir='path_to_train_masks', transform=transform)
val_dataset = WaterBodySegmentationDataset(image_dir='path_to_val_images', mask_dir='path_to_val_masks', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
模型定义
接下来,我们定义一个U-Net模型结构用于图像分割任务。
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.enc1 = self.conv_block(3, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.pool = nn.MaxPool2d(2)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = self.conv_block(128, 64)
self.out_conv = nn.Conv2d(64, 1, kernel_size=1) # 输出层,二分类问题
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
dec3 = self.upconv3(enc4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.dec3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.dec1(dec1)
return torch.sigmoid(self.out_conv(dec1)) # 使用sigmoid函数输出概率
训练过程
定义训练循环:
model = UNet()
criterion = nn.BCELoss() # 二元交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 20
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader.dataset)}")
模型优化
可以考虑使用学习率调度器和早停策略来优化模型性能:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
def validate(model, val_loader, criterion):
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, masks in val_loader:
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item() * images.size(0)
return val_loss / len(val_loader.dataset)
for epoch in range(num_epochs):
# ... 训练过程 ...
val_loss = validate(model, val_loader, criterion)
scheduler.step(val_loss)
print(f"Validation Loss: {val_loss}")
推理及可视化
推理并可视化结果:
import matplotlib.pyplot as plt
def visualize_predictions(model, dataloader, num_images=5):
model.eval()
with torch.no_grad():
for i, (images, masks) in enumerate(dataloader):
if i >= num_images:
break
outputs = model(images)
preds = (outputs > 0.5).float() # 阈值为0.5
fig, axarr = plt.subplots(1, 3)
axarr[0].imshow(images[0].permute(1, 2, 0).numpy()) # 显示原始图像
axarr[1].imshow(masks[0].squeeze().numpy(), cmap='gray') # 显示真实标签
axarr[2].imshow(preds[0].squeeze().numpy(), cmap='gray') # 显示预测结果
plt.show()
visualize_predictions(model, val_loader)
通过上述步骤,您可以有效地利用水体分割的遥感图像数据集进行水体检测任务。请根据实际情况调整代码中的细节-