Bootstrap

Spark中的矩阵乘法源码分析

前言:
    矩阵乘法在数据挖掘/机器学习中是常用的计算步骤,并且在大数据计算中,shuffle过程是不可避免的,矩阵乘法的不同计算方式shuffle的数据量都不相同。通过对矩阵乘法不同计算方式的深入学习,希望能够对大数据算法实现的shuffle过程优化有所启发。网上有很多分布式矩阵乘法相关的文章和论文,但是鲜有对Spark中分布式矩阵乘法的分析。本文针对Spark中分布式矩阵乘法的实现进行必要的说明讨论。

分布式矩阵乘法原理:
    矩阵乘法计算可以分为内积法和外积法。根据实现颗粒度的不同,也可以分为普通矩阵实现和分块矩阵实现。
矩阵乘法计算公式:

公式1为内积法计算公式,即使用A矩阵的行和B矩阵的列进行向量内积计算,每次计算矩阵C中的一个元素。
公式2为外积法计算公式,即使用A矩阵的列和B矩阵的行进行向量外积计算,每次计算一个n×k的矩阵,再将计算出的m个矩阵相加,得到矩阵C。
分布式矩阵乘法内积法实现:
    内积要求每次能够计算矩阵C中的一个元素,由于C中每个元素的计算是相互独立的,所以这个计算过程能够并发执行,由于C中有n×k个元素,所以能够支持的最大并发量是n×k。在计算C中每一个元素的时候,都需要用到A中某一行的m个元素和B中某一行的m个元素,这个过程需要从A中shuffle m个元素和B中shuffle m个元素到C的计算节点上,所以使用内积法直接计算矩阵的乘法,最多需要shuffle的元素个数为2×m×n×k个(因为C有n×k个元素),每个计算节点一次最少要计算C中的1个元素,所以每个计算节点内存最少要能够存储2m的数据。但是在大矩阵的运算中,通常并发度不可能开发n×k的程度,所以事实上需要shuffle的数据会远小于2×m×n×k个,因为由于并发量远小于C中的元素个数,所以会在同一个计算节点上计算C的多个元素值,这个时候,这个节点上的数据就能重复使用,不需要为每个元素的计算都shuffle 2m的数据量。这也是使用内积发计算分块矩阵的时候,shuffle量会大大减少的原因。算法shuffle过程如下:

    这个算法的 缺点显而易见,shuffle的数据过多。在分布式计算系统中,shuffle过程是影响系统性能的重要因素。但是这种算法在A或B其中一个矩阵是小矩阵的时候是有明显的 优势的。例如:如果B是小矩阵,矩阵B可以broadcast到分布式矩阵A的每一个节点上,这么在计算C矩阵的时候,A不需要再进行shuffle操作,能够充分的利用数据本地性进行计算。
分布式矩阵乘法外积法实现:
    向量外积计算公式如下:

    结合外积的计算公式,从上面公式2可以看出,矩阵乘法的外积公式是先计算m个n×k的矩阵,再将m个矩阵相加得到矩阵C。在这个过程中m个矩阵的计算过程是相互独立的。所以能够支持的最大并发量是m。在计算m个矩阵的任意一个的时候,都需要用到A中某一列的n个元素和B中某一行的k个元素,所以在计算m个矩阵的时候,最多需要移动m×(n+k)个元素。在计算完m个矩阵之后,需要将m个矩阵加和。矩阵加法是对应位置元素相加,所以在最后计算C中每一个元素的时候,需要将m个矩阵对应位置的数据shuffle到某一个计算节点进行加和,所以在加和过程中,需要shuffle的最大数据量是m×n×k(因为m个矩阵,每个矩阵有n×k个元素)。每个计算节点最少要计算m个矩阵中的1个,所以计算节点最少需要能够存储n+k个元素的内存,由于每个节点输出元素量为n×k,这个数据量在大规模矩阵中是相当大的,单个计算节点内存很难存储下来,所以每个节点计算的输出结果通常是达到一定量后存储到磁盘上,但是如果是在大规模稀疏矩阵中,m个矩阵中每个矩阵中有值的个数通常都会远小于n×k个,所以第二次shuffle的数据量通常远小于n×k。算法shuffle过程如下:

     这个算法的 缺点就是,在分布式大矩阵中,如果不是稀疏矩阵,计算出来的中间矩阵会非常大,要在单个计算节点完全使用内存计算中间矩阵基本不可能,需要使用磁盘辅助存储中间矩阵的计算结果。算法的 优势在于计算稀疏矩阵和虽然AB都是大规模矩阵,但是计算结果是个小矩阵的时候,这两种情况每个中间矩阵都能够完全存储在内存中,就会比较快。
分块矩阵的分布式乘法实现:
    分块矩阵乘法的过程和不分块的计算过程差不多,也可以使用内积法和外积法实现。使用分块矩阵实现分布式矩阵乘法的好处主要有两个,一个是减少shuffle过程的数据量,另一个是分块矩阵在每个小块在本地计算的时候能够调用现有矩阵计算包,成熟的矩阵计算包通常来说计算效率都比自己实现的好,例如矩阵计算的常用包BLAS包。
分块矩阵的基本计算公式如下(来自维基百科: https://zh.wikipedia.org/wiki/%E5%88%86%E5%A1%8A%E7%9F%A9%E9%99%A3):

    以上公式的计算复杂度为O(n^3),除了以上公式之外1969年Strassen利用分治算法将分块矩阵乘法的计算复杂度降低到O(n^l
;