Hi, it's nice to see strassen has attracted attention again. I would like to 
know which hardware have you used and how many cores have you used?

Actually, it's easy to implement strassen in TVM, and I have tested this 
algorithm with two different implementations.

TE version:
```python
def strassen_gemm(N):
    def gemm(A, B, N, name=""):
        global GEMM_COUNT
        if name != "":
            name += "G%d_" % GEMM_COUNT
            GEMM_COUNT += 1
        if (N > DIRECT_SIZE):
            return strassen(A, B, N, name)
        else:
            return direct(A, B, N, name)

    def direct(A, B, N, name):
        k = tvm.reduce_axis((0, N))
        C = tvm.compute(A.shape, lambda i, j: tvm.sum(A[i][k] * B[k][j], 
axis=k),
                        name=name+'C')
        return C

    def split(A, new_n, ori_name="Matrix"):
        A11 = tvm.compute((new_n, new_n),
            lambda i, j: A[i][j], name=ori_name+"11")
        A12 = tvm.compute((new_n, new_n),
            lambda i, j: A[i][j+new_n], name=ori_name+"12")
        A21 = tvm.compute((new_n, new_n),
            lambda i, j: A[i+new_n][j], name=ori_name+"21")
        A22 = tvm.compute((new_n, new_n),
            lambda i, j: A[i+new_n][j+new_n], name=ori_name+"22")
        return A11, A12, A21, A22

    def sub(A, B, N, name):
        C = tvm.compute((N, N),
            lambda i, j: A[i][j] - B[i][j], name=name)
        return C

    def add(A, B, N, name):
        C = tvm.compute((N, N),
            lambda i, j: A[i][j] + B[i][j], name=name)
        return C

    def strassen(A, B, N, name):
        global GEMM_LEVEL
        new_n = int(N / 2)

        A11, A12, A21, A22 = split(A, new_n, name+"A")
        B11, B12, B21, B22 = split(B, new_n, name+"B")

        S1 = sub(B12, B22, new_n, name+"S1")
        S2 = add(A11, A12, new_n, name+"S2")
        S3 = add(A21, A22, new_n, name+"S3")
        S4 = sub(B21, B11, new_n, name+"S4")
        S5 = add(A11, A22, new_n, name+"S5")
        S6 = add(B11, B22, new_n, name+"S6")
        S7 = sub(A12, A22, new_n, name+"S7")
        S8 = add(B21, B22, new_n, name+"S8")
        S9 = sub(A11, A21, new_n, name+"S9")
        S10 = add(B11, B12, new_n, name+"S10")

        level = GEMM_LEVEL
        GEMM_LEVEL += 1
        P1 = gemm(A11, S1, new_n, name+"L%d_"%level)
        P2 = gemm(S2, B22, new_n, name+"L%d_"%level)
        P3 = gemm(S3, B11, new_n, name+"L%d_"%level)
        P4 = gemm(A22, S4, new_n, name+"L%d_"%level)
        P5 = gemm(S5, S6, new_n, name+"L%d_"%level)
        P6 = gemm(S7, S8, new_n, name+"L%d_"%level)
        P7 = gemm(S9, S10, new_n, name+"L%d_"%level)

        C11 = tvm.compute((new_n, new_n),
                lambda i, j: P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j], 
name=name+"C11")
        C12 = add(P1, P2, new_n, name+"C12")
        C21 = add(P3, P4, new_n, name+"C21")
        C22 = tvm.compute((new_n, new_n),
                lambda i, j: P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j], 
name=name+"C22")

        C = tvm.compute((N, N),
                lambda i, j: tvm.if_then_else(i < new_n,
                    tvm.if_then_else(j < new_n, C11[i][j], C12[i][j-new_n]),
                    tvm.if_then_else(j < new_n, C21[i-new_n][j], 
C22[i-new_n][j-new_n])),
                name=name+"C")
        return C

    A = tvm.placeholder((N, N), name="A")
    B = tvm.placeholder((N, N), name="B")
    C = gemm(A, B, N)
    sch = tvm.create_schedule(C.op)
    return sch, [A, B, C]
```

Relay Version(I even tried an implementation of merging the gemm of 7 
sub-matrix to a single batch_matmul):
```python
def strassen_gemm(N, K, M, max_level=1):
    # A [N, K]
    # B [K, M]
    # C [N, M]
    def gemm(A, B, N, K, M, level):
        if (level < max_level and N % 2 == 0 and
                K % 2 == 0 and M % 2 == 0):
            return strassen(A, B, N, K, M, level)
        else:
            return direct(A, B, N, K, M)

    def direct(A, B, N, K, M):
        C = relay.nn.dense(A, relay.transpose(B, [1, 0]))
        return C

    def split(A, new_x, new_y):
        A11 = relay.strided_slice(A, [0, 0], [new_x, new_y])
        A12 = relay.strided_slice(A, [0, new_y], [new_x, new_y*2])
        A21 = relay.strided_slice(A, [new_x, 0], [new_x*2, new_y])
        A22 = relay.strided_slice(A, [new_x, new_y], [new_x*2, new_y*2])
        return A11, A12, A21, A22

    def strassen(A, B, N, K, M, level):
        new_n = int(N / 2)
        new_k = int(K / 2)
        new_m = int(M / 2)

        A11, A12, A21, A22 = split(A, new_n, new_k)
        B11, B12, B21, B22 = split(B, new_k, new_m)

        S1 = B12 - B22
        P1 = gemm(A11, S1, new_n, new_k, new_m, level+1)
        S2 = A11 + A12
        P2 = gemm(S2, B22, new_n, new_k, new_m, level+1)
        C12 = P1 + P2

        S3 = A21 + A22
        P3 = gemm(S3, B11, new_n, new_k, new_m, level+1)
        S4 = B21 - B11
        P4 = gemm(A22, S4, new_n, new_k, new_m, level+1)
        C21 = P3 + P4

        S5 = A11 + A22
        S6 = B11 + B22
        P5 = gemm(S5, S6, new_n, new_k, new_m, level+1)
        S7 = A12 - A22
        S8 = B21 + B22
        P6 = gemm(S7, S8, new_n, new_k, new_m, level+1)
        C11 = P5 + P4 - P2 + P6

        S9 = A11 - A21
        S10 = B11 + B12
        P7 = gemm(S9, S10, new_n, new_k, new_m, level+1)
        C22 = P5 + P1 - P3 - P7

        C1 = relay.concatenate([C11, C12], 1)
        C2 = relay.concatenate([C21, C22], 1)
        C = relay.concatenate([C1, C2], 0)
        return C

    def strassen_merge(A, B, N):
        new_n = int(N / 2)

        A11, A12, A21, A22 = split(A, new_n)
        B11, B12, B21, B22 = split(B, new_n)

        S1 = B12 - B22
        S2 = A11 + A12
        S3 = A21 + A22
        S4 = B21 - B11
        S5 = A11 + A22
        S6 = B11 + B22
        S7 = A12 - A22
        S8 = B21 + B22
        S9 = A11 - A21
        S10 = B11 + B12

        if new_n > direct_size:
            P1 = gemm(A11, S1, new_n)
            P2 = gemm(S2, B22, new_n)
            P3 = gemm(S3, B11, new_n)
            P4 = gemm(A22, S4, new_n)
            P5 = gemm(S5, S6, new_n)
            P6 = gemm(S7, S8, new_n)
            P7 = gemm(S9, S10, new_n)
        else:
            Merge_A = []
            for a in [A11, S2, S3, A22, S5, S7, S9]:
                Merge_A.append(relay.expand_dims(a, 0))
            Merge_A = relay.concatenate(Merge_A, 0)

            Merge_B = []
            for b in [S1, B22, B11, S4, S6, S8, S10]:
                Merge_B.append(relay.expand_dims(b, 0))
            Merge_B = relay.concatenate(Merge_B, 0)

            Merge_C = relay.nn.batch_matmul(Merge_A, relay.transpose(Merge_B, 
[0, 2, 1]))
            ss = relay.split(Merge_C, 7)
            P1 = relay.reshape(ss[0], [new_n, new_n])
            P2 = relay.reshape(ss[1], [new_n, new_n])
            P3 = relay.reshape(ss[2], [new_n, new_n])
            P4 = relay.reshape(ss[3], [new_n, new_n])
            P5 = relay.reshape(ss[4], [new_n, new_n])
            P6 = relay.reshape(ss[5], [new_n, new_n])
            P7 = relay.reshape(ss[6], [new_n, new_n])

        C11 = P5 + P4 - P2 + P6
        C12 = P1 + P2
        C21 = P3 + P4
        C22 = P5 + P1 - P3 - P7

        C1 = relay.concatenate([C11, C12], 1)
        C2 = relay.concatenate([C21, C22], 1)
        C = relay.concatenate([C1, C2], 0)
        return C

    A = relay.var("A", shape=(N, K))
    B = relay.var("B", shape=(K, M))
    C = gemm(A, B, N, K, M, 0)
    return A, B, C
```

The evaluation performance is not so good in the end. Only in a 4 cores 
`1024*1024*1024` case with `direct_size = 512`, I get better performance with 
strassen.

I think there are several reasons for this:
1. The TE version contains too much stages, which makes it hard to schedule, 
even we have the auto_schedule tool Ansor.
2. The Relay version contains some unnatural `slice` and `concat`, which are 
not so friendly for the memory access.
3. Op trends to perform better in gemm with a larger size. When we split a 
single gemm to 7 sub-matrix, these gemm with smaller size are likely to perform 
lower GFlops.
4. MNN manages it's memory access and compute threads well. It can even run the 
7 sub-matrix gemm in parallel, while TVM cannot support inter_op parallelism.
5. For the strassen algorithm itself, in my understanding it does save 
computation in single thread running(in theory can reduce from O(3) to O(2.7)), 
but when we take it to a multi-thread situation I think it will not be so 
beneficial.

So my conclusion is:
1. Strassen should be more powerful with little CPU cores, e.g. in a ARM CPU 
with only 4 or 8 cores, which is just the target hardware of MNN. In a Intel 
CPU with more cores, I don't think we can benefit from strassen.
2. MNN does have better memory/thread management since it's directly written in 
C. TVM seems not able to do the same thing with codegen.





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/strassen-algorithm-for-dense/2661/8) to 
respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/b2fd1f062dc099322dcbde547d2033398a33f5082b0d1520c00b12c2a2dcc2df).

Reply via email to