该章节介绍VITGAN对抗生成网络中,MSA多头注意力 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
MSA多头注意力 简介
MSA多头注意力,生成器时使用乘积计算注意力,判别器时使用欧氏距离计算注意力,实现参考Tensorflow官网教程:教程地址
代码实现
import tensorflow as tf
class MSA(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, discriminator):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.discriminator = discriminator # 是否鉴别器
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model, use_bias=False)
self.wk = tf.keras.layers.Dense(d_model, use_bias=False)
self.wv = tf.keras.layers.Dense(d_model, use_bias=False)
self.dense = tf.keras.layers.Dense(d_model, use_bias=False)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def scaled_dot_product_attention(self, q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
# (..., seq_len_q, seq_len_k)
if self.discriminator:
# 欧氏距离
matmul_q = tf.expand_dims(q, axis=-2)
matmul_k = tf.expand_dims(k, axis=-3)
matmul_qk = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(matmul_q - matmul_k), axis=-1))
else:
# 乘积
matmul_qk = tf.matmul(q, k, transpose_b=True)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(
scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
# (batch_size, num_heads, seq_len_q, depth)
q = self.split_heads(q, batch_size)
# (batch_size, num_heads, seq_len_k, depth)
k = self.split_heads(k, batch_size)
# (batch_size, num_heads, seq_len_v, depth)
v = self.split_heads(v, batch_size)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(
q, k, v, mask)
# (batch_size, seq_len_q, num_heads, depth)
scaled_attention = tf.transpose(
scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
# (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention)
# return output, attention_weights
return output
if __name__ == "__main__":
layer = MSA(128, 8, discriminator=False)
x = tf.random.uniform([2,5,128], dtype=tf.float32)
o = layer(x,x,x,None)
tf.print(tf.shape(o))