Volker Strassen
1 矩阵乘法
矩阵乘法是机器学习中最基本的运算之一,对其进行优化是多种优化的关键。通常,将两个大小为N X N的矩阵相乘需要N^3次运算。从那以后,我们在更好、更聪明的矩阵乘法算法方面取得了长足的进步。沃尔克·斯特拉森于1969年首次发表了他的算法。这是第一个证明基本O(n^3)运行时不是optiomal的算法。
Strassen算法的基本思想是将A和B分为8个子矩阵,然后递归计算C的子矩阵。这种策略称为分而治之。
2 伪代码
- 如上图所示,将矩阵A和B划分为大小为N/2 x N/2的4个子矩阵。
- 递归计算7个矩阵乘法。
- 计算C的子矩阵。
- 将这些子矩阵组合到我们的新矩阵C中
3 复杂性
- 最坏情况时间复杂度:Θ(n^2.8074)
- 最佳情况时间复杂度:Θ(1)
- 空间复杂度:Θ(logn)
年青时正在发愁的 Volker Strassen
4 算法的详细解释
矩阵相乘在进行3D变换的时候是经常用到的。在应用中常用矩阵相乘的定义算法对其进行计算。这个算法用到了大量的循环和相乘运算,这使得算法效率不高。而矩阵相乘的计算效率很大程度上的影响了整个程序的运行速度,所以对矩阵相乘算法进行一些改进是必要的。
我们先讨论二阶矩阵的计算方法。
对于二阶矩阵
a11 a12 b11 b12
A = a21 a22 B = b21 b22
先计算下面7个量(1)
x1 = (a11 + a22) * (b11 + b22);
x2 = (a21 + a22) * b11;
x3 = a11 * (b12 - b22);
x4 = a22 * (b21 - b11);
x5 = (a11 + a12) * b22;
x6 = (a21 - a11) * (b11 + b12);
x7 = (a12 - a22) * (b21 + b22);
再设C = AB。根据矩阵相乘的规则,C的各元素为(2)
c11 = a11 * b11 + a12 * b21
c12 = a11 * b12 + a12 * b22
c21 = a21 * b11 + a22 * b21
c22 = a21 * b12 + a22 * b22
比较(1)(2),C的各元素可以表示为(3)
c11 = x1 + x4 - x5 + x7
c12 = x3 + x5
c21 = x2 + x4
c22 = x1 + x3 - x2 + x6
根据以上的方法,我们就可以计算4阶矩阵了,先将4阶矩阵A和B划分成四块2阶矩阵,分别利用公式计算它们的乘积,再使用(1)(3)来计算出最后结果。
本文给出了多种算法,大家自己选择吧。
5 源程序
using System;
using System.Text;
namespace Legal.Truffer.Algorithm
{
/// <summary>
/// 矩阵相乘的斯特拉森(V. Strassen)方法
/// </summary>
public static class Matrix_Calculator
{
#region [4x4]x[4x4]矩阵相乘的斯特拉森(V. Strassen)方法(快速算法)
// 计算2X2矩阵
private static void Multiply2X2(out double fOut_11, out double fOut_12, out double fOut_21, out double fOut_22,
double f1_11, double f1_12, double f1_21, double f1_22,
double f2_11, double f2_12, double f2_21, double f2_22)
{
double x1 = ((f1_11 + f1_22) * (f2_11 + f2_22));
double x2 = ((f1_21 + f1_22) * f2_11);
double x3 = (f1_11 * (f2_12 - f2_22));
double x4 = (f1_22 * (f2_21 - f2_11));
double x5 = ((f1_11 + f1_12) * f2_22);
double x6 = ((f1_21 - f1_11) * (f2_11 + f2_12));
double x7 = ((f1_12 - f1_22) * (f2_21 + f2_22));
fOut_11 = x1 + x4 - x5 + x7;
fOut_12 = x3 + x5;
fOut_21 = x2 + x4;
fOut_22 = x1 - x2 + x3 + x6;
}
// 计算4X4矩阵
public static Matrix Multiply4x4(Matrix a, Matrix b)
{
//double c[7,4] = new double[7,4];
double c_0_0, c_0_1, c_0_2, c_0_3;
double c_1_0, c_1_1, c_1_2, c_1_3;
double c_2_0, c_2_1, c_2_2, c_2_3;
double c_3_0, c_3_1, c_3_2, c_3_3;
double c_4_0, c_4_1, c_4_2, c_4_3;
double c_5_0, c_5_1, c_5_2, c_5_3;
double c_6_0, c_6_1, c_6_2, c_6_3;
// (ma11 + ma22) * (mb11 + mb22)
Multiply2X2(out c_0_0, out c_0_1, out c_0_2, out c_0_3,
a[0] + a[10], a[1] + a[11], a[4] + a[14], a[5] + a[15],
b[0] + b[10], b[1] + b[11], b[4] + b[14], b[5] + b[15]);
// (ma21 + ma22) * mb11
Multiply2X2(out c_1_0, out c_1_1, out c_1_2, out c_1_3,
a[8] + a[10], a[9] + a[11], a[12] + a[14], a[13] + a[15],
b[0], b[1], b[4], b[5]);
// ma11 * (mb12 - mb22)
Multiply2X2(out c_2_0, out c_2_1, out c_2_2, out c_2_3,
a[0], a[1], a[4], a[5],
b[2] - b[10], b[3] - b[11], b[6] - b[14], b[7] - b[15]);
// ma22 * (mb21 - mb11)
Multiply2X2(out c_3_0, out c_3_1, out c_3_2, out c_3_3,
a[10], a[11], a[14], a[15],
b[8] - b[0], b[9] - b[1], b[12] - b[4], b[13] - b[5]);
// (ma11 + ma12) * mb22
Multiply2X2(out c_4_0, out c_4_1, out c_4_2, out c_4_3,
a[0] + a[2], a[1] + a[3], a[4] + a[6], a[5] + a[7],
b[10], b[11], b[14], b[15]);
// (ma21 - ma11) * (mb11 + mb12)
Multiply2X2(out c_5_0, out c_5_1, out c_5_2, out c_5_3,
a[8] - a[0], a[9] - a[1], a[12] - a[4], a[13] - a[5],
b[0] + b[2], b[1] + b[3], b[4] + b[6], b[5] + b[7]);
// (ma12 - ma22) * (mb21 + mb22)
Multiply2X2(out c_6_0, out c_6_1, out c_6_2, out c_6_3,
a[2] - a[10], a[3] - a[11], a[6] - a[14], a[7] - a[15],
b[8] + b[10], b[9] + b[11], b[12] + b[14], b[13] + b[15]);
return new Matrix(4, 4, new double[4 * 4] {
c_0_0 + c_3_0 - c_4_0 + c_6_0,
c_0_1 + c_3_1 - c_4_1 + c_6_1,
c_2_0 + c_4_0,
c_2_1 + c_4_1,
c_0_2 + c_3_2 - c_4_2 + c_6_2,
c_0_3 + c_3_3 - c_4_3 + c_6_3,
c_2_2 + c_4_2,
c_2_3 + c_4_3,
c_1_0 + c_3_0,
c_1_1 + c_3_1,
c_0_0 - c_1_0 + c_2_0 + c_5_0,
c_0_1 - c_1_1 + c_2_1 + c_5_1,
c_1_2 + c_3_2,
c_1_3 + c_3_3,
c_0_2 - c_1_2 + c_2_2 + c_5_2,
c_0_3 - c_1_3 + c_2_3 + c_5_3
});
}
#endregion
#region 基于Strassen算法的矩阵“分治”乘法(只支持维度为2的幂次的方阵相乘。)
private static Matrix create(Matrix input, int r1, int r2, int c1, int c2)
{
Matrix res = new Matrix(r2 - r1, c2 - c1);
for (int i = r1, ii = 0; i <= r2 && ii < r2 - r1; i++, ii++)
{
for (int j = c1, jj = 0; j < c2 && jj < c2 - c1; j++, jj++)
{
res[ii, jj] = input[i, j];
}
}
return res;
}
public static Matrix Multipy(Matrix A, Matrix B, int len, int r1 = 0, int c1 = 0)
{
if (len == 1)
{
return new Matrix(1, 1,
new double[1] { A[0] * B[0] }
);
}
int lend2 = len / 2;
Matrix a = create(A, r1, r1 + lend2, c1, c1 + lend2);
Matrix e = create(B, r1, r1 + lend2, c1, c1 + lend2);
Matrix b = create(A, r1, r1 + lend2, c1 + lend2, len);
Matrix f = create(B, r1, r1 + lend2, c1 + lend2, len);
Matrix c = create(A, r1 + lend2, len, c1, c1 + lend2);
Matrix g = create(B, r1 + lend2, len, c1, c1 + lend2);
Matrix d = create(A, r1 + lend2, len, c1 + lend2, len);
Matrix h = create(B, r1 + lend2, len, c1 + lend2, len);
Matrix p1 = a * (f - h); // multi(a, sub(f, h, lend2), 0, 0, lend2);
Matrix p2 = (a + b) * h; // multi(add(a, b, lend2), h, 0, 0, lend2);
Matrix p3 = (c + d) * e; // multi(add(c, d, lend2), e, 0, 0, lend2);
Matrix p4 = d * (g - e); // multi(d, sub(g, e, lend2), 0, 0, lend2);
Matrix p5 = (a + d) * (e + h); // multi(add(a, d, lend2), add(e, h, lend2), 0, 0, lend2);
Matrix p6 = (b - d) * (g + h); // multi(sub(b, d, lend2), add(g, h, lend2), 0, 0, lend2);
Matrix p7 = (a - c) * (e + f); // multi(sub(a, c, lend2), add(e, f, lend2), 0, 0, lend2);
Matrix r = (((p5 + p4) + p6) - p2); // sub(add(add(p5, p4, lend2), p6, lend2), p2, lend2);
Matrix s = p1 + p2; // add(p1, p2, lend2);
Matrix t = p3 + p4; // add(p3, p4, lend2);
Matrix u = (p5 + p1) - (p3 + p7);// sub(add(p5, p1, lend2), add(p3, p7, lend2), lend2);
Matrix rr = new Matrix(len, len);
for (int j = 0; j < lend2; j++)
{
for (int jj = 0; jj < lend2; jj++)
{
rr[j, jj] = r[j, jj];
}
}
for (int j = 0; j < lend2; j++)
{
for (int jj = 0; jj < lend2; jj++)
{
rr[j, jj + lend2] = s[j, jj];
}
}
for (int j = 0; j < lend2; j++)
{
for (int jj = 0; jj < lend2; jj++)
{
rr[j + lend2, jj] = t[j, jj];
}
}
for (int j = 0; j < lend2; j++)
{
for (int jj = 0; jj < lend2; jj++)
{
rr[j + lend2, jj + lend2] = u[j, jj];
}
}
return rr;
}
#endregion
#region 基于Strassen矩阵乘法的递归分治算法
/// <summary>
/// 基于Strassen矩阵乘法的递归分治算法
/// </summary>
/// <param name="n"></param>
/// <param name="A"></param>
/// <param name="B"></param>
/// <returns></returns>
public static Matrix Strassen(int n, Matrix A, Matrix B)
{
//2-order
if (n == 2)
{
return A * B;
}
int N = n / 2;
Matrix A11 = new Matrix(N, N);
Matrix A12 = new Matrix(N, N);
Matrix A21 = new Matrix(N, N);
Matrix A22 = new Matrix(N, N);
Matrix B11 = new Matrix(N, N);
Matrix B12 = new Matrix(N, N);
Matrix B21 = new Matrix(N, N);
Matrix B22 = new Matrix(N, N);
//将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。
for (int i = 0; i < n / 2; i++)
{
for (int j = 0; j < n / 2; j++)
{
A11[i, j] = A[i, j];
A12[i, j] = A[i, j + n / 2];
A21[i, j] = A[i + n / 2, j];
A22[i, j] = A[i + n / 2, j + n / 2];
B11[i, j] = B[i, j];
B12[i, j] = B[i, j + n / 2];
B21[i, j] = B[i + n / 2, j];
B22[i, j] = B[i + n / 2, j + n / 2];
}
}
//Calculate M1 = (A0 + A3) × (B0 + B3)
Matrix M1 = Strassen(N, A11 + A22, B11 + B22);
//Calculate M2 = (A2 + A3) × B0
Matrix M2 = Strassen(N, A21 + A22, B11);
//Calculate M3 = A0 × (B1 - B3)
Matrix M3 = Strassen(N, A11, B12 - B22);
//Calculate M4 = A3 × (B2 - B0)
Matrix M4 = Strassen(N, A22, B21 - B11);
//Calculate M5 = (A0 + A1) × B3
Matrix M5 = Strassen(N, A11 + A12, B22);
//Calculate M6 = (A2 - A0) × (B0 + B1)
Matrix M6 = Strassen(N, A21 - A11, B11 + B12);
//Calculate M7 = (A1 - A3) × (B2 + B3)
Matrix M7 = Strassen(N, A12 - A22, B21 + B22);
//Calculate C0 = M1 + M4 - M5 + M7
Matrix C11 = (M1 + M4) + (M7 - M5);
//Calculate C1 = M3 + M5
Matrix C12 = M3 + M5;
//Calculate C2 = M2 + M4
Matrix C21 = M2 + M4;
//Calculate C3 = M1 - M2 + M3 + M6
Matrix C22 = (M1 - M2) + (M3 + M6);
Matrix C = new Matrix(n, n);
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
C[i, j] = C11[i, j];
C[i, j + N] = C12[i, j];
C[i + N, j] = C21[i, j];
C[i + N, j + N] = C22[i, j];
}
}
return C;
}
#endregion
}
}