import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, feature_size, hidden_size, num_layers):
super(Encoder, self).__init__()
self.hidden_size = hidden_size # 隐层大小
self.num_layers = num_layers # lstm层数
# # 加入一层全连接扩展特征
# self.fc_extend = nn.Linear(feature_size, 128)
# feature_size为特征维度,就是每个时间点对应的特征数量,这里为8
self.lstm = nn.LSTM(feature_size, hidden_size, num_layers, batch_first=True,
bidirectional=True, dropout=0.3)
def forward(self, x, hidden=None):
batch_size = x.shape[0] # 获取批次大小
# x = self.fc_extend(x)
# 初始化隐层状态
if hidden is None:
h_0 = x.data.new(self.num_layers*2, batch_size, self.hidden_size).fill_(0).float()
c_0 = x.data.new(self.num_layers*2, batch_size, self.hidden_size).fill_(0).float()
else:
h_0,