Bootstrap

(9-6)基于感知轨迹预测模型(BAT)的目标行为预测系统:池化模块

9.4.3  池化模块

在深度学习中,池化(Pooling)是一种降采样操作,通常用于减小数据维度、提取关键特征、减少计算量,以及防止过拟合。在本项目中,文件pooling_module.py定义了用于模型中不同池化机制的函数,包括SLSTM(Sum LSTM)、CS-LSTM(Convolutional Social LSTM)和SGAN/Polar-Pooling。此外,还包括用于位置编码的模块(PositionalEncoding)。具体而言,该文件中的函数实现了对社交嵌入数据进行不同类型的池化操作,包括基于和相邻车辆位置关系的SGAN/Polar-Pooling。其中,PositionalEncoding模块用于对输入进行位置编码。

device = "cuda:0" or "cuda:1" if args["use_cuda"] else "cpu"
# Main Pooling Function
# 主要池化函数
 
'positon encoder'
# 位置编码器
class PositionalEncoding(nn.Module):
 
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()       
        pe = torch.zeros(max_len, d_model).to(device)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        position = position.to(device)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        div_term = div_term.to(device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        pe = pe.to(device)
        #pe.requires_grad = False
        self.register_buffer('pe', pe)
 
    def forward(self, x):
        x = x.to(device)
        return x + self.pe[:x.size(0), :]
 
def nbrs_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1):#, hist_enc_1
    if net.pooling == 'slstm':
        soc_enc = s_pooling(net, soc_enc)
    elif net.pooling == 'cslstm':
        soc_enc = cs_pooling(net, soc_enc)
    elif net.pooling == 'sgan' or net.pooling == 'polar':
        soc_enc = polar_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1) #, hist_enc_1
    return soc_enc
 
# SLSTM
def s_pooling(net, soc_enc):
 
    # 从底部进行零填充
    bottom_pad = net.grid_size[0] % net.kernel_size[0]
 
    if bottom_pad != 0:
        pad_layer = nn.ZeroPad2d((0, 0, 0, net.kernel_size[0] - bottom_pad))
        soc_enc = pad_layer(soc_enc)
 
    # 求和池化
    avg_pool = torch.nn.AvgPool2d((net.kernel_size[0], net.kernel_size[1]))
    soc_enc = net.kernel_size[0] * net.kernel_size[1] * avg_pool(soc_enc)
    soc_enc = soc_enc.view(-1, net.kernel_size[0] * net.encoder_size)
    soc_enc = net.leaky_relu(soc_enc)
 
    return soc_enc
 
# CS-LSTM:应用卷积社交池化
 
def cs_pooling(net, soc_enc):
 
    soc_enc = net.soc_maxpool(net.leaky_relu(
        net.conv_3x1(net.leaky_relu(net.soc_conv(soc_enc)))))
    soc_enc = soc_enc.view(-1, net.soc_embedding_size)
    return soc_enc
 
# 在提出的极坐标池化和SGAN中使用的池化操作
 
def polar_pooling(net, soc_enc, masks, nbrs, nbrs_enc,hist_enc_1):
    sum_masks = masks.sum(dim=3)
    soc_enc_1 = soc_enc
    soc_enc = torch.zeros(masks.shape[0], net.bottleneck_dim).float()
    if net.use_cuda:
        soc_enc = soc_enc.cuda()
        nbrs_enc = nbrs_enc.cuda()
        hist_enc_1 = hist_enc_1.cuda()
        soc_enc_1 = soc_enc_1.cuda()
 
 
    cntr = 0
    for ind in range(masks.shape[0]):
        no_nbrs = sum_masks[ind].nonzero().size()[0]
        if no_nbrs > 0:
            curr_nbr_pos = nbrs[:, cntr:cntr+no_nbrs, :]
            curr_nbr_enc = nbrs_enc[cntr:cntr+no_nbrs, :]
            cntr += no_nbrs
 
            end_nbr_pos = curr_nbr_pos[-1]
            
            soc_enc_1 = soc_enc_1.contiguous().view(soc_enc_1.shape[0], soc_enc_1.shape[1], -1) 
            hist_enc_1 = hist_enc_1.squeeze()
            hist_enc_1 = hist_enc_1.unsqueeze(2)
            #soc_enc_1  = soc_enc_1.squeeze() 
            
            new_hs = torch.cat((soc_enc_1,hist_enc_1), 2) 
            #位置编码
 
            pe = PositionalEncoding(d_model=40, max_len=5000)
            new_hs_per = pe(new_hs)
 
            #注意力
            new_hs_per = new_hs.permute(0, 2, 1)
            weight = net.pre4att(net.tanh(new_hs_per))
            new_hidden_ha, soft_attn_weights_ha = net.attention(weight, new_hs_per)
            new_hidden_ha =torch.cat([new_hidden_ha,new_hidden_ha],dim=1)
 
            #position-Embedding'
            rel_pos_embedding = net.rel_pos_embedding(end_nbr_pos) 
            mlp_h_input = torch.cat([rel_pos_embedding,curr_nbr_enc], dim=1)
            mlp_h_input = torch.cat([mlp_h_input,new_hidden_ha],dim=0)
            # if only 1 neighbor, BatchNormalization will not work
            # So calling model.eval() before feeding the data will change
            # the behavior of the BatchNorm layer to use the running estimates
            # instead of calculating them
            if mlp_h_input.shape[0] == 1 & net.batch_norm:
                net.mlp_pre_pool.eval()
 
            curr_pool_h = net.mlp_pre_pool(mlp_h_input)
 
            curr_pool_h = curr_pool_h.max(0)[0]
            soc_enc[ind] = curr_pool_h
    return soc_enc, soft_attn_weights_ha

具体来说,在上述代码中包括了如下三种不同的池化机制。

  1. SLSTM 池化(s_pooling):基于Sum LSTM的池化操作,通过对输入进行均值池化来提取关键特征。
  2. CS-LSTM 池化(cs_pooling):基于Convolutional Social LSTM的池化操作,使用卷积层和最大池化来对社交嵌入进行特征提取。
  3. SGAN/Polar-Pooling(polar_pooling):提供了SGAN(Social GAN)和Polar-Pooling 池化的实现。这一部分涉及到了位置编码、注意力机制以及多个子模块的组合,用于处理社交嵌入数据,并返回相应的输出。

总体而言,这些池化操作用于整合社交嵌入数据,提取关键信息,为模型的下一步预测任务提供更有代表性的输入。

;