Bootstrap

(9-5)基于感知轨迹预测模型(BAT)的目标行为预测系统:模型超参数

9.4.2  模型超参数

文件model_args.py定义了模型的参数配置信息,包括使用CUDA、编码器大小、解码器大小、输入和输出长度、网格大小、动力学嵌入大小、输入嵌入大小、纬度和经度类别数、训练标志、动力学矩阵和中心性输入输出大小等多个参数。此外,还包括输入维度、意图模块的使用、池化机制选择等设置。其中,根据池化机制的选择,定义了相应的参数如卷积核大小、SOC卷积深度等。最后,还设置了数据集类中的时间历史、未来长度、跳过因子以及预训练和训练的时期数等。

# 定义模型参数
args = {}
 
# 是否使用CUDA
args['use_cuda'] = True
 
# 编码器大小
args['encoder_size'] = 64
 
# 解码器大小
args['decoder_size'] = 128
 
# 输入历史长度
args['in_length'] = 16
 
# 输出未来长度
args['out_length'] = 25
 
# 网格大小
args['grid_size'] = (13, 3)
 
# 动力学嵌入大小
args['dyn_embedding_size'] = 32
 
# 动力学矩阵和中心性输入大小
args['dyn_matrix_and_centralit_input'] = 39
 
# 输入嵌入大小
args['input_embedding_size'] = 32
 
# 纬度类别数
args['num_lat_classes'] = 3
 
# 经度类别数
args['num_lon_classes'] = 3
 
# 训练标志
args['train_flag'] = True
 
# 动力学矩阵和中心性输出大小
args['dyn_matrix_and_centralit_output'] = 32
 
# 输入维度(2D或3D)
args['input_dim'] = 3
 
# 是否使用意图模块
args['intention_module'] = True
 
# 选择池化机制
args['pooling'] = 'polar'
 
# 池化为'slstm'时的卷积核大小
if args['pooling'] == 'slstm':
    args['kernel_size'] = (4, 3)
 
# 池化为'cslstm'时的SOC卷积深度和3x1卷积深度
elif args['pooling'] == 'cslstm':
    args['soc_conv_depth'] = 64
    args['conv_3x1_depth'] = 16
 
# 池化为'sgan'或'polar'时的瓶颈维度和SGAN批归一化标志
elif args['pooling'] == 'sgan' or args['pooling'] == 'polar':
    args['bottleneck_dim'] = 256
    args['sgan_batch_norm'] = False
 
# 数据集类中的时间历史、未来长度、跳过因子以及预训练和训练时期数等设置
args['t_hist'] = 30
args['t_fut'] = 50
args['skip_factor'] = 2
args['pretrainEpochs'] = 6
args['trainEpochs'] = 5
 
# 用于评估的预测视野
args['pred_horiz'] = 5

未完待续

;