Hi everyone,
I’m working on a custom accelerator backend for TVM and ran into a scheduling /
memory issue that I’m not sure how to model correctly.
#### Hardware background
On our hardware we have:
* A **matrix compute unit**
* A **vector compute unit**
* A **shared on-chip buffer** (local buffer) that both units can access
In a tiled **matrix–vector** computation, the idea is:
1. The matrix unit computes one **tile** of data and writes it into the local
buffer.
2. The vector unit immediately consumes that same tile from the **same local
buffer**, without going through global memory.
So in hardware there is **no local → global → local roundtrip** for that
intermediate tile.
#### What I see in TVM
In my current TVM schedule / lowering, the behavior looks more like this:
1. Matrix compute writes the tile to a **local buffer**.
2. TVM then copies that local buffer **back to global memory**.
3. Before the vector compute stage, TVM **copies the data from global to
another local buffer** for the vector unit to use.
In other words, the intermediate result always gets stored to global memory and
then reloaded, even though in our hardware a single shared local buffer is
enough.
```
import tvm
from tvm import relax
from tvm import tir
from tvm.relax.frontend import nn
from tvm.relax.transform import LegalizeOps, AnnotateTIROpPattern
import numpy as np
from tvm.relax.expr_functor import PyExprMutator
class NNModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(48, 64)
self.relu1 = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
return x
mod, param_spec = NNModule().export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 48), "float32")}}
)
pipeline = tvm.transform.Sequential([
LegalizeOps(),
AnnotateTIROpPattern(),
relax.transform.FuseOps(fuse_opt_level=4),
relax.transform.FuseTIR(),
])
mod_lowered = pipeline(mod)
sch = tvm.tir.Schedule(mod_lowered)
mod_lowered.show()
sch.work_on("fused_matmul_add_relu")
block_add = sch.get_block("T_add")
sch.compute_inline(block_add)
# move vector calculation after matrix
matmul_block = sch.get_block("matmul")
compute_block = sch.get_block("compute")
# tile
i0, i1, k = sch.get_loops(matmul_block)
i1_outer, i1_inner = sch.split(i1, factors=[None, 16])
k_outer, k_inner = sch.split(k, factors=[None, 16])
sch.reorder(i0, i1_outer, k_outer, k_inner, i1_inner)
# tile
i0_c, i1_c = sch.get_loops(compute_block)
i1_outer_c, i1_inner_c = sch.split(i1_c, factors=[None, 16])
sch.reverse_compute_at(compute_block, i1_outer)
sch.mod.show()
# add local/shared buffer
x_shared = sch.cache_read(matmul_block, 0, "shared") # x buffer
w_shared = sch.cache_read(matmul_block, 1, "shared") # weight buffer
matmul_local = sch.cache_write(matmul_block, 0, "local")
# 3.
compute_block = sch.get_block("compute")
# sch.cache_read(compute_block, 0, "local", [matmul_local])
bias_local = sch.cache_read(compute_block, 1, "local") # bias
result_local = sch.cache_write(compute_block, 0, "local")
# 4.
i0, i1_0, k_0, k_1, i1_1 = sch.get_loops(matmul_block)
sch.compute_at(x_shared, i1_0)
sch.compute_at(w_shared, i1_0)
sch.compute_at(bias_local, i1_0)
sch.mod.show()
```
the scheduled tir is as follows, and we can see
* matmul_intermediate[v0, v1] = matmul_intermediate_local[v0, v1]
```
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(48)),
"float32"), permute_dims: T.Buffer((T.int64(48), T.int64(64)), "float32"),
fc1_bias: T.Buffer((T.int64(64),), "float32"), compute_intermediate:
T.Buffer((T.int64(1), T.int64(64)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(64)))
x_shared = T.alloc_buffer((T.int64(1), T.int64(48)), scope="shared")
permute_dims_shared = T.alloc_buffer((T.int64(48), T.int64(64)),
scope="shared")
matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)),
scope="local")
fc1_bias_local = T.alloc_buffer((T.int64(64),), scope="local")
compute_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)),
scope="local")
for i0, i1_0 in T.grid(T.int64(1), T.int64(4)):
for ax0 in range(T.int64(48)):
with T.block("x_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(48), ax0)
T.reads(x[v0, v1])
T.writes(x_shared[v0, v1])
x_shared[v0, v1] = x[v0, v1]
for ax0, ax1 in T.grid(T.int64(48), T.int64(16)):
with T.block("permute_dims_shared"):
v0 = T.axis.spatial(T.int64(48), ax0)
v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
T.reads(permute_dims[v0, v1])
T.writes(permute_dims_shared[v0, v1])
permute_dims_shared[v0, v1] = permute_dims[v0, v1]
for k_0, k_1, i1_1 in T.grid(T.int64(3), T.int64(16), T.int64(16)):
with T.block("matmul"):
v_i0 = T.axis.spatial(T.int64(1), i0)
v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) +
i1_1)
v_k = T.axis.reduce(T.int64(48), k_0 * T.int64(16) + k_1)
T.reads(x_shared[v_i0, v_k], permute_dims_shared[v_k, v_i1])
T.writes(matmul_intermediate_local[v_i0, v_i1])
with T.init():
matmul_intermediate_local[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate_local[v_i0, v_i1] =
matmul_intermediate_local[v_i0, v_i1] + x_shared[v_i0, v_k] *
permute_dims_shared[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
with T.block("matmul_intermediate_local"):
v0 = T.axis.spatial(T.int64(1), i0 + ax0)
v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
T.reads(matmul_intermediate_local[v0, v1])
T.writes(matmul_intermediate[v0, v1])
matmul_intermediate[v0, v1] = matmul_intermediate_local[v0,
v1]
for ax0 in range(T.int64(16)):
with T.block("fc1.bias_local"):
v0 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
T.reads(fc1_bias[v0])
T.writes(fc1_bias_local[v0])
fc1_bias_local[v0] = fc1_bias[v0]
for ax0 in range(T.int64(16)):
with T.block("compute"):
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
T.reads(matmul_intermediate[v_i0, v_i1],
fc1_bias_local[v_i1])
T.writes(compute_intermediate_local[v_i0, v_i1])
compute_intermediate_local[v_i0, v_i1] =
T.max(matmul_intermediate[v_i0, v_i1] + fc1_bias_local[v_i1], T.float32(0.0))
for ax0, ax1 in T.grid(T.int64(1), T.int64(64)):
with T.block("compute_intermediate_local"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute_intermediate_local[v0, v1])
T.writes(compute_intermediate[v0, v1])
compute_intermediate[v0, v1] = compute_intermediate_local[v0,
v1]
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/how-to-keep-data-in-local-buffer-between-matrix-and-vector-ops-avoid-extra-global-memory-copies/18797/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/670da90b3807f3482c8fb0a24976279da2c37e261bceb3cea91d884863a5a6ec).