Stage2,Stage3,就是stage1的重复,但Stage2中block会重复4次,Stage3中block会重复16次
具体细节就不多讲了
for i in range(self.num_stages):
template, online_template, search = getattr(self, f'stage{i}')(template,
online_template, search)
return template, search
返回模板和搜索特征图(主要是搜索特征图)
template, search = self.backbone(template, online_template, search)
# Forward the corner head
return self.forward_box_head(search)
进入head,这里我们的search大小为(8, 384, 20, 20)
def forward_box_head(self, search):
"""
:param search: (b, c, h, w)
:return:
"""
if self.head_type == "CORNER":
# run the corner head
b = search.size(0)
outputs_coord = box_xyxy_to_cxcywh(self.box_head(search))
outputs_coord_new = outputs_coord.view(b, 1, 4)
out = {'pred_boxes': outputs_coord_new}
return out, outputs_coord_new
else:
raise KeyError
进入self.box_head
具体
class Corner_Predictor(nn.Module):
""" Corner Predictor module"""
def __init__(self, inplanes=64, channel=256, feat_sz=20, stride=16, freeze_bn=False):
super(Corner_Predictor, self).__init__()
self.feat_sz = feat_sz
self.stride = stride
self.img_sz = self.feat_sz * self.stride
'''top-left corner'''
self.conv1_tl = conv(inplanes, channel, freeze_bn=freeze_bn)
self.conv2_tl = conv(channel, channel // 2, freeze_bn=freeze_bn)
self.conv3_tl = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
self.conv4_tl = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
self.conv5_tl = nn.Conv2d(channel // 8, 1, kernel_size=1)
'''bottom-right corner'''
self.conv1_br = conv(inplanes, channel, freeze_bn=freeze_bn)
self.conv2_br = conv(channel, channel // 2, freeze_bn=freeze_bn)
self.conv3_br = conv(channel // 2, channel // 4, freeze_bn=freeze_bn)
self.conv4_br = conv(channel // 4, channel // 8, freeze_bn=freeze_bn)
self.conv5_br = nn.Conv2d(channel // 8, 1, kernel_size=1)
'''about coordinates and indexs'''
with torch.no_grad():
self.indice = torch.arange(0, self.feat_sz).view(-1, 1) * self.stride
# generate mesh-grid
self.coord_x = self.indice.repeat((self.feat_sz, 1)) \
.view((self.feat_sz * self.feat_sz,)).float().cuda()
self.coord_y = self.indice.repeat((1, self.feat_sz)) \
.view((self.feat_sz * self.feat_sz,)).float().cuda()
def forward(self, x, return_dist=False, softmax=True):
""" Forward pass with input x. """
score_map_tl, score_map_br = self.get_score_map(x)
if return_dist:
coorx_tl, coory_tl, prob_vec_tl = self.soft_argmax(score_map_tl,
return_dist=True,
softmax=softmax)
coorx_br, coory_br, prob_vec_br = self.soft_argmax(score_map_br,
return_dist=True,
softmax=softmax)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) /
self.img_sz, prob_vec_tl, prob_vec_br
else:
coorx_tl, coory_tl = self.soft_argmax(score_map_tl)
coorx_br, coory_br = self.soft_argmax(score_map_br)
return torch.stack((coorx_tl, coory_tl, coorx_br, coory_br), dim=1) /
self.img_sz
通过两种分支的卷积生成top-left和bottom-right预测图
def get_score_map(self, x):
# top-left branch
x_tl1 = self.conv1_tl(x)
x_tl2 = self.conv2_tl(x_tl1)
x_tl3 = self.conv3_tl(x_tl2)
x_tl4 = self.conv4_tl(x_tl3)
score_map_tl = self.conv5_tl(x_tl4)
# bottom-right branch
x_br1 = self.conv1_br(x)
x_br2 = self.conv2_br(x_br1)
x_br3 = self.conv3_br(x_br2)
x_br4 = self.conv4_br(x_br3)
score_map_br = self.conv5_br(x_br4)
return score_map_tl, score_map_br
对预测图softmax操作
def soft_argmax(self, score_map, return_dist=False, softmax=True):
""" get soft-argmax coordinate for a given heatmap """
score_vec = score_map.view((-1, self.feat_sz * self.feat_sz)) # (batch, feat_sz *
feat_sz)
prob_vec = nn.functional.softmax(score_vec, dim=1)
exp_x = torch.sum((self.coord_x * prob_vec), dim=1)
exp_y = torch.sum((self.coord_y * prob_vec), dim=1)
if return_dist:
if softmax:
return exp_x, exp_y, prob_vec
else:
return exp_x, exp_y, score_vec
else:
return exp_x, exp_y
这里的操作是为了求特征图中角点坐标的期望,最后返回左上角的x,y坐标和右下角的x,y坐标
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
将所得的坐标转化为中心坐标和高宽的格式
最后计算损失,损失有giou,iou,L1
giou不懂的可以查看论文Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression
MixFormer(论文解读与代码讲解)补充