因为有一批数据有点儿小,数据质量不佳,为了标注方便使用数据增强将数据固定在1080P,方便标注,
# -*- coding: UTF-8 -*-
"""
@Project :yolov5_relu_fire_smoke_v1.4
@IDE :PyCharm
@Author :沐枫
@Date :2024/4/2 20:28
添加白条,做数据增强,最后所有的图片尺寸固定在1080P
"""
import os
import multiprocessing
from concurrent import futures
from copy import deepcopy
import cv2
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom
def xyxy2xywh(x):
"""
Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
:param x:
:return:
"""
y = np.copy(x)
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def xywh2xyxy(x: np.ndarray):
"""
Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
:param x:
:return:
"""
y = np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y
def decodeVocAnnotation(voc_xml_path, class_index_dict):
"""
voc数据集格式的文件解析,将一个文件解析成一个list,
使用空格间隔不同对象
注意:返回的类别不是整型,而是字符串的类别名称
注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图
:param voc_xml_path: xml路径
:param class_index_dict: 类别字典
:return: [(cls_index, x1, y1, x2, y2), ...]
"""
assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml"
with open(voc_xml_path, 'r', encoding='utf-8') as xml_file:
# 打开xml文件,并返回根节点
root = ET.ElementTree().parse(xml_file)
# 定义一个列表,专门保存目标
information = []
# 查找root节点下所有目标信息
for obj in root.iter('object'):
# 目标的名称
name = obj.find('name').text
# 目标的bbox坐标,一般voc是保存的corner格式的bbox
box = obj.find('bndbox')
xmin = box.find('xmin').text
ymin = box.find('ymin').text
xmax = box.find('xmax').text
ymax = box.find('ymax').text
# 添加一个目标的信息
# NOTE:返回值的list
information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax)))
return information
def create_voc_xml(image_folder, image_filename, width: int, height: int, labels,
save_root, class_name_dict, conf_thresh_dict=None):
"""
:param image_folder: 图片的相对路径
:param image_filename: 000001.jpg
:param width: 图片宽
:param height: 图片高
:param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...]
:param save_root: 保存xml的根目录
:param class_name_dict: cls_index:cls_name,根据index获取正确的类别name
:param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度
:return:
"""
# 创建 XML 文件的根元素
root = ET.Element("annotation")
# 添加图片信息
folder = ET.SubElement(root, "folder")
folder.text = str(image_folder)
# 图片名字
filename = ET.SubElement(root, "filename")
filename.text = os.path.join(image_filename)
# 图片大小
size = ET.SubElement(root, "size")
width_element = ET.SubElement(size, "width")
width_element.text = str(width)
height_element = ET.SubElement(size, "height")
height_element.text = str(height)
depth = ET.SubElement(size, "depth") # 通道数
depth.text = "3"
# 添加目标框信息
for label in labels:
# 如果该参数设置为None,表示保存的txt没有None
if conf_thresh_dict is None:
# 保证这几项是整数
class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32)
else:
class_index, x1, y1, x2, y2, conf = label
# 保证这几项是整数
class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32)
# 根据置信度过滤是否保存项
if conf < conf_thresh_dict[class_index]:
continue
obj = ET.SubElement(root, "object")
name = ET.SubElement(obj, "name")
name.text = class_name_dict[int(class_index)]
pose = ET.SubElement(obj, "pose")
pose.text = "Unspecified"
truncated = ET.SubElement(obj, "truncated")
truncated.text = "0"
difficult = ET.SubElement(obj, "difficult")
difficult.text = "0"
bndbox = ET.SubElement(obj, "bndbox")
xmin = ET.SubElement(bndbox, "xmin")
xmin.text = str(x1)
ymin = ET.SubElement(bndbox, "ymin")
ymin.text = str(y1)
xmax = ET.SubElement(bndbox, "xmax")
xmax.text = str(x2)
ymax = ET.SubElement(bndbox, "ymax")
ymax.text = str(y2)
# 创建 XML 文件并保存
xml_str = ET.tostring(root, encoding="utf-8")
xml_str = minidom.parseString(xml_str)
# 设置缩进为4个空格,xml可读性提高
pretty_xml = xml_str.toprettyxml(indent=" " * 4)
save_path = os.path.join(save_root, f"{os.path.splitext(image_filename)[0]}.xml")
os.makedirs((os.path.dirname(save_path)), exist_ok=True)
with open(save_path, "w") as xmlFile:
xmlFile.write(pretty_xml)
def resize_and_pad(image: np.ndarray, labels: np.ndarray, width=1920, height=1080):
"""
:param image:
:param labels: (cls_id, x, y, w, h)
:param width:
:param height:
:return:
"""
def _resize(image: np.ndarray, labels: np.ndarray, width, height):
"""
:param image:
:param labels: (cls_id, x, y, w, h)
:param width:
:param height:
:return:
image: 最后的图片
labels: (cls_id, x, y, w, h)
"""
# 判断图片的尺寸,如果尺寸比目标尺寸大,就等比例缩放,斌使用纯白色填充,如果尺寸比目标尺寸小就直接填充到目标尺寸
img_h, img_w = image.shape[:2]
if img_w < width and img_h < height:
# 直接填充
# 填充的宽度和高度
dw = (width - img_w) // 2
dh = (height - img_h) // 2
# 创建一个新的蒙版
new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255
# 将图片填充到里面
new_image[dh:dh + img_h, dw:dw + img_w, :] = image[:, :, :]
# 标签平移,(cls_id, x, y, w, h)
labels[..., 1] += dw
labels[..., 2] += dh
else:
# 等比例缩放后再填充
# 计算宽度和高度的缩放比例
ratio = min((width / img_w), (height / img_h))
# 计算缩放后的宽度和高度
new_width = int(img_w * ratio)
new_height = int(img_h * ratio)
# 等比例缩放图像
resized_img = cv2.resize(image, (new_width, new_height))
# 计算需要填充的宽度和高度
dw = (width - new_width) // 2
dh = (height - new_height) // 2
# 创建一个新的蒙版
new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255
# 将图片填充到里面
new_image[dh:dh + new_height, dw:dw + new_width, :] = resized_img[:, :, :]
# 标签缩放,平移;(cls_id, x, y, w, h)
labels[..., 1:] *= ratio # 坐标和宽高都需要缩放
# 只有中心点需要平移,不影响宽高
labels[..., 1] += dw
labels[..., 2] += dh
return new_image, labels
SCALE = 2
# 原图的宽高
img_h, img_w = image.shape[:2]
# NOTE:先在外面扩大一次,写和内部函数,在判断,如果图片比目标尺寸大,就等比例缩放,如果图片比目标尺寸小,就直接填充
# 比较小,先扩大再等比例缩放;比较大,直接等比例缩放
if img_w < width and img_h < height:
new_w = img_w * SCALE
new_h = img_h * SCALE
# 图片扩大为原来的2倍
image = cv2.resize(image, (new_w, new_h))
# labels也扩大,因为图片扩大2倍,所以目标的中心点和宽高都会扩大同样的倍数
labels[..., 1:] *= SCALE
# 缩放和填充
new_image, new_labels = _resize(image, labels, width=width, height=height)
return new_image, new_labels
def run(image_path, xml_root,
image_root, save_image_root, save_xml_root,
class_index_dict, class_name_dict):
image_file = os.path.basename(image_path)
image_name, suffix = os.path.splitext(image_file)
xml_path = image_path.replace(image_root, xml_root).replace(suffix, ".xml")
if not os.path.exists(xml_path):
print(f"\n{image_path} no xml\n")
return
try:
# 读图
image = cv2.imread(image_path)
if image is None:
return
# (cls_id, x, y, w, h)
labels = decodeVocAnnotation(xml_path, class_index_dict)
if len(labels) == 0:
print(f"\n{image_path} no label\n")
return
labels = np.array(labels, dtype=np.float64)
if labels.ndim < 2:
labels = labels[None, ...]
# 坐标框转成xywh
labels[..., 1:] = xyxy2xywh(labels[..., 1:].copy())
# resize and pad
new_image, new_labels = resize_and_pad(image, labels.copy(), width=1920, height=1080)
new_img_h, new_img_w = new_image.shape[:2]
# 坐标框转成xyxy
new_labels[..., 1:] = xywh2xyxy(new_labels[..., 1:].copy())
# 开始保存
# save_image_path = image_path.replace(image_root, save_image_root)
save_image_path = os.path.join(save_image_root,
os.path.basename(os.path.dirname(image_path)),
f"aug_{image_file}")
save_xml_path = save_image_path.replace(save_image_root, save_xml_root).replace(suffix, ".xml")
os.makedirs(os.path.dirname(save_image_path), exist_ok=True)
os.makedirs(os.path.dirname(save_xml_path), exist_ok=True)
# 保存图片
cv2.imwrite(save_image_path, new_image)
# 创建xml
create_voc_xml(image_folder=save_image_path.replace(save_image_root + os.sep, ""),
image_filename=os.path.basename(save_image_path),
width=new_img_w,
height=new_img_h,
labels=np.array(new_labels, dtype=np.int32),
save_root=os.path.dirname(save_xml_path),
class_name_dict=class_name_dict, )
print(f"\r{image_path}", end='')
except Exception as e:
print(f"{image_path} {run.__name__}:{e}")
def run_process(root_file_list,
image_root, xml_root, save_image_root, save_xml_root,
class_index_dict, class_name_dict):
# 使用线程池控制程序执行
with futures.ThreadPoolExecutor(max_workers=5) as executor:
for root, file in root_file_list:
# 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
image_path = os.path.join(root, file)
executor.submit(run,
*(image_path, xml_root,
image_root, save_image_root, save_xml_root,
class_index_dict, class_name_dict))
if __name__ == '__main__':
class_index_dict = {
"fire": 0,
"smoke": 1,
}
class_name_dict = {
0: "fire",
1: "smoke",
}
data_root = r"Z:\Datasets\FireSmoke_v4"
data_root = os.path.abspath(data_root)
# 数据的原图根目录
image_root = os.path.join(data_root, "images")
# xml标注文件根目录
xml_root = os.path.join(data_root, "annotations")
# 保存根目录
save_image_root = os.path.join(image_root, "aug-pad")
save_xml_root = os.path.join(xml_root, "aug-pad")
# 过滤点不想用的目录
exclude_dirs = [
os.sep + r"background",
os.sep + r"candle_fire",
os.sep + r"AUG",
os.sep + r"smoke",
os.sep + r"val",
os.sep + r"aug-merge",
os.sep + r"candle_fire",
os.sep + r"cut_aug",
os.sep + r"miniFire",
os.sep + r"net",
os.sep + r"new_data",
os.sep + r"realScenario",
os.sep + r"smoke",
os.sep + r"TSMCandle",
]
max_workers = 10 # 线程/进程 数
print(f"max_workers:{max_workers}")
# 一个进程处理多少图片
max_file_num = 3000
# 保存root和file的list
root_file_list = list()
# 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
pool = multiprocessing.Pool(processes=max_workers)
for root, _, files in os.walk(image_root):
# 只要其中有一个,就跳过
if any(map(lambda x: x in root, exclude_dirs)):
continue
for file in files:
file_name, suffix = os.path.splitext(file)
if suffix.lower() not in (".jpg", ".jpeg", ".bmp", ".png"):
continue
root_file_list.append((root, file))
if len(root_file_list) > max_file_num:
pool.apply_async(run_process,
(deepcopy(root_file_list),
image_root, xml_root, save_image_root, save_xml_root,
class_index_dict, class_name_dict))
# 清除列表中的存储
root_file_list.clear()
else:
pool.apply_async(run_process,
(deepcopy(root_file_list),
image_root, xml_root, save_image_root, save_xml_root,
class_index_dict, class_name_dict))
# 清除列表中的存储
root_file_list.clear()
# 关闭进程池
pool.close()
# 等待所有子进程执行结束
pool.join()
print("\nFinish ...")