结合上次文章深入理解LSTM的代码实现。
一、LSTM
输入:x^t 代表当前时刻下的输入,h^(t-1)表示上一时刻的hidden state,c^(t-1)代表上一时刻的cell state。x^t、h^(t-1)、c^(t-1)一起作为输入。
输出:h^t为当前时刻的hidden state,c^t为当前时刻的cell state
c^t=g(z)f(zi)+c^(t-1)f(zf)
其中,对于传递下去的 c^t 改变得很慢,而 h^t 则在不同节点下往往会有很大的区别。LSTM通过对 c^t 的利用尽可能长时间地保留信息
参考文章:LSTM 详解及其代码实现
二、LSTM的代码
总的代码如下:
import torch as t
import torch.nn as nn
class LstmNet(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
# input_size输入元素个数,hidden_size隐藏层元素个数 和 num_layers隐藏层数
super(LstmNet, self).__init__()
# 隐藏单元数
self.hidden_size = hidden_size
# 隐藏层数
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
# batch_first如果为True,那么输入和输出Tensor的形状为(batch,seq,feature)
# 输出的全连接网络
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# 通过x.size(0)获取 batch中的元素个数
# h0 和 c0 的格式为:(层数*方向数,批次数,隐藏层数)
b_size = x.size(0)
h0 = t.zeros(self.num_layers, b_size, self.hidden_size)
c0 = t.zeros(self.num_layers, b_size, self.hidden_size)
lstm_out, _ = self.lstm(x, (h0, c0))
fc_out = self.fc(lstm_out[:, -1, :])
return fc_out
代码思路:
__init__()是构造函数,负责构造LSTM网络,其中包括input_size,hidden_size,num_layer,num_classes。
forward()是指在LSTM中所执行的一次正向传播。在LSTM中,这个函数的作用是根据当前的输入以及之前的状态来计算当前的输出和更新网络中的隐藏状态。
参考文章:forward函数在LSTM模型中
以上代码来自文章:LSTM 详解及其代码实现
1中调用的是构造函数__init__,2中调用的是forward()。
forward()方法是如何被调用的?
forward是在__call__中调用的,output, (hn, cn) = LSTM(input, (h0, c0))调用__call__方法,由于每个网络都重写forward方法。所以,当调用forward时,都调用的时重写之后的版本。
参考文章:pytorch的学习之路(一)| 模型的forward方法是如何被调用的
参数解释:
输入参数:
input_size:输入数据大小,可以理解为每个单词向量的长度
hidden_size:隐藏层元素个数
num_layers:隐藏层数
batch_first 默认为False,如果为True,那么输入和输出Tensor的形状为(batch,seq_length,feature)
num_classes:要分类的类别数,比如MNIST要识别数字,则num_classes为10,共有(0-9)10种类别
fc(fully connected)为全连接:就是将最后一层卷积得到特征图(矩阵)展开成一维向量,并未分类器提供输入。
全连接层在整个网络中起到分类器的作用。
使用全连接层的方式在LSTM进行时间序列预测中是一个重要的决策。
使用全连接层最终输出你想要的维度。
输出参数:
hn是一个三维的张量
第一维是 num_layers*num_directions,num_layers是我们定义的神经网络的层数,num_directions取值为1或2,为1时是普通LSTM,为2时是双向LSTM。这里是为1,所以直接是num_layers。
第二维是表示一批的样本数量batch_size
第三维表示隐藏层元素个数
cn是一样的
lstm_out的输出本来是(batch,seq_len,hidden_size),因为要接fc,所以使用lstm_out[:,-1,:],将seq_len这个维度去掉了。
参考文章:降维操作