9.4.3 池化模块
在深度学习中,池化(Pooling)是一种降采样操作,通常用于减小数据维度、提取关键特征、减少计算量,以及防止过拟合。在本项目中,文件pooling_module.py定义了用于模型中不同池化机制的函数,包括SLSTM(Sum LSTM)、CS-LSTM(Convolutional Social LSTM)和SGAN/Polar-Pooling。此外,还包括用于位置编码的模块(PositionalEncoding)。具体而言,该文件中的函数实现了对社交嵌入数据进行不同类型的池化操作,包括基于和相邻车辆位置关系的SGAN/Polar-Pooling。其中,PositionalEncoding模块用于对输入进行位置编码。
device = "cuda:0" or "cuda:1" if args["use_cuda"] else "cpu"
# Main Pooling Function
# 主要池化函数
'positon encoder'
# 位置编码器
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model).to(device)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
position = position.to(device)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
div_term = div_term.to(device)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
pe = pe.to(device)
#pe.requires_grad = False
self.register_buffer('pe', pe)
def forward(self, x):
x = x.to(device)
return x + self.pe[:x.size(0), :]
def nbrs_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1):#, hist_enc_1
if net.pooling == 'slstm':
soc_enc = s_pooling(net, soc_enc)
elif net.pooling == 'cslstm':
soc_enc = cs_pooling(net, soc_enc)
elif net.pooling == 'sgan' or net.pooling == 'polar':
soc_enc = polar_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1) #, hist_enc_1
return soc_enc
# SLSTM
def s_pooling(net, soc_enc):
# 从底部进行零填充
bottom_pad = net.grid_size[0] % net.kernel_size[0]
if bottom_pad != 0:
pad_layer = nn.ZeroPad2d((0, 0, 0, net.kernel_size[0] - bottom_pad))
soc_enc = pad_layer(soc_enc)
# 求和池化
avg_pool = torch.nn.AvgPool2d((net.kernel_size[0], net.kernel_size[1]))
soc_enc = net.kernel_size[0] * net.kernel_size[1] * avg_pool(soc_enc)
soc_enc = soc_enc.view(-1, net.kernel_size[0] * net.encoder_size)
soc_enc = net.leaky_relu(soc_enc)
return soc_enc
# CS-LSTM:应用卷积社交池化
def cs_pooling(net, soc_enc):
soc_enc = net.soc_maxpool(net.leaky_relu(
net.conv_3x1(net.leaky_relu(net.soc_conv(soc_enc)))))
soc_enc = soc_enc.view(-1, net.soc_embedding_size)
return soc_enc
# 在提出的极坐标池化和SGAN中使用的池化操作
def polar_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1):
sum_masks = masks.sum(dim=3)
soc_enc_1 = soc_enc
soc_enc = torch.zeros(masks.shape[0], net.bottleneck_dim).float()
if net.use_cuda:
soc_enc = soc_enc.cuda()
nbrs_enc = nbrs_enc.cuda()
hist_enc_1 = hist_enc_1.cuda()
soc_enc_1 = soc_enc_1.cuda()
cntr = 0
for ind in range(masks.shape[0]):
no_nbrs = sum_masks[ind].nonzero().size()[0]
if no_nbrs > 0:
curr_nbr_pos = nbrs[:, cntr:cntr+no_nbrs, :]
curr_nbr_enc = nbrs_enc[cntr:cntr+no_nbrs, :]
cntr += no_nbrs
end_nbr_pos = curr_nbr_pos[-1]
soc_enc_1 = soc_enc_1.contiguous().view(soc_enc_1.shape[0], soc_enc_1.shape[1], -1)
hist_enc_1 = hist_enc_1.squeeze()
hist_enc_1 = hist_enc_1.unsqueeze(2)
#soc_enc_1 = soc_enc_1.squeeze()
new_hs = torch.cat((soc_enc_1,hist_enc_1), 2)
#位置编码
pe = PositionalEncoding(d_model=40, max_len=5000)
new_hs_per = pe(new_hs)
#注意力
new_hs_per = new_hs.permute(0, 2, 1)
weight = net.pre4att(net.tanh(new_hs_per))
new_hidden_ha, soft_attn_weights_ha = net.attention(weight, new_hs_per)
new_hidden_ha =torch.cat([new_hidden_ha,new_hidden_ha],dim=1)
#position-Embedding'
rel_pos_embedding = net.rel_pos_embedding(end_nbr_pos)
mlp_h_input = torch.cat([rel_pos_embedding,curr_nbr_enc], dim=1)
mlp_h_input = torch.cat([mlp_h_input,new_hidden_ha],dim=0)
# if only 1 neighbor, BatchNormalization will not work
# So calling model.eval() before feeding the data will change
# the behavior of the BatchNorm layer to use the running estimates
# instead of calculating them
if mlp_h_input.shape[0] == 1 & net.batch_norm:
net.mlp_pre_pool.eval()
curr_pool_h = net.mlp_pre_pool(mlp_h_input)
curr_pool_h = curr_pool_h.max(0)[0]
soc_enc[ind] = curr_pool_h
return soc_enc, soft_attn_weights_ha
具体来说,在上述代码中包括了如下三种不同的池化机制。
- SLSTM 池化(s_pooling):基于Sum LSTM的池化操作,通过对输入进行均值池化来提取关键特征。
- CS-LSTM 池化(cs_pooling):基于Convolutional Social LSTM的池化操作,使用卷积层和最大池化来对社交嵌入进行特征提取。
- SGAN/Polar-Pooling(polar_pooling):提供了SGAN(Social GAN)和Polar-Pooling 池化的实现。这一部分涉及到了位置编码、注意力机制以及多个子模块的组合,用于处理社交嵌入数据,并返回相应的输出。
总体而言,这些池化操作用于整合社交嵌入数据,提取关键信息,为模型的下一步预测任务提供更有代表性的输入。