利用卷积神经网络(CNN)U-Net模型来进行河流遥感图像分割数据集的图像分割。深度学习方法来构建完成河流遥感图像分割任务
河流遥感图像分割数据集,8975张400*400数据集,训练集5385,验证测试1795
1
1
针对河流遥感图像分割任务,使用深度学习方法来实现这个任务,特别是利用卷积神经网络(CNN)如U-Net模型来进行图像分割。
数据准备
首先,我们需要定义一个自定义的数据集类来加载和预处理数据集。这里假设你的数据是以文件夹的形式组织,并且每张图像都有对应的标签图(ground truth)。
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class RiverSegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.png')) # 假设标签图像是png格式并且以_mask结尾
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L")
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
# 数据转换
transform = transforms.Compose([
transforms.Resize((400, 400)),
transforms.ToTensor(),
])
train_dataset = RiverSegmentationDataset(image_dir='path_to_train_images', mask_dir='path_to_train_masks', transform=transform)
val_dataset = RiverSegmentationDataset(image_dir='path_to_val_images', mask_dir='path_to_val_masks', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
模型定义
接下来,我们定义一个U-Net模型结构用于图像分割任务。
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 定义U-Net各层
self.enc1 = self.conv_block(3, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.pool = nn.MaxPool2d(2)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.dec3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.dec2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.dec1 = self.conv_block(128, 64)
self.out_conv = nn.Conv2d(64, 1, kernel_size=1) # 输出层,假设二分类
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
dec3 = self.upconv3(enc4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.dec3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.dec1(dec1)
return torch.sigmoid(self.out_conv(dec1)) # 使用sigmoid函数输出概率
训练过程
定义训练循环:
model = UNet()
criterion = nn.BCELoss() # 二元交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 20
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for images, masks in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader.dataset)}")
模型优化
可以考虑使用学习率调度器和早停策略来优化模型性能:
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
def validate(model, val_loader, criterion):
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, masks in val_loader:
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item() * images.size(0)
return val_loss / len(val_loader.dataset)
for epoch in range(num_epochs):
# ... 训练过程 ...
val_loss = validate(model, val_loader, criterion)
scheduler.step(val_loss)
print(f"Validation Loss: {val_loss}")
推理及可视化
推理并可视化结果:
import matplotlib.pyplot as plt
def visualize_predictions(model, dataloader, num_images=5):
model.eval()
with torch.no_grad():
for i, (images, masks) in enumerate(dataloader):
if i >= num_images:
break
outputs = model(images)
preds = (outputs > 0.5).float() # 阈值为0.5
fig, axarr = plt.subplots(1, 3)
axarr[0].imshow(images[0].permute(1, 2, 0).numpy()) # 显示原始图像
axarr[1].imshow(masks[0].squeeze().numpy(), cmap='gray') # 显示真实标签
axarr[2].imshow(preds[0].squeeze().numpy(), cmap='gray') # 显示预测结果
plt.show()
visualize_predictions(model, val_loader)