Hi everyone,
    
I’m working on a custom accelerator backend for TVM and ran into a scheduling / 
memory issue that I’m not sure how to model correctly.

#### Hardware background

On our hardware we have:

* A **matrix compute unit**
* A **vector compute unit**
* A **shared on-chip buffer** (local buffer) that both units can access

In a tiled **matrix–vector** computation, the idea is:

1. The matrix unit computes one **tile** of data and writes it into the local 
buffer.
2. The vector unit immediately consumes that same tile from the **same local 
buffer**, without going through global memory.

So in hardware there is **no local → global → local roundtrip** for that 
intermediate tile.
#### What I see in TVM

In my current TVM schedule / lowering, the behavior looks more like this:

1. Matrix compute writes the tile to a **local buffer**.
2. TVM then copies that local buffer **back to global memory**.
3. Before the vector compute stage, TVM **copies the data from global to 
another local buffer** for the vector unit to use.

In other words, the intermediate result always gets stored to global memory and 
then reloaded, even though in our hardware a single shared local buffer is 
enough.

```
import tvm
from tvm import relax
from tvm import tir
from tvm.relax.frontend import nn
from tvm.relax.transform import LegalizeOps, AnnotateTIROpPattern
import numpy as np
from tvm.relax.expr_functor import PyExprMutator

class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(48, 64)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        return x

mod, param_spec = NNModule().export_tvm(
    spec={"forward": {"x": nn.spec.Tensor((1, 48), "float32")}}
)

pipeline = tvm.transform.Sequential([
    LegalizeOps(),           
    AnnotateTIROpPattern(),  
    relax.transform.FuseOps(fuse_opt_level=4), 
    relax.transform.FuseTIR(),
])

mod_lowered = pipeline(mod)
sch = tvm.tir.Schedule(mod_lowered)
mod_lowered.show()

sch.work_on("fused_matmul_add_relu")
block_add = sch.get_block("T_add")
sch.compute_inline(block_add)

# move vector calculation after matrix
matmul_block = sch.get_block("matmul")
compute_block = sch.get_block("compute")
# tile
i0, i1, k = sch.get_loops(matmul_block)
i1_outer, i1_inner = sch.split(i1, factors=[None, 16])
k_outer, k_inner = sch.split(k, factors=[None, 16])
sch.reorder(i0, i1_outer, k_outer, k_inner, i1_inner)
# tile
i0_c, i1_c = sch.get_loops(compute_block)
i1_outer_c, i1_inner_c = sch.split(i1_c, factors=[None, 16])
sch.reverse_compute_at(compute_block, i1_outer)
sch.mod.show()


# add local/shared buffer
x_shared = sch.cache_read(matmul_block, 0, "shared")  # x buffer
w_shared = sch.cache_read(matmul_block, 1, "shared")  # weight buffer
matmul_local = sch.cache_write(matmul_block, 0, "local")

# 3. 
compute_block = sch.get_block("compute")
# sch.cache_read(compute_block, 0, "local", [matmul_local]) 
bias_local = sch.cache_read(compute_block, 1, "local")   # bias
result_local = sch.cache_write(compute_block, 0, "local") 

# 4. 
i0, i1_0, k_0, k_1, i1_1 = sch.get_loops(matmul_block)
sch.compute_at(x_shared, i1_0) 
sch.compute_at(w_shared, i1_0) 

sch.compute_at(bias_local, i1_0)  

sch.mod.show()


```
the scheduled tir is as follows, and we can see 
* matmul_intermediate[v0, v1] = matmul_intermediate_local[v0, v1]

```
@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(48)), 
"float32"), permute_dims: T.Buffer((T.int64(48), T.int64(64)), "float32"), 
fc1_bias: T.Buffer((T.int64(64),), "float32"), compute_intermediate: 
T.Buffer((T.int64(1), T.int64(64)), "float32")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(64)))
        x_shared = T.alloc_buffer((T.int64(1), T.int64(48)), scope="shared")
        permute_dims_shared = T.alloc_buffer((T.int64(48), T.int64(64)), 
scope="shared")
        matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)), 
scope="local")
        fc1_bias_local = T.alloc_buffer((T.int64(64),), scope="local")
        compute_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(64)), 
scope="local")
        for i0, i1_0 in T.grid(T.int64(1), T.int64(4)):
            for ax0 in range(T.int64(48)):
                with T.block("x_shared"):
                    v0 = T.axis.spatial(T.int64(1), T.int64(0))
                    v1 = T.axis.spatial(T.int64(48), ax0)
                    T.reads(x[v0, v1])
                    T.writes(x_shared[v0, v1])
                    x_shared[v0, v1] = x[v0, v1]
            for ax0, ax1 in T.grid(T.int64(48), T.int64(16)):
                with T.block("permute_dims_shared"):
                    v0 = T.axis.spatial(T.int64(48), ax0)
                    v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
                    T.reads(permute_dims[v0, v1])
                    T.writes(permute_dims_shared[v0, v1])
                    permute_dims_shared[v0, v1] = permute_dims[v0, v1]
            for k_0, k_1, i1_1 in T.grid(T.int64(3), T.int64(16), T.int64(16)):
                with T.block("matmul"):
                    v_i0 = T.axis.spatial(T.int64(1), i0)
                    v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + 
i1_1)
                    v_k = T.axis.reduce(T.int64(48), k_0 * T.int64(16) + k_1)
                    T.reads(x_shared[v_i0, v_k], permute_dims_shared[v_k, v_i1])
                    T.writes(matmul_intermediate_local[v_i0, v_i1])
                    with T.init():
                        matmul_intermediate_local[v_i0, v_i1] = T.float32(0.0)
                    matmul_intermediate_local[v_i0, v_i1] = 
matmul_intermediate_local[v_i0, v_i1] + x_shared[v_i0, v_k] * 
permute_dims_shared[v_k, v_i1]
            for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
                with T.block("matmul_intermediate_local"):
                    v0 = T.axis.spatial(T.int64(1), i0 + ax0)
                    v1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax1)
                    T.reads(matmul_intermediate_local[v0, v1])
                    T.writes(matmul_intermediate[v0, v1])
                    matmul_intermediate[v0, v1] = matmul_intermediate_local[v0, 
v1]
            for ax0 in range(T.int64(16)):
                with T.block("fc1.bias_local"):
                    v0 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
                    T.reads(fc1_bias[v0])
                    T.writes(fc1_bias_local[v0])
                    fc1_bias_local[v0] = fc1_bias[v0]
            for ax0 in range(T.int64(16)):
                with T.block("compute"):
                    v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
                    v_i1 = T.axis.spatial(T.int64(64), i1_0 * T.int64(16) + ax0)
                    T.reads(matmul_intermediate[v_i0, v_i1], 
fc1_bias_local[v_i1])
                    T.writes(compute_intermediate_local[v_i0, v_i1])
                    compute_intermediate_local[v_i0, v_i1] = 
T.max(matmul_intermediate[v_i0, v_i1] + fc1_bias_local[v_i1], T.float32(0.0))
        for ax0, ax1 in T.grid(T.int64(1), T.int64(64)):
            with T.block("compute_intermediate_local"):
                v0, v1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(compute_intermediate_local[v0, v1])
                T.writes(compute_intermediate[v0, v1])
                compute_intermediate[v0, v1] = compute_intermediate_local[v0, 
v1]
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/how-to-keep-data-in-local-buffer-between-matrix-and-vector-ops-avoid-extra-global-memory-copies/18797/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/670da90b3807f3482c8fb0a24976279da2c37e261bceb3cea91d884863a5a6ec).

Reply via email to