Bootstrap

yolo数据集分割

项目场景:

最近在做一个手机镜头缺陷检测的项目,数据集由电子显微镜拍摄制作,分辨率为1920*1080。


问题描述

使用原始数据集进行训练时效果不佳,于是尝试将数据集切割成640*640进行测试,记录一份将原始数据集切割成小分辨率的图片,并将标注信息也重新保存的python代码。


解决方案:

修改以下参数:

# 定义类别名称
class_names = ['class1', 'class2', 'class3']  # 根据你的类别名称修改
# 数据集路径
images_dir = #
labels_dir = #
image_save_dir =   # 保存裁剪后图像的目录
label_save_dir =  # 保存裁剪后标签的目录
# 裁剪参数
crop_size = 640  # 裁剪图像大小
step_size = 320  # 步长
# 标志位
flag = 0  # 1 表示不需要等待按键,直接生成切割的图片和标注信息;0 表示逐个查看图片和标注
import cv2
import os

def load_labels(label_path):
    labels = []
    with open(label_path, 'r') as file:
        for line in file:
            parts = line.strip().split()
            class_id = int(parts[0])
            x_center = float(parts[1])
            y_center = float(parts[2])
            width = float(parts[3])
            height = float(parts[4])
            labels.append((class_id, x_center, y_center, width, height))
    return labels

def draw_yolo_bboxes(image, labels, class_names):
    h, w, _ = image.shape
    for label in labels:
        class_id, x_center, y_center, bbox_width, bbox_height = label
        # 转换回原始图像尺寸
        x_min = int((x_center - bbox_width / 2) * w)
        y_min = int((y_center - bbox_height / 2) * h)
        x_max = int((x_center + bbox_width / 2) * w)
        y_max = int((y_center + bbox_height / 2) * h)

        # 绘制边界框
        cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 1)
        # 添加标签文本
        label_text = class_names[class_id]
        cv2.putText(image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)

def convert_labels_to_original(labels, img_width, img_height):
    original_labels = []
    for label in labels:
        class_id, x_center, y_center, bbox_width, bbox_height = label
        x_center *= img_width
        y_center *= img_height
        bbox_width *= img_width
        bbox_height *= img_height
        original_labels.append((class_id, x_center, y_center, bbox_width, bbox_height))
    return original_labels

def convert_labels_to_yolo(labels, img_width, img_height):
    yolo_labels = []
    for label in labels:
        class_id, x_center, y_center, bbox_width, bbox_height = label
        x_center /= img_width
        y_center /= img_height
        bbox_width /= img_width
        bbox_height /= img_height
        yolo_labels.append((class_id, x_center, y_center, bbox_width, bbox_height))
    return yolo_labels

def crop_and_save(image, original_labels, crop_size, step_size, image_save_dir, label_save_dir, image_filename, class_names):
    img_height, img_width, _ = image.shape
    crop_id = 0

    for y in range(0, img_height, step_size):
        for x in range(0, img_width, step_size):
            crop = image[y:y+crop_size, x:x+crop_size]

            if crop.shape[0] != crop_size or crop.shape[1] != crop_size:
                continue

            crop_labels = []
            for label in original_labels:
                class_id, x_center, y_center, bbox_width, bbox_height = label
                x_min = x_center - bbox_width / 2
                y_min = y_center - bbox_height / 2
                x_max = x_center + bbox_width / 2
                y_max = y_center + bbox_height / 2

                if x_min >= x and x_max <= x + crop_size and y_min >= y and y_max <= y + crop_size:
                    x_center -= x
                    y_center -= y
                    crop_labels.append((class_id, x_center, y_center, bbox_width, bbox_height))

            if not crop_labels:
                continue

            # 保存裁剪后的图像
            crop_image_path = os.path.join(image_save_dir, f"{os.path.splitext(image_filename)[0]}_crop_{crop_id}.jpg")
            cv2.imwrite(crop_image_path, crop)

            # 转换并保存新的 YOLO 标签
            crop_yolo_labels = convert_labels_to_yolo(crop_labels, crop_size, crop_size)
            crop_label_path = os.path.join(label_save_dir, f"{os.path.splitext(image_filename)[0]}_crop_{crop_id}.txt")
            with open(crop_label_path, 'w') as f:
                for label in crop_yolo_labels:
                    f.write(f"{label[0]} {label[1]} {label[2]} {label[3]} {label[4]}\n")

            crop_id += 1


# 定义类别名称
class_names = ['class1', 'class2', 'class3']  # 根据你的类别名称修改

# 数据集路径
#images_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\images"
images_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\save_image"
#labels_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\labels_txt"
labels_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\save_label"
image_save_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\save_image"  # 保存裁剪后图像的目录
label_save_dir = r"F:\ProjectSpace\17.DataSet\03.MobilePhoneLens\00.ConvexLens\mydata\train\save_label"  # 保存裁剪后标签的目录

# 创建保存目录(如果不存在)
os.makedirs(image_save_dir, exist_ok=True)
os.makedirs(label_save_dir, exist_ok=True)

# 裁剪参数
crop_size = 640  # 裁剪图像大小
step_size = 320  # 步长

# 获取所有图像文件
image_files = [f for f in os.listdir(images_dir) if f.endswith('.jpg')]

# 初始化当前图片索引
current_index = 0

# 标志位
flag = 0  # 1 表示不需要等待按键,直接生成切割的图片和标注信息;0 表示逐个查看图片和标注

while True:
    if current_index >= len(image_files):
        print("No more images.")
        break

    image_filename = image_files[current_index]
    image_path = os.path.join(images_dir, image_filename)
    label_path = os.path.join(labels_dir, image_filename.replace('.jpg', '.txt'))

    # 读取图像
    image = cv2.imread(image_path)
    # 读取标签
    labels = load_labels(label_path)
    # 转换标签到原始图像尺寸
    original_labels = convert_labels_to_original(labels, image.shape[1], image.shape[0])



    if flag == 0:
        draw_yolo_bboxes(image, labels, class_names)
        # 显示图像
        cv2.imshow('Image', image)
        # 等待按键事件
        key = cv2.waitKey(0)
        if key == ord(' '):  # 按下空格键显示下一张图片
            current_index += 1
        elif key == 27:  # 按下 ESC 键退出
            break
    else:
        current_index += 1
        # 裁剪图像并保存裁剪图像和新的标签
        crop_and_save(image, original_labels, crop_size, step_size, image_save_dir, label_save_dir, image_filename,
                      class_names)

cv2.destroyAllWindows()

;