Hello there. The idea is just same with existing IR pass described in
https://discuss.tvm.ai/t/discussion-new-ir-pass-proposal-combineparalleldense/3813
by @jonso . Many sequential network structures conduct group of matmul
operations on same input tensor such as
- gate projections on state within GRU/LSTM
- Q/K/V projections on input within transformer layer
Thanks to `CombineParallelDense` pass such operations can be combined to fully
utilize performance of matmul kernels.
The current implemented strategy is transform multiple matmul into batched
matmul op:
- before:
Y_1: [M, N] = matmul(X: [M, K], W_1: [K, N]), ..., Y_B = matmul(X, W_B: [K,
N])
- after:
Y: [B, M, N] = batch_matmul(stack(X...), stack(W_1, ... ,W_B))
However, there seems to be another simpler choice to just combine them into one
matmul instead of batched matmul, and it also works with even different output
channel sizes:
- before:
Y_1 = matmul(X: [M, K], W_1: [K, N_1]), ..., Y_B = matmul(X, W_B: [K, N_B])
- after:
Y: [M, N_1 + N_2 + ... + N_B] = matmul(X, stack(W_1, ..., W_B))
Since matmul and batch_matmul are different op implementations, the performance
of combined op may differ. The output layout are also different which may
affect downstream ops performance.
We can conduct some comparison between matmul and equivalent batch_matmul with
fixed LHS matrix. Use cublas as a reference, I find that use single cublasSgemm
is significantly faster than cublasSgemmStridedBatched in certain circumstances
with small B (typically 3)
The proposed strategy can be an option to current CombineParallelDense pass.
And I think the basic implementation logic will highly resemble
`CombineParallelConv2d`. CombineParallelDense pass can now select better
strategy between them to get more performance benefits.
---
[Visit
Topic](https://discuss.tvm.ai/t/yet-another-dense-op-combine-strategy/7126/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.ai/email/unsubscribe/da7fcd65104119955ca46f12df640ea3704b2a9ebceb5e4d4a5518323dd2db78).