一、背景
BEV方案中,将图像视角转换到BEV视角的方法对模型性能影响较大,FastBEV的速度较快,但投影效果上限不高,LSS投影上限较高,但速度较慢 (耗时相对较高)。是否有折中的方案,在耗时增加相对较少的情况下,提升模型的上限(中高算力平台下,提升模型能力)?
二、视角转换关键算子-----gridsample
这是pytorch官网对gridsample算子使用方法说明,其支持4-D(FastBEV/IMP)和5-D(LSS)采样,将图像特征提取到对应的BEV特征中,完成相机视角转换:https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
5-D gridsample相比4-D gridsample耗时剧增,假如在某智驾芯片上,4-D gridsample耗时是2ms,相同条件下5-D gridsample的耗时可能是200ms(具体耗时受特征图通道数影响),这种耗时急剧上升的方案,很难在智驾中落地应用。
三、LSS投影优化
1.先来对比4-D gridsample和5-D gridsample的输入输出关系:
4-D gridsample
input: (N, C, H_in, W_in);
bev_grid: (N, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
output: (N, C, H_out, W_out)
5-D gridsample
input: (N, C, H_in, W_in);
for循环提取每个C通道的输入特征进行softmax处理input_i:(N, D, H_in, W_in),按照dim=1堆叠起来,得到深度输入input_2:(N, C, D, H_in, W_in), 这里的D表示深度估计的通道数;
bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
output: (N, C, Z_out, H_out, W_out);
由于获取深度信息需要用到5-D gridsample,想要降低耗时,考虑减少特征图通道对耗时的影响,即做5-D gridsample时,将通道C设为1;
2.具体方法-----拆解5-D gridsample
将5-D gridsample拆解为一个4-D gridsample和一个单通道(C=1)的5-D gridsample,4-D gridsample负责提取多通道特征信息,单通道5-D gridsample负责提取深度特征信息,最后将两个特征信息相乘,得到多通道下的深度信息,等效变换过程如下:
step1:
4-D gridsample
input: (N, C, H_in, W_in);
bev_grid: (N, Z_out, H_out, W_out, 2), 这里的2表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y);
for循环提取每个Z_out下的bev_grid_i: (N, Z_out, H_out, W_out, 2),通过4-D gridsample分别得到输出特征图output_i: (N, C, H_out, W_out),按照dim=2堆叠起来,得到最终的BEV特征图output_1(没有深度概率信息):
output_1: (N, C, Z_out, H_out, W_out)
step2:
单通道5-D gridsample
input: (N, C, H_in, W_in);
input经过softmax处理后的特征图input_2: (N, D, H_in, W_in),这里的D表示深度估计的通道数;将input_2在dim=1上扩展一个维度,得到input_3:(N, 1, D, H_in, W_in)
bev_grid: (N, Z_out, H_out, W_out, 3), 这里的3表示bev_grid坐标通过相机内外参投影到图像上的坐标(x,y,d), d为深度估计;
output_2: (N, 1, Z_out, H_out, W_out);
step3:
将output_1和output_2相乘得到有深度概率信息的BEV特征图
output = outptu_1 * output_2 = (N, C, Z_out, H_out, W_out) * (N, 1, Z_out, H_out, W_out) = (N, 1, Z_out, H_out, W_out)
四、部分代码
1.IPM的BEV网格坐标索引
class UpdateIndicesIPM:
def __init__(self, height, range, voxel_size, feature_size, downsample):
self.height = height
self.range = range
self.voxel_size = voxel_size
self.feature_size = feature_size
self.ds_matrix = np.eye(4)
self.ds_matrix[:2] /= downsample
def __call__(self, data):
num = len(data["cam2egoes"])
ego2feats = torch.zeros((num, 4, 4), dtype=torch.float32)
for i in range(num):
ego2cam = np.linalg.inv(data["cam2egoes"][i])
tmp = np.eye(4)
tmp[:3, :3] = data["cam_intrinsics"][i]
ego2feats[i] = torch.tensor(self.ds_matrix @ tmp @ ego2cam)
grid = torch.stack(torch.meshgrid([
torch.arange(self.range[0], self.range[3], self.voxel_size[0]),
torch.arange(self.range[1], self.range[4], self.voxel_size[1]),
torch.tensor(self.height), torch.tensor(1.0)
], indexing="ij")) # [4, 188, 64, 4, 1]
grid_h, grid_w = grid.shape[1:3]
grid = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512]
points_2d = torch.bmm(ego2feats[:, :3, :], grid)
x = (points_2d[:, 0] / points_2d[:, 2]).round().long()
y = (points_2d[:, 1] / points_2d[:, 2]).round().long()
z = points_2d[:, 2]
valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) &
(y < self.feature_size[0]) & (z > 0))
x[valid] = 0
y[valid] = 0
x = (x.float() / self.feature_size[1] * 2.) - 1.0
y = (y.float() / self.feature_size[0] * 2.) - 1.0
indices = torch.cat([x.unsqueeze(2), y.unsqueeze(2)], dim=2)
indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 2) # batch, num_img, bev_w, bev_h, num_height, 2
data["indices"] = indices
return data
2.FastBEV
class FastBevTransform(nn.Module):
def __init__(self, feats_channels, num_height):
super().__init__()
self._num_height = num_height
self._conv = nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1)
self._grid_sample = GridSample(mode="nearest",
padding_mode="zeros",
align_corners=True)
self._cat = Concat(dim=1)
def forward(self, feats, indices):
# feats: (7B, C, H, W), indices: (7B, Hg, Wg, Z, 2)
bev_feats = []
for i in range(self._num_height):
output = self._grid_sample(feats, indices[:,:,:,i])
bev_feats.append(output)
bev_feats = self._cat(bev_feats) # (7B, Z*C, Hg, Wg)
bev_feats = self._conv(bev_feats) # (7B, C, Hg, Wg)
return bev_feats
3.LSS的BEV网格坐标索引
class UpdateIndicesLSS:
def __init__(self, height, range, voxel_size, feature_size,
resolution, max_num_depth, downsample):
self.height = height
self.range = range
self.voxel_size = voxel_size
self.feature_size = feature_size
self.resolution = resolution
self.max_num_depth = max_num_depth
self.ds = np.eye(3)
self.ds[:2] /= downsample
def __call__(self, data):
num = len(data["cam2egoes"])
ego2cams = torch.zeros((num, 4, 4), dtype=torch.float32)
cam2feats = torch.zeros((num, 3, 3), dtype=torch.float32)
for i in range(num):
ego2cams[i] = torch.tensor(np.linalg.inv(data["cam2egoes"][i]))
cam2feats[i] = torch.tensor(self.ds @ data["cam_intrinsics"][i])
grid = torch.stack(torch.meshgrid([
torch.arange(self.range[0], self.range[3], self.voxel_size[0]),
torch.arange(self.range[1], self.range[4], self.voxel_size[1]),
torch.tensor(self.height), torch.tensor(1.0)
], indexing="ij")) # [4, 188, 64, 4, 1]
grid_h, grid_w = grid.shape[1:3]
grid4 = grid.view(1, 4, -1).expand(num, 4, -1) # [7, 4, 192512]
points_2d = torch.bmm(ego2cams[:, :3, :], grid4)
x = (points_2d[:, 0] / points_2d[:, 2]) # [7, 48128]
y = (points_2d[:, 1] / points_2d[:, 2]) # [7, 48128]
z = points_2d[:, 2] # [7, 48128]
r = points_2d.norm(dim=1) # [B*N, Hg*Wg]
d = torch.floor(r / self.resolution)
distortions = torch.tensor(np.array(data["cam_distortions"]).T)
k1,k2,k3,p1,p2,k4,k5,k6 = distortions[:,:,None]
fovs = torch.tensor(data['crop_fovs']).unsqueeze(-1) / 2.0
in_fov = np.abs(np.arctan2(points_2d[:, 0], z)) < fovs
r2 = x**2 + y**2
ratio = (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) / (1 + k4 * r2 + k5 * r2**2 + k6 * r2**3)
x_undist = x * ratio + 2 * p1 * x * y + p2 * (r2 + 2 * x**2)
y_undist = y * ratio + p1 * (r2 + 2 * y**2) + 2 * p2 * x * y
x = cam2feats[:, 0, [0]] * x_undist + cam2feats[:, 0, [2]]
y = cam2feats[:, 1, [1]] * y_undist + cam2feats[:, 1, [2]]
valid = ~((x >= 0) & (y >= 0) & (x < self.feature_size[1]) & \
(y < self.feature_size[0]) & (z > 0) & in_fov & \
(d >= 0) & (d < self.max_num_depth)) # [7, 48128]
x[valid], y[valid], d[valid] = -1, -1, -1
x = (x.float() / self.feature_size[1] * 2.) - 1.0
y = (y.float() / self.feature_size[0] * 2.) - 1.0
d = (d.float() / self.max_num_depth * 2.) - 1.0
indices = torch.cat([x[:,:,None], y[:,:,None], d[:,:,None]], dim=2) # [7, 48128, 3]
indices = indices.reshape(-1, grid_h, grid_w, len(self.height), 3) # batch*num_img, bev_w, bev_h, num_height, 3(x, y, d)
data["indices"] = indices.permute(0, 3, 1, 2, 4) # batch*num_img, num_height, bev_w, bev_h, 3(x, y, d)
return data
4.LSS的BEV投影
class LssBevTransform(nn.Module):
def __init__(self, feats_channels, num_height, max_num_depth):
super().__init__()
self._num_height = num_height
self._max_num_depth = max_num_depth
self.ms_cam = MS_CAM(feats_channels * num_height)
self._depth_proj = nn.Sequential(
nn.Conv2d(feats_channels, max_num_depth, kernel_size=3, padding=1),
nn.Softmax(dim=1)
)
self._grid_sample = GridSample(mode="nearest",
padding_mode="zeros",
align_corners=True)
self._cat = Concat(dim=1)
self._blocks = nn.Sequential(
nn.Conv2d(feats_channels * num_height, feats_channels, kernel_size=1),
nn.BatchNorm2d(feats_channels),
nn.ReLU(inplace=True)
)
def simplify_bev(self, feats, indices):
depths = self._depth_proj(feats)[:, None]
import ipdb
ipdb.set_trace()
pass
def forward(self, feats, indices):
# feats: (B*N, C, H, W)
# indices: (B*N, Z, X, Y, 3) where 3 dims represent (w, h, d).
bev_feats = self._sample_bev_feats(feats, indices[..., :2]) # (B*N, C, Z, X, Y)
depth_feats = self._sample_depth_feats(feats, indices) # (B*N, 1, Z, X, Y)
final_feats = bev_feats * depth_feats # (B*N, C, Z, Y, X)
N, C, Z, Y, X = final_feats.shape
final_feats = final_feats.view(N, C * Z, Y, X) # (B*N, Z*C, Hg, Wg)
final_feats = final_feats*self.ms_cam(final_feats)
final_feats = self._blocks(final_feats) # (B*N, C, Hg, Wg)
return final_feats
def _sample_bev_feats(self, feats, indices):
bev_feats = [self._grid_sample(feats, indices[:, i]) for i in range(self._num_height)]
return torch.stack(bev_feats, dim=2) # (B*N, C, Z, Y, X)
def _sample_depth_feats(self, feats, indices):
depths = self._depth_proj(feats)[:, None] # (B*N, 1, D, H, W)
return self._grid_sample(depths, indices) # (B*N, 1, Z, X, Y)
五、展望
LSS投影时将input_3:(N, 1, D, H_in, W_in)中D和H_in进行reshape合并后得(N, 1, D*H_in, W_in),可以完全通过4-D gridsample提取特征,耗时进一步降低,等效替代测试代码如下:
#!/usr/bin/env python3
import unittest
import torch
import torch.nn.functional as F
class GridSampleTest(unittest.TestCase):
def test_grid_sample_equivalence(self):
D, H, W = 100, 144, 256
Y, X = 64, 128
# Generate random features.
feats_5d = torch.randn(1, 1, D, H, W)
# Generate random indices.
d = torch.randint(high=D, size=(Y, X))
h = torch.randint(high=H, size=(Y, X))
w = torch.randint(high=W, size=(Y, X))
# Prepare grid for 5D grid_sample.
indices_5d = torch.stack([
2.0 * w / (W - 1) - 1.0,
2.0 * h / (H - 1) - 1.0,
2.0 * d / (D - 1) - 1.0
], dim=-1).view(1, 1, Y, X, 3)
bev_feats_5d = F.grid_sample(
feats_5d, indices_5d, mode="nearest", align_corners=True
).view(Y, X)
# Flatten D and H dimensions and prepare grid for 4D grid_sample.
dh = d * H + h
indices_4d = torch.stack([
2.0 * w / (W - 1) - 1.0,
2.0 * dh / (D * H - 1) - 1.0
], dim=-1).view(1, Y, X, 2)
feats_4d = feats_5d.view(1, 1, D * H, W)
bev_feats_4d = F.grid_sample(
feats_4d, indices_4d, mode="nearest", align_corners=True
).view(Y, X)
# Check if the results are close.
self.assertTrue(torch.allclose(bev_feats_5d, bev_feats_4d, atol=1e-6))
if __name__ == "__main__":
unittest.main()