Writing out some of my thoughts, to see if there's a way to express the
constraints while only using existing TIR features. The main goals would be as
follows.
1. Allow simplification of expressions based on the values present in the
padding.
2. Allow local simplifications to take advantage of non-local constraints,
without requiring a full end-to-end analysis.
3. Specify the non-local constraints in some deducible manner that doesn't
impose a runtime performance penalty.
Next, working through various options for how the constraints could be stored.
In the examples below, sketching out how these would apply to the element-wise
operation which starts as below.
```python
@T.prim_func
def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
for i in T.serial(14):
B[i] = 2 * A[i]
```
1. Apply layout transforms on local caches. Here, the full lifetime of a
buffer is known. All TIR optimization are done prior to hoisting the cache and
layout transformation into the graph level.
- For read caches, pad value is whatever gets conditionally written to the
padding while generating it. In example below, `AC` could be recognized as
being padded.
```python
@T.prim_func
def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
AC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
AC[io, ii] = A[4 * io + ii]
else:
AC[io, ii] = 0
for i in T.serial(14):
B[i] = 2 * AC[i // 4, i % 4]
```
- For write caches, pad value is whatever is in the padding after the last
write to the cache. In example below, `BC` could be recognized as being padded.
```python
@T.prim_func
def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
BC = T.alloc_buffer([4, 4], "int32")
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
BC[io, ii] = 2 * A[4*io + ii]
else:
BC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[i] = BC[io, ii]
```
- Downside, either of the `else` statements could be eliminated as a no-op,
since they don't contribute to the output `B` value. After that elimination,
there wouldn't be any way to reconstruct the pad value.
2. When hoisting an allocation+transformation, write the pad value to the
buffer at the start of function from which it was hoisted. This way, the pad
value can still be used in local reasoning.
- No change needed in producers, since they would already write the pad
value to the buffer.
- For consumers, would be represented as writing `pad_value` into the
padding at the start of the function.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, repeated unnecessary effort at the beginning of each consumer.
Avoiding it with this representation would require knowing that the producer
had written `pad_value` already, which is exactly the information we're trying
to avoid.
3. When hoisting an allocation+transformation, write the pad value to the
buffer at the start of function from which it was hoisted, and write
`T.undef()` at the end. This way, the pad value can still be used in local
reasoning, and no-op removal can remove the repeated writing when lowering.
- No change needed in producers, since they would already write the pad
value to the buffer.
- For consumers, would be like option 2, but with an additional write of
`T.undef()` at the end of the function. When lowering, the write of
`T.undef()` would allow the first write to be removed as a no-op because it is
overwritten. The `T.undef()` can then be removed as described in the RFC.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = 0
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
AC[io, ii] = T.undef()
```
- Downside, no way to distinguish between "can assume the pad value is zero"
and "can overwrite the pad value at will". The writing of `T.undef()` would
allow any writes to the padding to be inserted as a no-op.
- Downside, wouldn't actually simplify out in cases where the pad value is
used. The first in a pair of repeated writes to the same location can only be
removed if there are no reads between the writes. After using the pad value to
eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes
the padding could no longer be removed.
4. Use `AssertStmt` in a loop to declare known information about the buffers.
- No change needed in producers, since the pad value is already written out.
- For consumers, would have an initial loop that asserts the pad value is
correct.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
for io, ii in T.grid(4, 4):
if 4 * io + ii >= 14:
assert AC[io, ii] == 0, "padding"
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, assert statements have target-dependent handling. In
`CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops. In `CodeGenCPU`
and `CodeGenC`, they generate asserts. In `CodeGenCUDA`, they aren't handled
at all and would error out.
Could work around this with a lowering pass, but identifying these
conditions would require having a special string in the message, and packing
structured data into strings makes me wary.
5. Use `AssertStmt` with implicitly-defined variables to declare known
information about the buffers.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
a = T.var("int32")
b = T.var("int32")
assert (
AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b < 0)
or (b >= 4)
), "padding"
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Can apply to clamped texture memory, since the variables in the assertion
isn't restricted to the bounds.
- Would need to recognize specific pattern of `BufferLoad` being used to
define variables used in constraint.
- The implicitly-defined variables can be written in current TIR, but
variables would ensure that this isn't something that ever makes it into
generated code at runtime.
- Downside, implicitly-defined variables are something of a red flag.
6. Store constraints in the function attributes, either as a dictionary or as a
structured object.
```python
@T.prim_func
def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
T.func_attr(
"buffer_constraints",
[
{
"buffer": AC,
"predicate": lambda io, ii: 4 * io + ii < 14,
"pad_value": lambda io, ii: 0,
},
],
)
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, requires transformations that change a buffer to be aware that
other structures will also need to be replaced.
- Downside, requires simplifications to either be passed the entire
`PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list.
- Downside, would break expectations of `IRMutatorWithAnalyzer`. The current
entry point of any `Stmt` or `Expr` would need to have additional information
of the `"buffer_constraints"`.
7. Store constraints in the `Buffer` object, either as a dictionary or as a
structured object.
```python
@T.prim_func
def func(ac: T.handle, B: T.Buffer[14, "int32"]):
AC = T.match_buffer(
shape=(4, 4),
dtype="int32",
constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io +
ii < 14, pad_value=0)],
)
for io, ii in T.grid(4, 4):
if 4 * io + ii < 14:
B[4 * io + ii] = 2 * AC[io, ii]
```
- Downside, introduces additional data structure in TIR.
--
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046
You are receiving this because you are subscribed to this thread.
Message ID: <apache/tvm-rfcs/pull/77/[email protected]>