Bootstrap

遥感图像变化检测数据集制作,包括数据集划分,数据集裁剪,以及用数据增强方式扩充数据集

代码如下:

 

import cv2
import os
import numpy as np
def main():
   

# 加载图片a,b,label
    img_a = cv2.imread("a.jpg")
    img_b = cv2.imread("b.jpg")
    img_label = cv2.imread("label.jpg")
   

# 设置裁剪大小
    crop_size = 256
   

# 输入图片大小是20000*20000的,裁剪大小是256*256的,步长是128,计算裁剪个数
    num_crop = (20000 - 256) // 128 + 1
   

# 计算每次裁剪的步长
    step = 128
   

# 循环裁剪
    crops_a = []
    crops_b = []
    crops_label = []
    for i in range(num_crop):
        x = i * step
        y = i * step
        crop_a = img_a[x:x + crop_size, y:y + crop_size, :]
        if crop_a.shape != (256, 256, 3):
            continue
        else:
            crops_a.append(crop_a)

        crop_b = img_b[x:x + crop_size, y:y + crop_size, :]
        if crop_b.shape != (256, 256, 3):
            continue
        else:
            crops_b.append(crop_b)

        crop_label = img_label[x:x + crop_size, y:y + crop_size, :]
        if crop_label .shape != (256, 256, 3):
            continue
        else:
            crops_label.append(crop_label)

   

# 对crops_a, crops_b, crops_label列表里面的图片进行数据增强来扩充数据集
   

# 水平翻转
    crops_a.extend([cv2.flip(img, 1) for img in crops_a])
    crops_b.extend([cv2.flip(img, 1) for img in crops_b])
    crops_label.extend([cv2.flip(img, 1) for img in crops_label])
   

# 垂直翻转
    crops_a.extend([cv2.flip(img, 0) for img in crops_a])
    crops_b.extend([cv2.flip(img, 0) for img in crops_b])
    crops_label.extend([cv2.flip(img, 0) for img in crops_label])
   

# 水平垂直翻转
    crops_a.extend([cv2.flip(img, -1) for img in crops_a])
    crops_b.extend([cv2.flip(img, -1) for img in crops_b])
    crops_label.extend([cv2.flip(img, -1) for img in crops_label])
   

# 旋转90度
    crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_a])
    crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_b])
    crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) for img in crops_label])
   

# 旋转180度
    crops_a.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_a])
    crops_b.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_b])
    crops_label.extend([cv2.rotate(img, cv2.ROTATE_180) for img in crops_label])
   

# 旋转270度
    crops_a.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_a])
    crops_b.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_b])
    crops_label.extend([cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) for img in crops_label])
   

# 调整亮度
    crops_a.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_a])
    crops_b.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_b])
    crops_label.extend([cv2.addWeighted(img, 1.5, img, 0, 0) for img in crops_label])
   

# 调整对比度
    crops_a.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_a])
    crops_b.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_b])
    crops_label.extend([cv2.addWeighted(img, 1, img, 0, -50) for img in crops_label])
   

# 调整饱和度
    crops_a.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_a])
    crops_b.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_b])
    crops_label.extend([cv2.cvtColor(img, cv2.COLOR_BGR2HSV) for img in crops_label])
   

# 调整色相
    crops_a.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_a])
    crops_b.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_b])
    crops_label.extend([cv2.cvtColor(img, cv2.COLOR_HSV2BGR) for img in crops_label])

   

# 分别统计三个列表的元素个数,并赋值给变量count_a, count_b, count_label
    count_a = len(crops_a)
    count_b = len(crops_b)
    count_label = len(crops_label)
   

# 创建train, val, test文件夹
    for dirname in ["train", "val", "test"]:
        if not os.path.exists(dirname):
            os.makedirs(dirname)

   

# 创建train文件夹下的A,B,Label目录
    for dirname in ["train/A", "train/B", "train/Label"]:
        if not os.path.exists(dirname):
            os.makedirs(dirname)

   

# 创建val文件夹下的A,B,Label目录
    for dirname in ["val/A", "val/B", "val/Label"]:
        if not os.path.exists(dirname):
            os.makedirs(dirname)

 

# 创建test文件夹下的A,B,Label目录  for dirname in ["test/A", "test/B", "test/Label"]:
        if not os.path.exists(dirname):
            os.makedirs(dirname)

# 保存图片
    def crop_and_save_images(img_list, prefix):
        for idx, img in enumerate(img_list):
            cv2.imwrite(f"{prefix}/{str(idx).zfill(4)}.jpg", img)


    crop_and_save_images(crops_a[:count_a//10*8], "train/A")
    crop_and_save_images(crops_b[:count_b//10*8], "train/B")
    crop_and_save_images(crops_label[count_label//10*8], "train/Label")
    crop_and_save_images(crops_a[count_a//10*8:count_a//10*9], "val/A")
    crop_and_save_images(crops_b[count_b//10*8:count_b//10*9], "val/B")
    crop_and_save_images(crops_label[count_label//10*8:count_label//10*9], "val/Label")
    crop_and_save_images(crops_a[count_a//10*9:], "test/A")
    crop_and_save_images(crops_b[count_b//10*9:], "test/B")
    crop_and_save_images(crops_label[count_label//10*9:], "test/Label")


if __name__ == "__main__":
    main()

;