PyTorch中的Embedding Layer
一、语法格式
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None,
max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False, _weight=None)
1、参数说明
(1)num_embeddings(int)
:语料库字典大小;
(2)embedding_dim(int)
:每个嵌入向量的大小;
(3)padding_idx(int, optional)
:输出遇到此下标时用零填充;
(4)max_norm(float, optional)
:重新归一化词嵌入,使它们的范数小于提供的值;
(5)norm_type(float, optional)
:对应max_norm选项计算p范数时的p,默认值为2;
(上面的4、5两个参数基本不用,通常使用kaiming和xavier初始化参数)
(6)scale_grad_by_freq(boolean, optional)
:将通过小批量中单词频率的倒数来缩放梯度,默认为False。注意!这里的词频指的是自动获取当前小批量中的词频,而非整个词典;
(7) sparse(bool, optional)
:如果为True,则与权重矩阵相关的梯度转变为稀疏张量。
稀疏张量指反向传播时只更新当前使用词的权重矩阵,以加快更新速度。但是,即使设置 sparse=True ,权重矩阵也未必稀疏更新,原因如下:
- 与优化器相关,使用SGD、Adam等优化器时包含momentum项,导致不相关词的Embedding依然会叠加动量,无法稀疏更新;
- 使用weight_decay,即正则项计入损失值。
基本上通常需要设置的参数是前三个
2、变量说明
Embedding.weight
为 可学习参数 ,形状为 (num_embeddings, embedding_dim) ,初始化为标准正态分布 (N(0, 10)) 。
输入:input(*),数据类型LongTensor,一般为[mini-batch, nums of index]。
输出:output( * , embedding_dim),其中 * 是输入的形状。
二、实例
import torch
import torch.nn as nn
# 1、定义查找表的形状为10*3
embedding = nn.Embedding(10, 3)
# 2、查看Embedding初始化权重信息
embedding.weight
print(embedding.weight)
# 3、定义输入
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 4、将输入中的每个词转换为词嵌入
a = embedding(input)
print(a)
输出结果:
Parameter containing:
tensor([[-1.7372, -0.7281, -1.9509],
[-1.1080, 0.7775, -0.7351],
[ 0.9606, 2.3034, 1.1976],
[-0.6429, 2.1996, -0.0045],
[-0.6949, -1.9427, -0.3486],
[-2.4980, -0.7219, 1.0658],
[-1.4095, 1.7520, 0.7215],
[-0.2162, 0.7108, 0.9062],
[-2.3733, 0.1184, -0.9335],
[-0.0870, 0.1308, -0.6418]], requires_grad=True)
tensor([[[ 0.2644, 0.4962, -2.5476],
[ 1.3521, -0.2055, 0.9044],
[-0.3781, 0.0259, -1.7972],
[-1.0164, -0.5694, -1.0062]],
[[-0.3781, 0.0259, -1.7972],
[-1.6988, -1.1996, -1.7316],
[ 1.3521, -0.2055, 0.9044],
[-1.1474, 0.9734, -0.2874]]], grad_fn=<EmbeddingBackward0>)
Process finished with exit code 0
requires_grad=True,所以weight是可学习的。
三、初始化
Enbedding Layer是如何初始化权重矩阵(即查找表)的??
观察nn.Embedding对应的源码:
class Embedding(Module):
............
if _weight is None:
self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
self.reset_parameters()
else:
............
def reset_parameters(self) -> None:
init.normal_(self.weight)
............
更新weight时主要使用了实例方法self.reset_parameters(),而这个实例方法又调用了初始化(init)模块中的normal_方法。
题外话
对于CNN中的参数:
-可学习的参数:卷积层和全连接层的权重、bias、BatchNorm的 [公式] 等。
-不可学习的参数(超参数):学习率、batch_size、weight_decay、模型的深度宽度分辨率等。