本文是在参考资料1的基础上加入更多细节完成,并非完全原创,感谢原创同学,尊重支持原创才能让社区更加健康。
这次在模型优化的时候加入了一个RNN结构,TensorFlow里有封装好的RNN函数,我们可以直接调用,RNN详细介绍见参考资料2
TensorFlow官网给的标准API:
注意: 这个是TF1.0版本下的,在2.0以上版本,dynamic_rnn是在 tf.compat.v1.nn.dynamic_rnn
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
参数说明:
-
cell: LSTM、GRU等的记忆单元。cell参数代表一个LSTM或GRU的记忆单元,也就是一个cell,是RNN中最小的单元结构。例如,cell = tf.nn.rnn_cell.LSTMCell((num_units),其中,num_units表示rnn cell中神经元个数,也就是下文的cell.output_size。返回一个LSTM或GRU cell,作为参数传入。多个cell组成了一个完整的RNN结构。
-
inputs: 输入的训练或测试数据,一般格式为[batch_size, max_time, embed_size],其中batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,embed_size表示嵌入的词向量的维度。
-
sequence_length: 是一个list,假设你输入了三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]。
-
time_major: 决定了输出tensor的格式,如果为True, 张量的形状必须为 [max_time, batch_size,cell.output_size]。如果为False, tensor的形状必须为[batch_size, max_time, cell.output_size],cell.output_size表示rnn cell中神经元个数。
-
返回值:元组(outputs, states)
-
outputs: outputs很容易理解,就是每个cell会有一个输出
-
states: states表示最终的状态,也就是序列中最后一个cell输出的状态。一般情况下states的形状为 [batch_size, cell.output_size ],但当输入的cell为
BasicLSTMCell
时,state 的形状为 [2,batch_size, cell.output_size ],其中2也对应着 LSTM 中的 cell state 和 hidden state。
可以看到,tf.nn.dynamic_rnn
这个方法有两个返回值,outputs
和 states
,那这俩是什么关系呢?另外 state 的形状为什么会发生改变?
先回答第二个问题,为什么 state 的形状会发生改变?
首先看当 cell
是LSTM类型时,states形状为 [2,batch_size,cell.output_size ];当cell为GRU时,states形状为[batch_size, cell.output_size ]。其原因是因为 LSTM 和 GRU 的结构本身不同,如下面两个图所示,这是 LSTM 的 cell 结构,每个 cell 会有两个输出:
C
t
C_{t}
Ct 和
h
t
h_t
ht,上面这个图是输出
C
t
C_t
Ct,代表哪些信息应该被记住哪些应该被遗忘; 下面这个图是输出
h
t
h_t
ht,代表这个cell的最终输出,LSTM的 states 是由
C
t
C_t
Ct 和
h
t
h_t
ht 组成的,即 states = (c, h)。
当 cell 为 GRU 时,state 就只有一个了,原因是GRU将
C
t
C_t
Ct 和
h
t
h_t
ht 进行了简化,将其合并成了
h
t
h_t
ht,如下图所示,GRU将遗忘门和输入门合并成了更新门,另外 cell 不再有细胞状态 cell state,只有hidden state
再回答第一个问题,outputs 和 states,这俩是什么关系呢?
对于不同的 cell
类型,outputs
和 states
的关系是有差异的。
如果 cell 为 LSTM,那 states 是个 tuple,分别代表
C
t
C_t
Ct 和
h
t
h_t
ht,其中
h
t
h_t
ht 与outputs中对应的最后一个时刻(即最后一个cell)的输出相等,这里再细说一下,outputs 输出的是每个 cell 的 h,也就是说整个RNN结构里有多少个 cell,outputs 就有多少个 h值,而 states 输出的是最后一个cell 的 C 和 h,它是h 和 outputs 里最后一个h 值是一样的;如果cell为GRU,那么同理,states其实就是
h
t
h_t
ht
Talk is cheap , show me your code
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
# 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time), 控制序列长度, 4代表每个序列的维度
X = np.random.randn(3, 6, 4)
# 第二个输入的实际长度为4
X[1, 4:] = 0
#记录三个输入的实际步长
X_lengths = [6, 4, 6]
rnn_hidden_size = 5
if rnn_type == 'lstm':
cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
else:
cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1, s1 = session.run([outputs, last_states])
print(np.shape(o1))
print("*"*20)
print(o1)
print("*"*20)
print(np.shape(s1))
print("*"*20)
print(s1)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
cell类型为LSTM,输入的形状为 [ 3, 6, 4 ],经过 tf.nn.dynamic_rnn
后 outputs
的形状为 [ 3, 6, 5 ],states
形状为 [ 2, 3, 5 ],其中 state
第一部分为 c,代表 cell state
,第二部分为 h,代表 hidden state
,这就是形状里的第一维2的构成,3是 batch_size,因为我们一次性输入的是3条序列,5是每个输出向量的维度。可以看到 hidden state
与 对应的 outputs
的最后一行是相等的。另外需要注意的是输入一共有三个序列,但第二个序列的长度只有4,可以看到 outputs
中对应的两行值都为0,所以 hidden state
对应的是最后一个不为0的部分。tf.nn.dynamic_rnn
通过设置 sequence_length
来实现这一逻辑。
输出结果1:
(3, 6, 5)
********************
[[[ 0.0146346 -0.04717453 -0.06930042 -0.06065602 0.02456717]
[-0.05580321 0.08770171 -0.04574306 -0.01652854 -0.04319528]
[ 0.09087799 0.03535907 -0.06974291 -0.03757408 -0.15553619]
[ 0.10003044 0.10654698 0.21004055 0.13792148 -0.05587583]
[ 0.13547596 -0.014292 -0.0211154 -0.10857875 0.04461256]
[ 0.00417564 -0.01985144 0.00050634 -0.13238986 0.14323784]]
[[ 0.04893576 0.14289175 0.17957205 0.09093887 -0.0507192 ]
[ 0.17696126 0.09929577 0.21185635 0.20386451 0.11664373]
[ 0.15658667 0.03952745 -0.03425637 0.00773833 -0.03546742]
[-0.14002582 -0.18578786 -0.08373584 -0.25964601 0.04090167]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
[[ 0.18564152 0.01531695 0.13752453 0.17188506 0.19555427]
[ 0.13703949 0.14272294 0.21313036 0.07417354 0.0477547 ]
[ 0.23021792 0.04455495 0.10204565 0.17159792 0.34148467]
[ 0.0386402 0.0387848 0.02134559 0.00110381 0.08414687]
[ 0.01386241 -0.02629686 -0.0733538 -0.03194245 0.13606553]
[ 0.01859433 -0.00585316 -0.04007138 0.03811594 0.21708331]]]
********************
(2, 3, 5)
********************
LSTMStateTuple(
c=array([[ 0.00909146, -0.03747076, 0.0008946 , -0.23459786, 0.29565899],
[-0.18409266, -0.30463044, -0.28033809, -0.49032542, 0.12597639],
[ 0.04494702, -0.01359631, -0.06706629, 0.06766361, 0.40794032]]),
h=array([[ 0.00417564, -0.01985144, 0.00050634, -0.13238986, 0.14323784],
[-0.14002582, -0.18578786, -0.08373584, -0.25964601, 0.04090167],
[ 0.01859433, -0.00585316, -0.04007138, 0.03811594, 0.21708331]])
)
cel l类型为 GRU,我们看看到,输入的形状为 [ 3, 6, 4 ],经过 tf.nn.dynamic_rnn
后 outputs
的形状为 [ 3, 6, 5 ],state形状为 [ 3, 5 ]。可以看到 state
与 对应的 outputs
的最后一行是相等的
输出结果2:
(3, 6, 5)
********************
[[[-0.05190962 -0.13519617 0.02045928 -0.0821183 0.28337528]
[ 0.0201574 0.03779418 -0.05092804 0.02958051 0.12232347]
[ 0.14884441 -0.26075898 0.1821795 -0.03454954 0.18424161]
[-0.13854156 -0.26565378 0.09567164 -0.03960079 0.14000589]
[-0.2605973 -0.39901657 0.12495693 -0.19295695 0.52423598]
[-0.21596414 -0.63051687 0.20837501 -0.31775378 0.77519457]]
[[-0.1979659 -0.30253523 0.0248779 -0.17981144 0.41815343]
[ 0.34481129 -0.05256187 0.1643036 0.00739746 0.27384158]
[ 0.49703664 0.22241165 0.27344766 0.00093435 0.09854949]
[ 0.23312444 0.156997 0.25482553 0.0138156 -0.02302272]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
[[-0.06401732 0.08605342 -0.03936866 -0.02287695 0.16947652]
[-0.1775206 -0.2801672 -0.0387468 -0.20264583 0.58125297]
[ 0.39408762 -0.44066425 0.25826641 -0.18851604 0.36172166]
[ 0.0536013 -0.29902928 0.08891931 -0.03930039 0.0743423 ]
[ 0.02304702 -0.0612499 0.09113458 -0.05169013 0.29876455]
[-0.06711324 0.014125 -0.05856332 -0.05632359 -0.00390189]]]
********************
(3, 5)
********************
[[-0.21596414 -0.63051687 0.20837501 -0.31775378 0.77519457]
[ 0.23312444 0.156997 0.25482553 0.0138156 -0.02302272]
[-0.06711324 0.014125 -0.05856332 -0.05632359 -0.00390189]]
总结一下:
tf.nn.dynamic_rnn
这个函数可以控制RNN cell 个数,构建适合业务场景需求的RNN 结构。
参考资料:
1、https://zhuanlan.zhihu.com/p/43041436
2、https://www.jianshu.com/p/9dc9f41f0b29