Bootstrap

Deformable-DETR代码实现

Deformable-DETR代码实现

GitHub:https://github.com/fundamentalvision/Deformable-DETR

预训练模型
在这里插入图片描述
数据集:coco2017 格式

命令:python main.py --coco_path “data/coco” --batch_size=1 --num_workers=1 --output_dir=“outputs” --start_epoch 0 --resume=r50_deformable_detr-checkpoint.pth

修改代码使得开始跳过验证,直接开始训练

#main.py
# if args.resume:
改为:if args.resume and args.start_epoch is not 0:

修改类别

#models/deformable_detr.py
def build(args):
    #num_classes = 12 if args.dataset_file != 'coco' else 91
    num_classes = 9 if args.dataset_file != 'coco' else 10

问题
1.cannot import name_ NewEmptyTensorop from torchvision. op S.
解决:我的版本本来就是最新版本了,但还是出现这些的问题,这是因为我现在的版本是0.10+的版本,而由于在torchvision 0.10版本的util/misc.py中,代码只检查前3个字符,因此它认为我的版本是0.1,并试图导入_NewEmptyTensorOp,通过删除util/misc.py中的30-59和490-501行就解决了这个问题。

2.在这里插入图片描述

解决:进入models/ops中运行python setup.py install(适用于Win系统),或者进入该文件夹下编译sh ./make.sh(适用于iux系统)
3.
在这里插入图片描述
默认里设置路径

命令:python main.py --batch_size=1 --num_workers=1 --output_dir=“outputs” --start_epoch 0 --resume=r50_deformable_detr-checkpoint.pth

4.训练自己数据集时要改权值

import torch
#加载官方提供的权重文件
pretrained_weights = torch.load('r50_deformable_detr-checkpoint.pth')

#修改相关权重
num_class = 20 # 自己数据集分类数
pretrained_weights['model']['class_embed.0.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.0.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.1.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.1.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.2.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.2.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.3.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.3.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.4.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.4.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.5.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.5.bias'].resize_(num_class+1)
pretrained_weights['model']['query_embed.weight'].resize_(50, 512) # 此处50对应生成queries的数量,根据main.py中--num_queries数量修改
torch.save(pretrained_weights, 'deformable_detr-r50_%d.pth'%num_class)
 

参考:https://www.jianshu.com/p/b364534fd0a7

;