红外船舶数据集 6000张 yolo格式
红外船舶数据集 6000张 yolo格式
红外船舶数据集包含6000张图片,并且已经标注为YOLO格式。我们可以使用这些数据来训练一个目标检测模型。我们使用YOLOv8架构来进行目标检测任务。
以下是详细的步骤和代码示例,帮助你开始使用这个数据集进行训练和评估。
代码仅供参考!!!
项目结构
ship_detection/
├── main.py
├── train.py
├── evaluate.py
├── infer.py
├── visualize.py
├── datasets/
│ ├── ship_images/
│ │ ├── images/
│ │ │ ├── train/
│ │ │ └── val/
│ │ ├── labels/
│ │ │ ├── train/
│ │ │ └── val/
│ │ ├── train.txt
│ │ └── val.txt
├── best_ship_model.pt
├── requirements.txt
└── data.yaml
文件内容
requirements.txt
opencv-python==4.5.3.56
torch==2.0.0+cu117
matplotlib
numpy
pandas
albumentations
ultralytics
data.yaml
假设你的数据集自带 data.yaml
文件,其内容应类似于:
train: ./datasets/ship_images/images/train
val: ./datasets/ship_images/images/val
nc: 1 # 假设只有一个类别:ship
names: ['ship']
数据准备
-
确认数据集目录结构:
确保你的数据集已经按照以下结构组织好:datasets/ └── ship_images/ ├── images/ │ ├── train/ │ └── val/ ├── labels/ │ ├── train/ │ └── val/ ├── train.txt └── val.txt
-
检查
data.yaml
文件:
确认data.yaml
文件中的路径和类别信息正确无误。
训练脚本
train.py
from ultralytics import YOLO
# 设置随机种子以保证可重复性
import torch
torch.manual_seed(42)
# 定义数据集路径
dataset_config = 'data.yaml'
# 加载预训练的YOLOv8模型
model = YOLO('yolov8n.pt')
# 训练模型
results = model.train(
imgsz=640,
batch=16,
epochs=50,
data=dataset_config,
weights='yolov8n.pt',
name='ship_detection',
project='runs/train'
)
# 打印训练结果
print(results)
评估脚本
evaluate.py
from ultralytics import YOLO
# 初始化YOLOv8模型
model_path = 'runs/train/ship_detection/weights/best.pt'
# 加载模型
model = YOLO(model_path)
# 评估模型
metrics = model.val(data='data.yaml', imgsz=640)
# 打印评估结果
print(metrics)
推理脚本
infer.py
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt
# 初始化YOLOv8模型
model_path = 'runs/train/ship_detection/weights/best.pt'
model = YOLO(model_path)
def detect_and_visualize(image_path):
results = model.predict(source=image_path, conf=0.25, iou=0.45, agnostic_nms=False)
for result in results:
boxes = result.boxes.cpu().numpy()
im_array = result.orig_img
for box in boxes:
r = box.xyxy[0].astype(int)
cls = int(box.cls[0])
conf = box.conf[0]
label = f'{model.names[cls]} {conf:.2f}'
# 绘制边界框
cv2.rectangle(im_array, (r[0], r[1]), (r[2], r[3]), (0, 255, 0), 2)
# 绘制标签
cv2.putText(im_array, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 显示图像
plt.imshow(cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
if __name__ == "__main__":
image_path = 'path/to/your/image.jpg' # 替换为你的图像路径
detect_and_visualize(image_path)
可视化脚本
visualize.py
from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
# 初始化YOLOv8模型
model_path = 'runs/train/ship_detection/weights/best.pt'
model = YOLO(model_path)
def visualize_dataset(dataset_path, num_samples=5):
with open(dataset_path, 'r') as f:
lines = f.readlines()
sample_paths = random.sample(lines, num_samples)
for path in sample_paths:
path = path.strip()
results = model.predict(source=path, conf=0.25, iou=0.45, agnostic_nms=False)
for result in results:
boxes = result.boxes.cpu().numpy()
im_array = result.orig_img
for box in boxes:
r = box.xyxy[0].astype(int)
cls = int(box.cls[0])
conf = box.conf[0]
label = f'{model.names[cls]} {conf:.2f}'
# 绘制边界框
cv2.rectangle(im_array, (r[0], r[1]), (r[2], r[3]), (0, 255, 0), 2)
# 绘制标签
cv2.putText(im_array, label, (r[0], r[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 显示图像
plt.imshow(cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()
if __name__ == "__main__":
dataset_path = 'datasets/ship_images/val.txt' # 替换为你的验证集路径
visualize_dataset(dataset_path)
运行步骤总结
-
克隆项目仓库(如果有的话):
git clone https://github.com/yourusername/ship_detection.git cd ship_detection
-
安装依赖项:
conda create --name ship_det_env python=3.8 conda activate ship_det_env pip install -r requirements.txt
-
准备数据集:
- 确保你的数据集已经按照上述结构组织好。
- 确认
data.yaml
文件中的路径和类别信息正确无误。
-
训练模型:
python train.py
-
评估模型:
python evaluate.py
-
运行推理:
python infer.py
-
可视化数据集:
python visualize.py
操作界面
- 选择图片进行检测: 修改
infer.py
中的image_path
变量指向你要检测的图片路径,然后运行python infer.py
。 - 批量检测: 在
visualize.py
中设置dataset_path
为你想要可视化的数据集路径,然后运行python visualize.py
。
详细解释
requirements.txt
列出项目所需的所有Python包及其版本。
data.yaml
配置数据集路径和类别信息,用于YOLOv8模型训练。
train.py
加载预训练的YOLOv8模型并使用自定义数据集进行训练。训练完成后打印训练结果。
evaluate.py
加载训练好的YOLOv8模型并对验证集进行评估,打印评估结果。
infer.py
对单张图像进行预测并可视化检测结果。
visualize.py
对数据集中的一些样本进行预测并可视化检测结果。
希望这些详细的信息和代码能够帮助你顺利实施和优化你的红外船舶检测系统。