Parallel Matrix Multiplication Algorithms

References

This lecture is based on Demmel's Lecture 9 (part 2).
Matrix multiplication is a very regular and basic computation. Why is matrix multiplication more basic than matrix-vector or vector-vector (saxpy) multiplications? The reason is that matrix multiplication offers much more opportunity to exploit locality than those simpler operations.

    Operation          Definition      Floating    Memory     ratio
                                        point      access
                                       operations
    ----------------------------------------------------------------
    saxpy             y = a*x + y         2n       3n + 1      2/3
    matrix-vector     y = A*x + y       2n^2      n^2 + 3n      2
    matrix-matrix     C = A*B + C       2n^3        4n^2       n/2

Through the discussion of parallel algorithms for matrix multiplication, we will show the typical partitions of a matrix. We will also apply the performance analysis we have learned to the parallel algorithms.

Assuming the time unit is the time for one basic floating-point operation, the serial time for matrix multiplication is 2*n^3.

The performance of a parallel algorithm depends on:

  • The interconnection network, e.g., bus, ring, hypercube, etc.
  • Data layout (partition of matrices). There are two basic data layouts: 1D blocked and 2D blocked.
  • 1D blocked data layout

    In this layout, a matrix is partitioned into p strips (block columns). Each processor has one strip. The general formula for matrix multiplication is

    	C(i) = C(i) + SUM_{j=0}^{p-1} A(j)*B(j,i)
    
    Since each processor has only one block column of A and it need whole A to compute C(i), each processor has to send its block column of A around other processors.

    Algorithm (1D blocked, bus without broadcast, synchronous send and receive, with barrier)

        C(myRank) = C(myRank) + A(myRank)*B(myRank,myRank);
        for i=0 to p-1
          for j=0 to p-1 except i
            if (myRank==i) send A(i) to Pj; end if;
    	if (myRank==j)
    	  recv A(i) from Pi;
    	  C(myRank) = C(myRank) + A(i)*B(i,myRank);
    	end if
    	barrier();
          end for
        end for
    
    The computation time in the inner most loop is

    	2*n*(n/p)^2, ignoring the lower order terms.
    
    The communication time in the inner most loop is

    	T_s + n*(n/p)*T_w.
    
    Thus the total time

       T(p,n) = (p*(p-1) + 1)*(2*n^3/p^2 + n^2/p) ------- computation
    	  + (p*(p-1))*(t_s + n^2/p*T_w) ------------- communication
    	  ~ 2*n^3 + p^2*T_s + p*n^2*T_w
    
    This algorithm is slower than the serial. Because of the barrier, only one processor works at a time.

    Obviously, the efficiency

    	      T(1,n)      1
    	E = ---------- < ---
    	     p*T(p,n)     p
    
    So, the more processors we use, the lower the efficiency no matter what we do with the problem size. This a poor parallel algorithm.

    Now, we try to overlap computation and communication by removing the barrier. We rewrite the the above algorithm as follows after barrier is removed.

    Algorithm (1D blocked, bus without broadcast, synchronous send and receive, without a barrier)

        C(myRank) = C(myRank) + A(myRank)*B(myRank,myRank);
        for i=0 to myRank-1
          recv A(i) from Pi;
          C(myRank) = C(myRank) + A(i)*B(i,myRank);
        end for;
        for i=0 tp p-1 except myRank
          send A(myRank) to Pi;
        end for;
        for i=myRank+1 to p-1
          recv A(i) from Pi;
          C(myRank) = C(myRank) + A(i)*B(i,myRank);
        end for
    
    Question: Is this algorithm equivalent to the previous one?

    Denote the communication time and computation time

        cm (communication time) = T_s + (n^2/p)*T_w
        ar (computation time) = 2*n^3/p^2
    
    When the communication is not so fast, cm>=ar,

      P0: |ar| cm | cm | cm | cm |ar|      cm |ar|           cm |ar|
      P1: |ar| cm |ar|        cm | cm | cm |    cm |ar|           cm |ar|
      P2: |ar|      cm |ar|        cm |ar| cm | cm | cm |              cm |ar|
      P3: |ar|           cm |ar|        cm |ar|      cm |ar| cm | cm | cm |
    
    	Total time T(p,n) = p*(p-1)*cm + 3*ar
                              ~ p^2*cm
                              >= p^2*ar = 2*n^3
    
    Again, this a poor algorithm.

    Note that when n>>p, cm is much smaller than ar. In other words, the communication and computation ratio cm/ar is small.

    Assuming cm/ar <= 1/(p-2),

      P0: |  ar |cm|cm|cm|cm|  ar |   cm|  ar |      cm|  ar |
      P1: |  ar |cm|  ar |cm|cm|cm|      cm|  ar |      cm|  ar |
      P2: |  ar |   cm|  ar |cm|  ar |cm|cm|cm|            cm|  ar |
      P3: |  ar |      cm|  ar |cm|  ar |   cm|  ar |cm|cm|cm|
    
            Total time T(p,n) = p*(p-1)*cm + 3*ar
                              <= (p*(p-1)/(p-2) + 3)*ar
                              ~ 2*n^3/p
    
    We get almost perfect speedup.

    In general,

    	T(p,n) = p*(p-1)*cm + 3*ar
    
                        2*n^3
            E = -----------------------
                 p*(p*(p-1)*cm + 3*ar)
    
    Recall that ar = 2*n^3/p^2,

                       p
           E = ---------------------
                p*(p-1)*(cm/ar) + 3
    
    If the ratio cm/ar >> 1, then E~1/p, if cm/ar is small, say cm/ar <= 1/p, then E~1. In other words, when n >> p, the efficiency is good.

    Now, let us check the isoefficiency function. The overhead

        T_o = p^2*(p-1)*cm + 3*p*ar - 2*n^3
            = p^2*(p-1)*T_s + p*(p-1)*n^2*T_w + (6/p - 2)*n^3
    
    Let T_o = const*2*n^3. Assuming n>>p (otherwise we would not consider isoefficiency since the efficiency is low), we get

                 p^2*(p-1)         p*(p-1)         3
        const = -----------*T_s + ---------*T_w + --- - 1
                   2*n^3             2*n           p
    
    which implies that n = O(p^2). In summary, when n>>p, the efficiency is good, to maintain the efficiency, n must increase at the rate of p^2.

    Algorithm (1D blocked, bus with broadcast)

        C(myRank) = C(myRank) + A(myRank)*B(myRank, myRank)
        for i=0 to p-1 except myRank
            Broadcast A(i) from Pi;
            C(myRank) = C(myRank) + A(i)*B(i, myRank);
        end for.
    
      p0:  | ar |  b0  |       b1  | ar |  b2  | ar |  b3  | ar |
      p1:  | ar |  b0  | ar |  b1  |       b2  | ar |  b3  | ar |
      p2:  | ar |  b0  | ar |  b1  | ar |  b2  |       b3  | ar |
      p3:  | ar |  b0  | ar |  b1  | ar |  b2  | ar |  b3  |
    
    The time and efficiency are

        T(p,n) = p*(ar + cm)
    
                 2*n^3             1
        E = --------------- = -----------
             p^2*(ar + cm)     1 + cm/ar
    
    When n>>p, the ratio cm/ar is small, E is close to 1 (perfect speedup). For the isoefficiency, the overhead is

        T_o = p^2*cm = p^2*T_s + p*n^2*T_w
    
    Let T_o = const*2*n^3, we get n = O(p). To maintain the efficiency, n must increase at the same p.

    Algorithm (1D blocked, ring)

        C(myRank) = C(myRank) + A(myRank)*B(myRank,myRank);
        outbuf = A(myRank);
        for i=1 to p-1
            send outbuf to processor (myRank+1) mod p and recv
            from processor (myRank-1) mod p in inbuf;
            C(myRank) = C(myRank) + inbuf*B((myRank-i) mod p, myRank);
            outbuf = inbuf;
        end for.
    
        T(p,n) = p*ar ------------------------- computation
               + (p-1)*cm --------------------- communication
    
                            2*n^3
        Efficiency = ---------------------
                      p*(p*ar + (p-1)*cm)
                               1
                   = -----------------------
                      1 + ((p-1)/p)*(cm/ar)
    
    Again, the efficiency depends of the ratio cm/ar. Similar to the bus with broadcast, in the isoefficiency function

         T_o = p*(p*ar + (p-1)*cm) - 2*n^3
    
    Recalling that ar = 2*n^3/p^2 and cm = T_s + (n^2/p)*T_w,

        T_o = p*(p-1)*(T_s + (n^2/p)*T_w)
    
    Let T_o = const*2*n^3, we get n = O(p).

    Cannon's algorithm on a 2D mesh

        Initial skewing;
        C = C + A*B;
        for i=1:s-1
            left-circular-shift each row;
            up-circular-shift each column;
            C = C + A*B;
        end for.
    
    where s = sqrt(p). For example, in processor P(1,1), C(1,1) is computed as follows.

        C(1,1) = A(1,2)*B(2,1) + A(1,0)*B(0,1) + A(1,1)*B(1,1)
    
    In general, in P(i,j)
        C(i,j) = SUM_{k=0}^{s-1} A(i, (i+j+k) mod s)*B((i+j+k) mod s, j)
    
    For the execution time, the initial skewing takes 2*((s/2)*cm) followed by ar computation and then the loop (s-1)*(2*cm + ar). In this case, cm = T_s + (n^2/p)*T_w and ar = 2*(n/s)^3 = 2*n^3/p^(3/2). The total time is

        T(p,n) = (3s-2)*cm + s*ar
    
    and

                               1
        Efficiency = -----------------------
                      1 + (3 - 2/s)*(cm/ar)
    
    This similar to ring. For the isoefficiency function, the overhead is T_o = p*(3s - 2)*cm. Set T_o = const*2*n^3, we get

        p*(3*p^(1/2) - 2)*T_s + (3*p^(1/2) - 2)*n^2*T_w = const*2*n^3
    
    and n = O(p^(1/2)).