Deformable-DETR代码实现
GitHub:https://github.com/fundamentalvision/Deformable-DETR
预训练模型
数据集:coco2017 格式
命令:python main.py --coco_path “data/coco” --batch_size=1 --num_workers=1 --output_dir=“outputs” --start_epoch 0 --resume=r50_deformable_detr-checkpoint.pth
修改代码使得开始跳过验证,直接开始训练
#main.py
# if args.resume:
改为:if args.resume and args.start_epoch is not 0:
修改类别
#models/deformable_detr.py
def build(args):
#num_classes = 12 if args.dataset_file != 'coco' else 91
num_classes = 9 if args.dataset_file != 'coco' else 10
问题
1.cannot import name_ NewEmptyTensorop from torchvision. op S.
解决:我的版本本来就是最新版本了,但还是出现这些的问题,这是因为我现在的版本是0.10+的版本,而由于在torchvision 0.10版本的util/misc.py中,代码只检查前3个字符,因此它认为我的版本是0.1,并试图导入_NewEmptyTensorOp,通过删除util/misc.py中的30-59和490-501行就解决了这个问题。
2.
解决:进入models/ops中运行python setup.py install(适用于Win系统),或者进入该文件夹下编译sh ./make.sh(适用于iux系统)
3.
默认里设置路径
命令:python main.py --batch_size=1 --num_workers=1 --output_dir=“outputs” --start_epoch 0 --resume=r50_deformable_detr-checkpoint.pth
4.训练自己数据集时要改权值
import torch
#加载官方提供的权重文件
pretrained_weights = torch.load('r50_deformable_detr-checkpoint.pth')
#修改相关权重
num_class = 20 # 自己数据集分类数
pretrained_weights['model']['class_embed.0.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.0.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.1.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.1.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.2.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.2.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.3.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.3.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.4.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.4.bias'].resize_(num_class+1)
pretrained_weights['model']['class_embed.5.weight'].resize_(num_class+1, 256)
pretrained_weights['model']['class_embed.5.bias'].resize_(num_class+1)
pretrained_weights['model']['query_embed.weight'].resize_(50, 512) # 此处50对应生成queries的数量,根据main.py中--num_queries数量修改
torch.save(pretrained_weights, 'deformable_detr-r50_%d.pth'%num_class)
参考:https://www.jianshu.com/p/b364534fd0a7