Bootstrap

快速计算矩阵之间的欧式距离

记矩阵 A A A 大小为 M × D M\times D M×D,矩阵 B B B 大小为 N × D N\times D N×D A = [ a ⃗ 1 , a ⃗ 2 ⋯ a ⃗ m ] A=[\vec a_1,\vec a_2\cdots\vec a_m] A=[a 1,a 2a m] B = [ b ⃗ 1 , b ⃗ 2 ⋯ b ⃗ n ] B=[\vec b_1,\vec b_2\cdots\vec b_n] B=[b 1,b 2b n]
现在要计算 A A A 中每个向量 a ⃗ i \vec a_i a i B B B 中每个向量 b ⃗ j \vec b_j b j 之间的欧式距离,最朴素的办法就是 O ( n 2 ) O(n^2) O(n2) 的硬算,但这样会非常的慢,如果可以使用矩阵运算来代替循环,则可以大大提升速度。

推导过程


首先计算 a ⃗ i \vec a_i a i b ⃗ j \vec b_j b j 之间的距离 D ( i , j ) D(i,j) D(i,j)
D ( i , j ) = ( a i , 1 − b j , 1 ) 2 + ⋯ + ( a i , D − b j , D ) 2 = ( a i , 1 2 + ⋯ + a i , D 2 ) + ( b j , 1 2 + ⋯ + b j , D 2 ) − 2 × ( a i , 1 b + j , 1 + ⋯ + a i , D + b j , D ) = ∣ ∣ a ⃗ i ∣ ∣ 2 + ∣ ∣ b ⃗ j ∣ ∣ 2 − 2 × a ⃗ i b ⃗ j T \begin{aligned}D(i,j)&=\sqrt{(a_{i,1}-b_{j,1})^2+\cdots+(a_{i,D}-b_{j,D})^2}\\&=\sqrt{(a^2_{i,1}+\cdots+a^2_{i,D})+(b^2_{j,1}+\cdots+b^2_{j,D})-2\times(a_{i,1}b+_{j,1}+\cdots+a_{i,D}+b_{j,D})}\\&=\sqrt{||\vec a_i||^2+||\vec b_j||^2-2\times\vec a_i\vec b^T_j}\end{aligned} D(i,j)=(ai,1bj,1)2++(ai,Dbj,D)2 =(ai,12++ai,D2)+(bj,12++bj,D2)2×(ai,1b+j,1++ai,D+bj,D) =a i2+b j22×a ib jT

由此再推广至整个矩阵
dist = ( ∣ ∣ a ⃗ 1 ∣ ∣ 2 ⋯ ∣ ∣ a ⃗ 1 ∣ ∣ 2 ⋮ ⋱ ⋮ ∣ ∣ a ⃗ M ∣ ∣ 2 ⋯ ∣ ∣ a ⃗ M ∣ ∣ 2 ) M × N + ( ∣ ∣ b ⃗ 1 ∣ ∣ 2 ⋯ ∣ ∣ b ⃗ N ∣ ∣ 2 ⋮ ⋱ ⋮ ∣ ∣ b ⃗ 1 ∣ ∣ 2 ⋯ ∣ ∣ b ⃗ N ∣ ∣ 2 ) M × N − 2 × A B T \text{dist}=\sqrt{\begin{pmatrix}||\vec a_1||^2&\cdots&||\vec a_1||^2\\\vdots&\ddots&\vdots\\||\vec a_M||^2&\cdots&||\vec a_M||^2\end{pmatrix}_{M\times N}+\begin{pmatrix}||\vec b_1||^2&\cdots&||\vec b_N||^2\\\vdots&\ddots&\vdots\\||\vec b_1||^2&\cdots&||\vec b_N||^2\end{pmatrix}_{M\times N}-2\times AB^T} dist=a 12a M2a 12a M2M×N+b 12b 12b N2b N2M×N2×ABT


代码


def euclidean_dist(x, y):
    """
    PARAMETER:
    x: pytorch Variable, with shape [m, d]
    y: pytorch Variable, with shape [n, d]
    RETURN:
    dist: pytorch Variable, with shape [m, n]
    """

    m, n = x.size(0), y.size(0)
    # x 对每个数进行平方后,在 axis=1 方向加和,此时 x shape 为 (m, 1),经过 expand() 扩展为 (m, n)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
    dist = xx + yy
    # help(torch.addmm())
    # help(dist.addmm())
    torch.addmm(dist, x, y.t(), beta=1, alpha=-2, out=dist)
    return dist.sqrt()
;