Hi, it's nice to see strassen has attracted attention again. I would like to know which hardware have you used and how many cores have you used?
Actually, it's easy to implement strassen in TVM, and I have tested this algorithm with two different implementations. TE version: ```python def strassen_gemm(N): def gemm(A, B, N, name=""): global GEMM_COUNT if name != "": name += "G%d_" % GEMM_COUNT GEMM_COUNT += 1 if (N > DIRECT_SIZE): return strassen(A, B, N, name) else: return direct(A, B, N, name) def direct(A, B, N, name): k = tvm.reduce_axis((0, N)) C = tvm.compute(A.shape, lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=k), name=name+'C') return C def split(A, new_n, ori_name="Matrix"): A11 = tvm.compute((new_n, new_n), lambda i, j: A[i][j], name=ori_name+"11") A12 = tvm.compute((new_n, new_n), lambda i, j: A[i][j+new_n], name=ori_name+"12") A21 = tvm.compute((new_n, new_n), lambda i, j: A[i+new_n][j], name=ori_name+"21") A22 = tvm.compute((new_n, new_n), lambda i, j: A[i+new_n][j+new_n], name=ori_name+"22") return A11, A12, A21, A22 def sub(A, B, N, name): C = tvm.compute((N, N), lambda i, j: A[i][j] - B[i][j], name=name) return C def add(A, B, N, name): C = tvm.compute((N, N), lambda i, j: A[i][j] + B[i][j], name=name) return C def strassen(A, B, N, name): global GEMM_LEVEL new_n = int(N / 2) A11, A12, A21, A22 = split(A, new_n, name+"A") B11, B12, B21, B22 = split(B, new_n, name+"B") S1 = sub(B12, B22, new_n, name+"S1") S2 = add(A11, A12, new_n, name+"S2") S3 = add(A21, A22, new_n, name+"S3") S4 = sub(B21, B11, new_n, name+"S4") S5 = add(A11, A22, new_n, name+"S5") S6 = add(B11, B22, new_n, name+"S6") S7 = sub(A12, A22, new_n, name+"S7") S8 = add(B21, B22, new_n, name+"S8") S9 = sub(A11, A21, new_n, name+"S9") S10 = add(B11, B12, new_n, name+"S10") level = GEMM_LEVEL GEMM_LEVEL += 1 P1 = gemm(A11, S1, new_n, name+"L%d_"%level) P2 = gemm(S2, B22, new_n, name+"L%d_"%level) P3 = gemm(S3, B11, new_n, name+"L%d_"%level) P4 = gemm(A22, S4, new_n, name+"L%d_"%level) P5 = gemm(S5, S6, new_n, name+"L%d_"%level) P6 = gemm(S7, S8, new_n, name+"L%d_"%level) P7 = gemm(S9, S10, new_n, name+"L%d_"%level) C11 = tvm.compute((new_n, new_n), lambda i, j: P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j], name=name+"C11") C12 = add(P1, P2, new_n, name+"C12") C21 = add(P3, P4, new_n, name+"C21") C22 = tvm.compute((new_n, new_n), lambda i, j: P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j], name=name+"C22") C = tvm.compute((N, N), lambda i, j: tvm.if_then_else(i < new_n, tvm.if_then_else(j < new_n, C11[i][j], C12[i][j-new_n]), tvm.if_then_else(j < new_n, C21[i-new_n][j], C22[i-new_n][j-new_n])), name=name+"C") return C A = tvm.placeholder((N, N), name="A") B = tvm.placeholder((N, N), name="B") C = gemm(A, B, N) sch = tvm.create_schedule(C.op) return sch, [A, B, C] ``` Relay Version(I even tried an implementation of merging the gemm of 7 sub-matrix to a single batch_matmul): ```python def strassen_gemm(N, K, M, max_level=1): # A [N, K] # B [K, M] # C [N, M] def gemm(A, B, N, K, M, level): if (level < max_level and N % 2 == 0 and K % 2 == 0 and M % 2 == 0): return strassen(A, B, N, K, M, level) else: return direct(A, B, N, K, M) def direct(A, B, N, K, M): C = relay.nn.dense(A, relay.transpose(B, [1, 0])) return C def split(A, new_x, new_y): A11 = relay.strided_slice(A, [0, 0], [new_x, new_y]) A12 = relay.strided_slice(A, [0, new_y], [new_x, new_y*2]) A21 = relay.strided_slice(A, [new_x, 0], [new_x*2, new_y]) A22 = relay.strided_slice(A, [new_x, new_y], [new_x*2, new_y*2]) return A11, A12, A21, A22 def strassen(A, B, N, K, M, level): new_n = int(N / 2) new_k = int(K / 2) new_m = int(M / 2) A11, A12, A21, A22 = split(A, new_n, new_k) B11, B12, B21, B22 = split(B, new_k, new_m) S1 = B12 - B22 P1 = gemm(A11, S1, new_n, new_k, new_m, level+1) S2 = A11 + A12 P2 = gemm(S2, B22, new_n, new_k, new_m, level+1) C12 = P1 + P2 S3 = A21 + A22 P3 = gemm(S3, B11, new_n, new_k, new_m, level+1) S4 = B21 - B11 P4 = gemm(A22, S4, new_n, new_k, new_m, level+1) C21 = P3 + P4 S5 = A11 + A22 S6 = B11 + B22 P5 = gemm(S5, S6, new_n, new_k, new_m, level+1) S7 = A12 - A22 S8 = B21 + B22 P6 = gemm(S7, S8, new_n, new_k, new_m, level+1) C11 = P5 + P4 - P2 + P6 S9 = A11 - A21 S10 = B11 + B12 P7 = gemm(S9, S10, new_n, new_k, new_m, level+1) C22 = P5 + P1 - P3 - P7 C1 = relay.concatenate([C11, C12], 1) C2 = relay.concatenate([C21, C22], 1) C = relay.concatenate([C1, C2], 0) return C def strassen_merge(A, B, N): new_n = int(N / 2) A11, A12, A21, A22 = split(A, new_n) B11, B12, B21, B22 = split(B, new_n) S1 = B12 - B22 S2 = A11 + A12 S3 = A21 + A22 S4 = B21 - B11 S5 = A11 + A22 S6 = B11 + B22 S7 = A12 - A22 S8 = B21 + B22 S9 = A11 - A21 S10 = B11 + B12 if new_n > direct_size: P1 = gemm(A11, S1, new_n) P2 = gemm(S2, B22, new_n) P3 = gemm(S3, B11, new_n) P4 = gemm(A22, S4, new_n) P5 = gemm(S5, S6, new_n) P6 = gemm(S7, S8, new_n) P7 = gemm(S9, S10, new_n) else: Merge_A = [] for a in [A11, S2, S3, A22, S5, S7, S9]: Merge_A.append(relay.expand_dims(a, 0)) Merge_A = relay.concatenate(Merge_A, 0) Merge_B = [] for b in [S1, B22, B11, S4, S6, S8, S10]: Merge_B.append(relay.expand_dims(b, 0)) Merge_B = relay.concatenate(Merge_B, 0) Merge_C = relay.nn.batch_matmul(Merge_A, relay.transpose(Merge_B, [0, 2, 1])) ss = relay.split(Merge_C, 7) P1 = relay.reshape(ss[0], [new_n, new_n]) P2 = relay.reshape(ss[1], [new_n, new_n]) P3 = relay.reshape(ss[2], [new_n, new_n]) P4 = relay.reshape(ss[3], [new_n, new_n]) P5 = relay.reshape(ss[4], [new_n, new_n]) P6 = relay.reshape(ss[5], [new_n, new_n]) P7 = relay.reshape(ss[6], [new_n, new_n]) C11 = P5 + P4 - P2 + P6 C12 = P1 + P2 C21 = P3 + P4 C22 = P5 + P1 - P3 - P7 C1 = relay.concatenate([C11, C12], 1) C2 = relay.concatenate([C21, C22], 1) C = relay.concatenate([C1, C2], 0) return C A = relay.var("A", shape=(N, K)) B = relay.var("B", shape=(K, M)) C = gemm(A, B, N, K, M, 0) return A, B, C ``` The evaluation performance is not so good in the end. Only in a 4 cores `1024*1024*1024` case with `direct_size = 512`, I get better performance with strassen. I think there are several reasons for this: 1. The TE version contains too much stages, which makes it hard to schedule, even we have the auto_schedule tool Ansor. 2. The Relay version contains some unnatural `slice` and `concat`, which are not so friendly for the memory access. 3. Op trends to perform better in gemm with a larger size. When we split a single gemm to 7 sub-matrix, these gemm with smaller size are likely to perform lower GFlops. 4. MNN manages it's memory access and compute threads well. It can even run the 7 sub-matrix gemm in parallel, while TVM cannot support inter_op parallelism. 5. For the strassen algorithm itself, in my understanding it does save computation in single thread running(in theory can reduce from O(3) to O(2.7)), but when we take it to a multi-thread situation I think it will not be so beneficial. So my conclusion is: 1. Strassen should be more powerful with little CPU cores, e.g. in a ARM CPU with only 4 or 8 cores, which is just the target hardware of MNN. In a Intel CPU with more cores, I don't think we can benefit from strassen. 2. MNN does have better memory/thread management since it's directly written in C. TVM seems not able to do the same thing with codegen. --- [Visit Topic](https://discuss.tvm.apache.org/t/strassen-algorithm-for-dense/2661/8) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/b2fd1f062dc099322dcbde547d2033398a33f5082b0d1520c00b12c2a2dcc2df).