看名字就知道是算余弦相似度,但是有个烦人的参数dim
,本文主要解决如下几个问题
dim
参数到底有什么作用?如何设置dim参数- 两个矩阵使用该函数算余弦相似度到底是按列向量来算还是按行向量来算?
- 如果想要算矩阵中每个行向量两两之间的相似度,如何计算?
1. dim的作用
- 实验一: dim=0
import torch.nn.functional as F
import torch
import math
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
b = torch.tensor([[5, 6], [7, 8]], dtype=torch.float)
def check(vec_a, vec_b):
dot = 0
for i in range(len(vec_a)):
dot += vec_a[i]*vec_b[i]
vec_a_sq_sum = math.sqrt(sum([item*item for item in vec_a]))
vec_b_sq_sum = math.sqrt(sum([item*item for item in vec_b]))
return dot/(vec_a_sq_sum*vec_b_sq_sum)
if __name__ == "__main__":
res = F.cosine_similarity(a, b, dim=0)
print(res)
check1 = check([1,3], [5,7])
check2 = check([2,4], [6,8])
print(check1)
print(check2)
如上计算a, b的余弦相似度,dim=0,得到如下结果:
tensor([0.9558, 0.9839])
0.95577900872195
0.9838699100999074
因此对于二维矩阵,dim=0表示对应列的列向量之间进行cos相似度计算
- 实验2: dim = 1
if __name__ == "__main__":
res = F.cosine_similarity(a, b, dim=1)
print(res)
check1 = check([1,2], [5,6])
check2 = check([3,4], [7,8])
print(check1)
print(check2)
得到如下结果
tensor([0.9734, 0.9972])
0.973417168333576
0.997164120486613
因此dim=1表示相对应的行向量之间的余弦相似度计算, 默认情况下dim=1,即当不设置dim参数时,是计算行向量之间的相似度
2. 如何计算两两之间的相似度?
从上面可以看出,输入两个2*2的矩阵,F.cosine_similarity
函数输出的是1*2的结果,即指算了对应行或对应列向量之间的相似度,那么如何计算行向量两两之间的相似度呢,即输入两个22的矩阵,输出为22的相似度矩阵
将上面main函数改成如下:
if __name__ == "__main__":
x = a.unsqueeze(1)
y = b.unsqueeze(0)
print("x", x)
print("y", y)
res = F.cosine_similarity(x, y, dim=-1)
print("res", res)
check11 = check([1,2], [5,6])
check12 = check([1,2], [7,8])
check21 = check([3,4], [5,6])
check22 = check([3,4], [7,8])
print(check11, check12, check21, check22)
输出如下结果
x tensor([[[1., 2.]],
[[3., 4.]]])
y tensor([[[5., 6.],
[7., 8.]]])
res tensor([[0.9734, 0.9676],
[0.9987, 0.9972]])
0.973417168333576 0.9676172723968437 0.9986876634765887 0.997164120486613
可看到上面程序实现了行向量间两两求相似度的需求
个人猜测其原理如下:
x.shape=[2, 1, 2]
,y.shape=[1, 2, 2]
,在三维情况下dim=-1(dim=2)表示行,从x中取一行,x_vec.shape = [1,2] , 从y中取一行, y_vec.shape = [2,2],这样计算相似度为两个,同理进行下一行计算,有点像numpy中矩阵自动扩展计算类似
3. 总结
如果想要按行计算相似度,dim设置成-1即可,如果要想两两计算相似度,需要使用unsqueeze
函数进行增加矩阵维度。