记矩阵
A
A
A 大小为
M
×
D
M\times D
M×D,矩阵
B
B
B 大小为
N
×
D
N\times D
N×D,
A
=
[
a
⃗
1
,
a
⃗
2
⋯
a
⃗
m
]
A=[\vec a_1,\vec a_2\cdots\vec a_m]
A=[a1,a2⋯am],
B
=
[
b
⃗
1
,
b
⃗
2
⋯
b
⃗
n
]
B=[\vec b_1,\vec b_2\cdots\vec b_n]
B=[b1,b2⋯bn]
现在要计算
A
A
A 中每个向量
a
⃗
i
\vec a_i
ai 与
B
B
B 中每个向量
b
⃗
j
\vec b_j
bj 之间的欧式距离,最朴素的办法就是
O
(
n
2
)
O(n^2)
O(n2) 的硬算,但这样会非常的慢,如果可以使用矩阵运算来代替循环,则可以大大提升速度。
推导过程
首先计算
a
⃗
i
\vec a_i
ai 和
b
⃗
j
\vec b_j
bj 之间的距离
D
(
i
,
j
)
D(i,j)
D(i,j)
D
(
i
,
j
)
=
(
a
i
,
1
−
b
j
,
1
)
2
+
⋯
+
(
a
i
,
D
−
b
j
,
D
)
2
=
(
a
i
,
1
2
+
⋯
+
a
i
,
D
2
)
+
(
b
j
,
1
2
+
⋯
+
b
j
,
D
2
)
−
2
×
(
a
i
,
1
b
+
j
,
1
+
⋯
+
a
i
,
D
+
b
j
,
D
)
=
∣
∣
a
⃗
i
∣
∣
2
+
∣
∣
b
⃗
j
∣
∣
2
−
2
×
a
⃗
i
b
⃗
j
T
\begin{aligned}D(i,j)&=\sqrt{(a_{i,1}-b_{j,1})^2+\cdots+(a_{i,D}-b_{j,D})^2}\\&=\sqrt{(a^2_{i,1}+\cdots+a^2_{i,D})+(b^2_{j,1}+\cdots+b^2_{j,D})-2\times(a_{i,1}b+_{j,1}+\cdots+a_{i,D}+b_{j,D})}\\&=\sqrt{||\vec a_i||^2+||\vec b_j||^2-2\times\vec a_i\vec b^T_j}\end{aligned}
D(i,j)=(ai,1−bj,1)2+⋯+(ai,D−bj,D)2=(ai,12+⋯+ai,D2)+(bj,12+⋯+bj,D2)−2×(ai,1b+j,1+⋯+ai,D+bj,D)=∣∣ai∣∣2+∣∣bj∣∣2−2×aibjT
由此再推广至整个矩阵
dist
=
(
∣
∣
a
⃗
1
∣
∣
2
⋯
∣
∣
a
⃗
1
∣
∣
2
⋮
⋱
⋮
∣
∣
a
⃗
M
∣
∣
2
⋯
∣
∣
a
⃗
M
∣
∣
2
)
M
×
N
+
(
∣
∣
b
⃗
1
∣
∣
2
⋯
∣
∣
b
⃗
N
∣
∣
2
⋮
⋱
⋮
∣
∣
b
⃗
1
∣
∣
2
⋯
∣
∣
b
⃗
N
∣
∣
2
)
M
×
N
−
2
×
A
B
T
\text{dist}=\sqrt{\begin{pmatrix}||\vec a_1||^2&\cdots&||\vec a_1||^2\\\vdots&\ddots&\vdots\\||\vec a_M||^2&\cdots&||\vec a_M||^2\end{pmatrix}_{M\times N}+\begin{pmatrix}||\vec b_1||^2&\cdots&||\vec b_N||^2\\\vdots&\ddots&\vdots\\||\vec b_1||^2&\cdots&||\vec b_N||^2\end{pmatrix}_{M\times N}-2\times AB^T}
dist=⎝⎜⎛∣∣a1∣∣2⋮∣∣aM∣∣2⋯⋱⋯∣∣a1∣∣2⋮∣∣aM∣∣2⎠⎟⎞M×N+⎝⎜⎛∣∣b1∣∣2⋮∣∣b1∣∣2⋯⋱⋯∣∣bN∣∣2⋮∣∣bN∣∣2⎠⎟⎞M×N−2×ABT
代码
def euclidean_dist(x, y):
"""
PARAMETER:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
RETURN:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
# x 对每个数进行平方后,在 axis=1 方向加和,此时 x shape 为 (m, 1),经过 expand() 扩展为 (m, n)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
# help(torch.addmm())
# help(dist.addmm())
torch.addmm(dist, x, y.t(), beta=1, alpha=-2, out=dist)
return dist.sqrt()