无论前端backbone如何,总会输出一个feature map 如何在这个feature map上得到pred的 box呢?
第一步先得到 feature map上所有的 中心点x,y
举例:
以特征图大小(4, 4 ,5, 6)表示,4为batch,4 为channel, 5为h, 6为w
#以feature的大小5*6 生成单位网格
import torch
yy, xx =torch.meshgrid(torch.arange(5), torch.arange(6)
print(yy,xx)
#得到如下
tensor([[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4]])
tensor([[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]])
#堆叠起来
mesh = torch.stack([xx, yy], dim=0)
print(mesh)
#输出如下,shape为[2, 5, 6]
tensor([[[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4, 5]],
[[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4]]])
#增加一个batch 位置的维度 shape为 [4, 2, 5, 6]
mesh = mesh.unsqueeze(0).repeat(4,1,1,1).float()
print(mesh)
tensor([[[[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.]],
[[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4.]]],
[[[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.]],
[[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4.]]],
[[[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.]],
[[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4.]]],
[[[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.],
[0., 1., 2., 3., 4., 5.]],
[[0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4.]]]])
#### 然后 以预设定的anchor box 比率生成
anchor_wh=[
[ 5.0174, 15.0521],
[ 7.0833, 21.2500],
[10.0347, 24.7917],
[20.0694, 18.8889]]
anchor_offset_mesh = anchor_wh.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, nGh,nGw)
print(anchor_offset_mesh)
tensor([[[[ 5.0174, 5.0174, 5.0174, 5.0174, 5.0174, 5.0174],
[ 5.0174, 5.0174, 5.0174, 5.0174, 5.0174, 5.0174],
[ 5.0174, 5.0174, 5.0174, 5.0174, 5.0174, 5.0174],
[ 5.0174, 5.0174, 5.0174, 5.0174, 5.0174, 5.0174],
[ 5.0174, 5.0174, 5.0174, 5.0174, 5.0174, 5.0174]],
[[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521],
[15.0521, 15.0521, 15.0521, 15.0521, 15.0521, 15.0521]]],
[[[ 7.0833, 7.0833, 7.0833, 7.0833, 7.0833, 7.0833],
[ 7.0833, 7.0833, 7.0833, 7.0833, 7.0833, 7.0833],
[ 7.0833, 7.0833, 7.0833, 7.0833, 7.0833, 7.0833],
[ 7.0833, 7.0833, 7.0833, 7.0833, 7.0833, 7.0833],
[ 7.0833, 7.0833, 7.0833, 7.0833, 7.0833, 7.0833]],
[[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500],
[21.2500, 21.2500, 21.2500, 21.2500, 21.2500, 21.2500]]],
[[[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347],
[10.0347, 10.0347, 10.0347, 10.0347, 10.0347, 10.0347]],
[[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917],
[24.7917, 24.7917, 24.7917, 24.7917, 24.7917, 24.7917]]],
[[[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694],
[20.0694, 20.0694, 20.0694, 20.0694, 20.0694, 20.0694]],
[[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889],
[18.8889, 18.8889, 18.8889, 18.8889, 18.8889, 18.8889]]]])
anchor_offset_mesh=torch.cat([mesh2, anchor_offset_mesh],dim=1)
f=anchor_offset_mesh.permute(0,2,3,1).contiguous().view(-1, 4)
print(f)
# 生成 x,y, w,h 格式的bounding box
tensor([[ 0.0000, 0.0000, 5.0174, 15.0521],
[ 1.0000, 0.0000, 5.0174, 15.0521],
[ 2.0000, 0.0000, 5.0174, 15.0521],
[ 3.0000, 0.0000, 5.0174, 15.0521],
[ 4.0000, 0.0000, 5.0174, 15.0521],
[ 5.0000, 0.0000, 5.0174, 15.0521],
[ 0.0000, 1.0000, 5.0174, 15.0521],
[ 1.0000, 1.0000, 5.0174, 15.0521],
[ 2.0000, 1.0000, 5.0174, 15.0521],
[ 3.0000, 1.0000, 5.0174, 15.0521],
[ 4.0000, 1.0000, 5.0174, 15.0521],
[ 5.0000, 1.0000, 5.0174, 15.0521],
[ 0.0000, 2.0000, 5.0174, 15.0521],
[ 1.0000, 2.0000, 5.0174, 15.0521],
[ 2.0000, 2.0000, 5.0174, 15.0521],
[ 3.0000, 2.0000, 5.0174, 15.0521],
[ 4.0000, 2.0000, 5.0174, 15.0521],
[ 5.0000, 2.0000, 5.0174, 15.0521],
[ 0.0000, 3.0000, 5.0174, 15.0521],
[ 1.0000, 3.0000, 5.0174, 15.0521],
[ 2.0000, 3.0000, 5.0174, 15.0521],
[ 3.0000, 3.0000, 5.0174, 15.0521],
[ 4.0000, 3.0000, 5.0174, 15.0521],
[ 5.0000, 3.0000, 5.0174, 15.0521],
[ 0.0000, 4.0000, 5.0174, 15.0521],
[ 1.0000, 4.0000, 5.0174, 15.0521],
[ 2.0000, 4.0000, 5.0174, 15.0521],
[ 3.0000, 4.0000, 5.0174, 15.0521],
[ 4.0000, 4.0000, 5.0174, 15.0521],
[ 5.0000, 4.0000, 5.0174, 15.0521],
[ 0.0000, 0.0000, 7.0833, 21.2500],
[ 1.0000, 0.0000, 7.0833, 21.2500],
[ 2.0000, 0.0000, 7.0833, 21.2500],
[ 3.0000, 0.0000, 7.0833, 21.2500],
[ 4.0000, 0.0000, 7.0833, 21.2500],
[ 5.0000, 0.0000, 7.0833, 21.2500],
[ 0.0000, 1.0000, 7.0833, 21.2500],
[ 1.0000, 1.0000, 7.0833, 21.2500],
[ 2.0000, 1.0000, 7.0833, 21.2500],
[ 3.0000, 1.0000, 7.0833, 21.2500],
[ 4.0000, 1.0000, 7.0833, 21.2500],
[ 5.0000, 1.0000, 7.0833, 21.2500],
[ 0.0000, 2.0000, 7.0833, 21.2500],
[ 1.0000, 2.0000, 7.0833, 21.2500],
[ 2.0000, 2.0000, 7.0833, 21.2500],
[ 3.0000, 2.0000, 7.0833, 21.2500],
[ 4.0000, 2.0000, 7.0833, 21.2500],
[ 5.0000, 2.0000, 7.0833, 21.2500],
[ 0.0000, 3.0000, 7.0833, 21.2500],
[ 1.0000, 3.0000, 7.0833, 21.2500],
[ 2.0000, 3.0000, 7.0833, 21.2500],
[ 3.0000, 3.0000, 7.0833, 21.2500],
[ 4.0000, 3.0000, 7.0833, 21.2500],
[ 5.0000, 3.0000, 7.0833, 21.2500],
[ 0.0000, 4.0000, 7.0833, 21.2500],
[ 1.0000, 4.0000, 7.0833, 21.2500],
[ 2.0000, 4.0000, 7.0833, 21.2500],
[ 3.0000, 4.0000, 7.0833, 21.2500],
[ 4.0000, 4.0000, 7.0833, 21.2500],
[ 5.0000, 4.0000, 7.0833, 21.2500],
[ 0.0000, 0.0000, 10.0347, 24.7917],
[ 1.0000, 0.0000, 10.0347, 24.7917],
[ 2.0000, 0.0000, 10.0347, 24.7917],
[ 3.0000, 0.0000, 10.0347, 24.7917],
[ 4.0000, 0.0000, 10.0347, 24.7917],
[ 5.0000, 0.0000, 10.0347, 24.7917],
[ 0.0000, 1.0000, 10.0347, 24.7917],
[ 1.0000, 1.0000, 10.0347, 24.7917],
[ 2.0000, 1.0000, 10.0347, 24.7917],
[ 3.0000, 1.0000, 10.0347, 24.7917],
[ 4.0000, 1.0000, 10.0347, 24.7917],
[ 5.0000, 1.0000, 10.0347, 24.7917],
[ 0.0000, 2.0000, 10.0347, 24.7917],
[ 1.0000, 2.0000, 10.0347, 24.7917],
[ 2.0000, 2.0000, 10.0347, 24.7917],
[ 3.0000, 2.0000, 10.0347, 24.7917],
[ 4.0000, 2.0000, 10.0347, 24.7917],
[ 5.0000, 2.0000, 10.0347, 24.7917],
[ 0.0000, 3.0000, 10.0347, 24.7917],
[ 1.0000, 3.0000, 10.0347, 24.7917],
[ 2.0000, 3.0000, 10.0347, 24.7917],
[ 3.0000, 3.0000, 10.0347, 24.7917],
[ 4.0000, 3.0000, 10.0347, 24.7917],
[ 5.0000, 3.0000, 10.0347, 24.7917],
[ 0.0000, 4.0000, 10.0347, 24.7917],
[ 1.0000, 4.0000, 10.0347, 24.7917],
[ 2.0000, 4.0000, 10.0347, 24.7917],
[ 3.0000, 4.0000, 10.0347, 24.7917],
[ 4.0000, 4.0000, 10.0347, 24.7917],
[ 5.0000, 4.0000, 10.0347, 24.7917],
[ 0.0000, 0.0000, 20.0694, 18.8889],
[ 1.0000, 0.0000, 20.0694, 18.8889],
[ 2.0000, 0.0000, 20.0694, 18.8889],
[ 3.0000, 0.0000, 20.0694, 18.8889],
[ 4.0000, 0.0000, 20.0694, 18.8889],
[ 5.0000, 0.0000, 20.0694, 18.8889],
[ 0.0000, 1.0000, 20.0694, 18.8889],
[ 1.0000, 1.0000, 20.0694, 18.8889],
[ 2.0000, 1.0000, 20.0694, 18.8889],
[ 3.0000, 1.0000, 20.0694, 18.8889],
[ 4.0000, 1.0000, 20.0694, 18.8889],
[ 5.0000, 1.0000, 20.0694, 18.8889],
[ 0.0000, 2.0000, 20.0694, 18.8889],
[ 1.0000, 2.0000, 20.0694, 18.8889],
[ 2.0000, 2.0000, 20.0694, 18.8889],
[ 3.0000, 2.0000, 20.0694, 18.8889],
[ 4.0000, 2.0000, 20.0694, 18.8889],
[ 5.0000, 2.0000, 20.0694, 18.8889],
[ 0.0000, 3.0000, 20.0694, 18.8889],
[ 1.0000, 3.0000, 20.0694, 18.8889],
[ 2.0000, 3.0000, 20.0694, 18.8889],
[ 3.0000, 3.0000, 20.0694, 18.8889],
[ 4.0000, 3.0000, 20.0694, 18.8889],
[ 5.0000, 3.0000, 20.0694, 18.8889],
[ 0.0000, 4.0000, 20.0694, 18.8889],
[ 1.0000, 4.0000, 20.0694, 18.8889],
[ 2.0000, 4.0000, 20.0694, 18.8889],
[ 3.0000, 4.0000, 20.0694, 18.8889],
[ 4.0000, 4.0000, 20.0694, 18.8889],
[ 5.0000, 4.0000, 20.0694, 18.8889]])