Bootstrap

Social-stgcnn代码笔记记录

参考:
代码https://github.com/abduallahmohamed/Social-STGCNN
TCN 网络+代码https://blog.csdn.net/Leon_winter/article/details/100124146
童哲校长傅里叶变换的课https://www.bilibili.com/video/BV1ft411J73y
上海交大许志钦老师图卷积的课https://www.bilibili.com/video/BV1ap4y117DD

utils.py

import os
import math
import sys

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Func
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from numpy import linalg as LA
import networkx as nx
from tqdm import tqdm
import time


def anorm(p1,p2): 
    NORM = math.sqrt((p1[0]-p2[0])**2+ (p1[1]-p2[1])**2)
    if NORM ==0:
        return 0
    return 1/(NORM)
                
def seq_to_graph(seq_,seq_rel,norm_lap_matr = True):
    seq_ = seq_.squeeze()
    seq_rel = seq_rel.squeeze()
    seq_len = seq_.shape[2]
    max_nodes = seq_.shape[0]

    
    V = np.zeros((seq_len,max_nodes,2))
    A = np.zeros((seq_len,max_nodes,max_nodes))
    for s in range(seq_len):
        step_ = seq_[:,:,s]#seq是[57,2,8]  step_是[57,2]八个点,循环一次填每个点的57个行人的 xy
        step_rel = seq_rel[:,:,s]#同上[57,2]
        for h in range(len(step_)): #len是二维的行数=57 len[1]是列数  第s个序列的h个人的坐标都写入了
            V[s,h,:] = step_rel[h]#v[8,57,2] s对应8,h对应57,     : 对应xy
            A[s,h,h] = 1#[8,57,57] 自己和自己记做1,只有一个点的值改为1其余不变仍是0
            for k in range(h+1,len(step_)):
                l2_norm = anorm(step_rel[h],step_rel[k])
                A[s,h,k] = l2_norm#距离
                A[s,k,h] = l2_norm#距离,两个人的距离一样
        if norm_lap_matr: 
            G = nx.from_numpy_matrix(A[s,:,:])
            A[s,:,:] = nx.normalized_laplacian_matrix(G).toarray()#拉普拉斯矩阵和邻接矩阵形状一样8*57*57
            
    return torch.from_numpy(V).type(torch.float),\
           torch.from_numpy(A).type(torch.float)


def poly_fit(traj, traj_len, threshold):
    """
    输入:
     -traj:形状为Numpy的数组(2,traj_len)
     -traj_len:轨迹的Len
     -阈值:非线性轨迹应考虑的最小误差
     输出:
     -int:1->非线性0->线性
    Input:
    - traj: Numpy array of shape (2, traj_len)
    - traj_len: Len of trajectory
    - threshold: Minimum error to be considered for non linear traj
    Output:
    - int: 1 -> Non Linear 0-> Linear
    """
    t = np.linspace(0, traj_len - 1, traj_len)
    res_x = np.polyfit(t, traj[0, -traj_len:], 2, full=True)[1]
    res_y = np.polyfit(t, traj[1, -traj_len:], 2, full=True)[1]
    if res_x + res_y >= threshold:
        return 1.0
    else:
        return 0.0
def read_file(_path, delim='\t'):
    data = []
    if delim == 'tab':
        delim = '\t'
    elif delim == 'space':
        delim = ' '
    with open(_path, 'r') as f:
        for line in f:
            line = line.strip().split(delim)
            line = [float(i) for i in line]
            data.append(line)
    return np.asarray(data)


class TrajectoryDataset(Dataset):
    """轨迹数据集的数据加载器
    Dataloder for the Trajectory datasets"""
    def __init__(
        self, data_dir, obs_len=8, pred_len=8, skip=1, threshold=0.002,
        min_ped=1, delim='\t',norm_lap_matr = True):
        """
        参数:
         -data_dir:目录包含格式为的数据集文件
         <frame_id> <ped_id> <x> <y>
         -obs_len:输入轨迹中的时间步数
         -pred_len:输出轨迹中的时间步数
         -skip:制作数据集时要跳过的帧数
         -threshold:非线性轨迹应考虑的最小误差
         使用线性预测器时
         -min_ped:应连续行驶的最小行人数量
         -delim:数据集文件中的定界符

        Args:
        - data_dir: Directory containing dataset files in the format
        <frame_id> <ped_id> <x> <y>
        - obs_len: Number of time-steps in input trajectories
        - pred_len: Number of time-steps in output trajectories
        - skip: Number of frames to skip while making the dataset
        - threshold: Minimum error to be considered for non linear traj
        when using a linear predictor
        - min_ped: Minimum number of pedestrians that should be in a seqeunce
        - delim: Delimiter in the dataset files
        """
        super(TrajectoryDataset, self).__init__()

        self.max_peds_in_frame = 0
        self.data_dir = data_dir
        self.obs_len = obs_len
        self.pred_len = pred_len
        self.skip = skip
        self.seq_len = self.obs_len + self.pred_len
        self.delim = delim
        self.norm_lap_matr = norm_lap_matr

        all_files = os.listdir(self.data_dir)
        all_files = [os.path.join(self.data_dir, _path) for _path in all_files]
        num_peds_in_seq = []
        seq_list = []
        seq_list_rel = []
        loss_mask_list = []
        non_linear_ped = []
        for path in all_files:
            data = read_file(path, delim)
            frames = np.unique(data[:, 0]).tolist()#去重后一共有多少帧
            frame_data = []#2785*20*57*2
            for frame in frames:#所有帧数2804帧
                frame_data.append(data[frame == data[:, 0], :])#按帧取出所有数据-生成列表
            num_sequences = int(
                math.ceil((len(frames) - self.seq_len + 1) / skip))#向上取整, 减去序列长度再加1

            for idx in range(0, num_sequences * self.skip + 1, skip):#所有帧数2785帧
                curr_seq_data = np.concatenate(
                    frame_data[idx:idx + self.seq_len], axis=0)#取8+12共20作为一个序列,2785个20的序列eg:0-20帧所有行人的所有坐标[20*57*2]
                peds_in_curr_seq = np.unique(curr_seq_data[:, 1])#去重取出每个序列里的行人 [2785]eg:0-20帧内所有出现的行人
                self.max_peds_in_frame = max(self.max_peds_in_frame,len(peds_in_curr_seq))#最大行人数自我更新
                curr_seq_rel = np.zeros((len(peds_in_curr_seq), 2,#[57,2,20]
                                         self.seq_len))
                curr_seq = np.zeros((len(peds_in_curr_seq), 2, self.seq_len))#同上
                curr_loss_mask = np.zeros((len(peds_in_curr_seq),#[57,20]
                                           self.seq_len))
                num_peds_considered = 0
                _non_linear_ped = []
                for _, ped_id in enumerate(peds_in_curr_seq):#单位帧内所有人57人
                    curr_ped_seq = curr_seq_data[curr_seq_data[:, 1] ==ped_id, :]#把当前人的20个序列xy取出[20*2]
                    curr_ped_seq = np.around(curr_ped_seq, decimals=4)#小数点后4位取整
                    pad_front = frames.index(curr_ped_seq[0, 0]) - idx#首帧
                    pad_end = frames.index(curr_ped_seq[-1, 0]) - idx + 1#尾帧
                    if pad_end - pad_front != self.seq_len:
                        continue
                    curr_ped_seq = np.transpose(curr_ped_seq[:, 2:])#取第二列之后的xy,再转置[2*20]
                    curr_ped_seq = curr_ped_seq
                    # 建立相对坐标系 Make coordinates relative
                    rel_curr_ped_seq = np.zeros(curr_ped_seq.shape)#2*20
                    rel_curr_ped_seq[:, 1:] = \
                        curr_ped_seq[:, 1:] - curr_ped_seq[:, :-1]#下一帧减当前帧坐标——>求出变化量
                    _idx = num_peds_considered
                    curr_seq[_idx, :, pad_front:pad_end] = curr_ped_seq#把多个行人的序列放在一个列表
                    curr_seq_rel[_idx, :, pad_front:pad_end] = rel_curr_ped_seq#同上
                    # Linear vs Non-Linear Trajectory
                    _non_linear_ped.append(
                        poly_fit(curr_ped_seq, pred_len, threshold))#把线性的找出来放在列表里
                    curr_loss_mask[_idx, pad_front:pad_end] = 1#设置掩码
                    num_peds_considered += 1

                if num_peds_considered > min_ped:
                    non_linear_ped += _non_linear_ped
                    num_peds_in_seq.append(num_peds_considered)#记录有多少人
                    loss_mask_list.append(curr_loss_mask[:num_peds_considered])#掩码
                    seq_list.append(curr_seq[:num_peds_considered])#把这个序列所有人的坐标加进列表,得到[2785*57,2,20]
                    seq_list_rel.append(curr_seq_rel[:num_peds_considered])#同上

        self.num_seq = len(seq_list)
        seq_list = np.concatenate(seq_list, axis=0)#[2785*57,2,20]
        seq_list_rel = np.concatenate(seq_list_rel, axis=0)
        loss_mask_list = np.concatenate(loss_mask_list, axis=0)
        non_linear_ped = np.asarray(non_linear_ped)

        # Convert numpy -> Torch Tensor
        self.obs_traj = torch.from_numpy(
            seq_list[:, :, :self.obs_len]).type(torch.float)#[2785*57,2,8]
        self.pred_traj = torch.from_numpy(
            seq_list[:, :, self.obs_len:]).type(torch.float)#[2785*57,2,12]
        self.obs_traj_rel = torch.from_numpy(
            seq_list_rel[:, :, :self.obs_len]).type(torch.float)#[2785*57,2,8]
        self.pred_traj_rel = torch.from_numpy(
            seq_list_rel[:, :, self.obs_len:]).type(torch.float)#[2785*57,2,12]
        self.loss_mask = torch.from_numpy(loss_mask_list).type(torch.float)
        self.non_linear_ped = torch.from_numpy(non_linear_ped).type(torch.float)
        cum_start_idx = [0] + np.cumsum(num_peds_in_seq).tolist()
        self.seq_start_end = [
            (start, end)
            for start, end in zip(cum_start_idx, cum_start_idx[1:])#当前帧涉及的人数,和下一帧的人数
        ]
        #转换为图Convert to Graphs
        self.v_obs = [] 
        self.A_obs = [] 
        self.v_pred = [] 
        self.A_pred = [] 
        print("Processing Data .....")
        pbar = tqdm(total=len(self.seq_start_end)) 
        for ss in range(len(self.seq_start_end)):#len是2785
            pbar.update(1)

            start, end = self.seq_start_end[ss]#取出来开始和结束时的人数

            v_,a_ = seq_to_graph(self.obs_traj[start:end,:],self.obs_traj_rel[start:end, :],self.norm_lap_matr)#self.obs_traj[start:end,:]是2785个seq_list中的1个[57,2,8],,self.obs_traj是2785*[57,2,8]
            self.v_obs.append(v_.clone())#2785*8*57*2  经过ss2785个循环v_obs从[8,57,2]变成[2785,8,57,2]
            self.A_obs.append(a_.clone())#2785*8*57*57  同上
            v_,a_=seq_to_graph(self.pred_traj[start:end,:],self.pred_traj_rel[start:end, :],self.norm_lap_matr)#同上
            self.v_pred.append(v_.clone())
            self.A_pred.append(a_.clone())
        pbar.close()

    def __len__(self):
        return self.num_seq

    def __getitem__(self, index):
        start, end = self.seq_start_end[index]

        out = [
            self.obs_traj[start:end, :], self.pred_traj[start:end, :],
            self.obs_traj_rel[start:end, :], self.pred_traj_rel[start:end, :],
            self.non_linear_ped[start:end], self.loss_mask[start:end, :],
            self.v_obs[index], self.A_obs[index],
            self.v_pred[index], self.A_pred[index]

        ]

modle.py

import os
import math
import sys

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Func
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import torch.optim as optim




class ConvTemporalGraphical(nn.Module):
    #Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
    r"""
    r“”“应用图卷积的基本模块。
    arg:
        in_channels(int):输入序列数据中的通道数
        out_channels(int):卷积产生的通道数
        kernel_size(int):图卷积内核的大小
        t_kernel_size(int):时间卷积核的大小
        t_stride(整数,可选):时间卷积的跨度。默认值:1
        t_padding(int,可选):在控件的两边都添加了时间零填充
            输入。默认值:0
        t_dilation(整数,可选):时间内核元素之间的间距。
            默认值:1
        偏见(布尔型,可选):如果为``True'',则向输出添加可学习的偏见。
            默认值:``True``
    形状:
        -Input [0]:以(n,in_channels,T_ {in},V)格式输入图形序列
        -Input [1]:以(K,V,V)格式输入图邻接矩阵
        -output[0]:Outpu图形序列,格式为(N,out_channels,T_ {out},V)`
        -Output [1]:输出数据的图形邻接矩阵,格式为(K,V,V)`
        哪里
            :ma:`N`是批处理大小,
            :math:`K`是空间内核大小,如:math:`K == kernel_size [1]`,
            :math:`T_ {in} / T_ {out}`是输入/输出序列的长度,
            V是图形节点的数量。
    “”

    The basic module for applying a graph convolution.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the graph convolving kernel
        t_kernel_size (int): Size of the temporal convolving kernel
        t_stride (int, optional): Stride of the temporal convolution. Default: 1
        t_padding (int, optional): Temporal zero-padding added to both sides of
            the input. Default: 0
        t_dilation (int, optional): Spacing between temporal kernel elements.
            Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the output.
            Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes. 
    """
    def __init__(self,
                 in_channels,#2
                 out_channels,#5
                 kernel_size,#8
                 t_kernel_size=1,
                 t_stride=1,
                 t_padding=0,
                 t_dilation=1,
                 bias=True):
        super(ConvTemporalGraphical,self).__init__()
        self.kernel_size = kernel_size
        #dilation 3*3卷积间隔一个虽然还是3*3但效果类似5*5,维度变化是按5*5
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,#5个通道5个卷积核
            kernel_size=(t_kernel_size, 1),#[1,1]
            padding=(t_padding, 0),#[0,0]
            stride=(t_stride, 1),#[1,1]
            dilation=(t_dilation, 1),#[1,1]
            bias=bias)#[8,57]————>[8,57] 卷积核为1.维度不变

    def forward(self, x, A):#A是[8,57,57]存的所有人的距离
        assert A.size(0) == self.kernel_size
        x = self.conv(x)#[1,5,8,57],[8,57,57]一个有5个通道,每个通道包含2种人[5*2],每种人有8[5*2*8]个时刻,一种人有57个[5*2*8*57]!!一共有4个维度:57个人的某一时刻共57*8个数据,与他人57个的某一时刻57*8个数据,!!第一矩阵是57个人在所有时刻的与57个人在第一个时刻的距离,因此是8*57*57相乘降维到8*57!!在第二个大的维度上包含某一个通道57个人在8个时刻与57个人在8个时刻的距离!!!在第三个大的维度上包含所有通道的所有距离
        x = torch.einsum('nctv,tvw->nctw', (x, A))#对应XA的公式,傅里叶变换频域的卷积等于时域的相乘[1,5,8,57],[8,57,57]————>[1,5,8,57] 只关注后两个维度[8,57]和[57,57] 
#A是utils.py中的seq_to_graph函数,将A转化为(A+I)乘以D逆
#e = torch.einsum('nctv,tvw->nctw', (a, b)) einsum爱因斯坦求和省略中间计算过程,等同于下面这一行
#d = a.view(n * c * t, 1, v).bmm(b.view(1, t, v, w).expand(n * c, t, v,w).reshape(-1, v, w)).view(n, c, t, w)
        return x.contiguous(), A
    #nctw是[1,5,8,57] 和[8,57,57]是两个三维数据相乘,即5个8*57的矩阵 与8个57*57的矩阵相乘  8*57和57*57相乘得到一个8*57的矩阵包含第一种人在所有时刻和第二种人在某一个时刻i的距离,需要8个这个的矩阵可以表示第一种人在所有时刻和第二种人在所有时刻的距离,因此是8个!!
    #有了8个矩阵之后还需要5个 这样的8个矩阵,表示5个通道,第一种人在所有时刻和第二种人在所有时刻的距离,所以是40个

    #8*57 57*57相乘第一个点表示第一个时刻出现的人,他们不到57,他们与第一个人在第一个时刻的距离和!!,第一行所有点表示第一个时刻出现的人,与其他第一个时刻出现的人的距离和!!!当然第二行表示第二个时刻。。。。。。


class st_gcn(nn.Module):
    r"""
    在输入图序列上应用空间时间图卷积。
    精氨酸:
        in_channels(int):输入序列数据中的通道数
        out_channels(int):卷积产生的通道数
        kernel_size(元组):时间卷积核和图卷积核的大小
        stride(整数,可选):时间卷积的步幅。默认值:1
        dropout(整数,可选):最终输出的辍学率。默认值:0
        residual(布尔值,可选):如果为``True'',则应用残差机制。默认值:``True``
    形状:
        -Input [0]:以(ma,in_channels,T_ {in},V)格式输入图形序列
        -Input [1]:以((K,V,V)`)格式输入图邻接矩阵
        -output[0]:Outpu图形序列,格式为(N,out_channels,T_ {out},V)`
        -Output [1]:输出数据的图形邻接矩阵,格式为(K,V,V)`
        哪里
            :ma:`N`是批处理大小,
            :math:`K`是空间内核大小,如:math:`K == kernel_size [1]`,
            :math:`T_ {in} / T_ {out}`是输入/输出序列的长度,
            :math:'V'是图形节点的数量。

    Applies a spatial temporal graph convolution over an input graph sequence.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
        stride (int, optional): Stride of the temporal convolution. Default: 1
        dropout (int, optional): Dropout rate of the final output. Default: 0
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """

    def __init__(self,
                 in_channels,#2
                 out_channels,#5
                 kernel_size,#[3,8]
                 use_mdn = False,
                 stride=1,
                 dropout=0,
                 residual=True):
        super(st_gcn,self).__init__()
        
#         print("outstg",out_channels)

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = ((kernel_size[0] - 1) // 2, 0)#
        self.use_mdn = use_mdn
#128,5,8,57  *  8,57,57  =————> 128,5,8,57  求了个距离和缩了一维
        self.gcn = ConvTemporalGraphical(in_channels, out_channels,
                                         kernel_size[1])
        

        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.PReLU(),
            nn.Conv2d(
                out_channels,#5
                out_channels,#5
                (kernel_size[0], 1),#[3,1] n=(w-f+2p)+1 = (8-3+2)+1 = 8  列是57-1+1 = 57不变
                (stride, 1),#[1,1]
                padding,#[1,0]
            ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True),
        )

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(
                    in_channels,#2
                    out_channels,#5
                    kernel_size=1,
                    stride=(stride, 1)),#[1,1]
                nn.BatchNorm2d(out_channels),
            )#[8,57]————>  n=8-1+1=8 [8,57]

        self.prelu = nn.PReLU()

    def forward(self, x, A):
#[128,2,8,57],[8,57,57]
        res = self.residual(x)#[128,2,8,57]————>[128,5,8,57]
        x, A = self.gcn(x, A)#128,5,8,57  *  8,57,57  ————> 128,5,8,57

        x = self.tcn(x) + res#tcn[128,5,8,57]相加还是
        
        if not self.use_mdn:
            x = self.prelu(x)#激活层

        return x, A

class social_stgcnn(nn.Module):
    def __init__(self,n_stgcnn =1,n_txpcnn=1,input_feat=2,output_feat=5,
                 seq_len=8,pred_seq_len=12,kernel_size=3):
        super(social_stgcnn,self).__init__()
        self.n_stgcnn= n_stgcnn#1
        self.n_txpcnn = n_txpcnn#5
                
        self.st_gcns = nn.ModuleList()
        self.st_gcns.append(st_gcn(input_feat,output_feat,(kernel_size,seq_len)))
        for j in range(1,self.n_stgcnn):#0
            self.st_gcns.append(st_gcn(output_feat,output_feat,(kernel_size,seq_len)))
        
        self.tpcnns = nn.ModuleList()
        self.tpcnns.append(nn.Conv2d(seq_len,pred_seq_len,3,padding=1))#n=5-3+2+1=5
        for j in range(1,self.n_txpcnn):#1,2,3,4
            self.tpcnns.append(nn.Conv2d(pred_seq_len,pred_seq_len,3,padding=1))#n=w-3+2+1不变
        self.tpcnn_ouput = nn.Conv2d(pred_seq_len,pred_seq_len,3,padding=1)
            
            
        self.prelus = nn.ModuleList()
        for j in range(self.n_txpcnn):#0,1,2,3,4
            self.prelus.append(nn.PReLU())#默认0.25


        
    def forward(self,v,a):

        for k in range(self.n_stgcnn):#1 从0到1不包含1
            v,a = self.st_gcns[k](v,a)# 128,5,8,57
            
        v = v.view(v.shape[0],v.shape[2],v.shape[1],v.shape[3])#128,5,8,57————>128,8,5,57
        
        v = self.prelus[0](self.tpcnns[0](v))#128,12,5,57

        for k in range(1,self.n_txpcnn-1):#1,2,3,4
            v =  self.prelus[k](self.tpcnns[k](v)) + v#128*12*5*57 4个残差tcn
            
        v = self.tpcnn_ouput(v)#128*12*5*57
        v = v.view(v.shape[0],v.shape[2],v.shape[1],v.shape[3])#128*5*12*57
        
        
        return v,a

train.py

def train(epoch):
    global metrics,loader_train
    model.train()
    loss_batch = 0 
    batch_count = 0
    is_fst_loss = True
    loader_len = len(loader_train)
    turn_point =int(loader_len/args.batch_size)*args.batch_size+ loader_len%args.batch_size -1


    for cnt,batch in enumerate(loader_train): 
        batch_count+=1

        #获取数据Get data
        batch = [tensor.cuda() for tensor in batch]
        obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,\
         loss_mask,V_obs,A_obs,V_tr,A_tr = batch



        optimizer.zero_grad()
        #前向传播Forward
        #V_obs = batch,seq,node,feat
        #V_obs_tmp = batch,feat,seq,node
        #V起初是[8,57,2] 高8层 57行 2列, 现在是高2层 8行 57列
        V_obs_tmp =V_obs.permute(0,3,1,2)#[128,2,8,57]v现在是128*高2层 8行 57列, A是邻接矩阵128*8*57*57

        V_pred,_ = model(V_obs_tmp,A_obs.squeeze())#V是2个 8*57的矩阵,,A是8个57*57的矩阵
        #高2层 8行 57列  ————→  高57层 8行 2列
        V_pred = V_pred.permute(0,2,3,1)#128*5*12*57————>[128,12,57,5]
        
        

        V_tr = V_tr.squeeze()#[128,8,57,2]
        A_tr = A_tr.squeeze()#[8,57,57]
        V_pred = V_pred.squeeze()#[128,12,57,5]

        if batch_count%args.batch_size !=0 and cnt != turn_point :
            l = graph_loss(V_pred,V_tr)
            if is_fst_loss :
                loss = l
                is_fst_loss = False
            else:
                loss += l

        else:
            loss = loss/args.batch_size
            is_fst_loss = True
            loss.backward()
            
            if args.clip_grad is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip_grad)


            optimizer.step()
            #衡量指标Metrics
            loss_batch += loss.item()
            print('TRAIN:','\t Epoch:', epoch,'\t Loss:',loss_batch/batch_count)#所有batch 的 单位batch loss和  除以当前batch数
            
    metrics['train_loss'].append(loss_batch/batch_count)

metrcis.py

def bivariate_loss(V_pred,V_trgt):#12,57,5   12,57,2
    #mux, muy, sx, sy, corr
    #assert V_pred.shape == V_trgt.shape
    #格雷夫斯(Graves)的公式24和25(2013) and eq 24 & 25 in Graves (2013)
    #!!!https://blog.csdn.net/qq_41185868/article/details/104875246 论文公式24和25
    # 高斯分布概率密度函数 f(x) =  σ乘以根号下2π分之一  乘以 e 的 负的x-u 的平方 除以2σ平方
    normx = V_trgt[:,:,0]- V_pred[:,:,0]
    normy = V_trgt[:,:,1]- V_pred[:,:,1]

    sx = torch.exp(V_pred[:,:,2]) #sx
    sy = torch.exp(V_pred[:,:,3]) #sy
    corr = torch.tanh(V_pred[:,:,4]) #corr
    
    sxsy = sx * sy

    z = (normx/sx)**2 + (normy/sy)**2 - 2*((corr*normx*normy)/sxsy)
    negRho = 1 - corr**2

    # 分子 Numerator
    result = torch.exp(-z/(2*negRho))
    # 归一化因子 Normalization factor
    denom = 2 * np.pi * (sxsy * torch.sqrt(negRho))

    # 最终距离计算 Final PDF calculation
    result = result / denom

    # 数值稳定性 Numerical stability
    epsilon = 1e-20

    result = -torch.log(torch.clamp(result, min=epsilon))
    result = torch.mean(result)
    
    return result

仅做记录欢迎指出问题

;