## Background and Motivation
TVM is an end-to-end deep learning compiler with two levels of IR and 
optimization. TVM translates popular DL frameworks into Relay and optimizes the 
computation graph, after which it lowers each graph node into Tensor 
Expression(TE) and does another function-level optimization before finally 
lowering it into TIR and generating backend code. In brief, the current 
workflow is
```
TF/PyTorch/ONNX -> Relay -> TE (based on TE schedule) -> TIR -> C++/CUDA
```
Currently, low-level optimization is done through TE scheduling, which has 
several limitations:
- Based on an accessory data structure: schedule tree. Schedule primitives 
operate on schedule tree rather than TIR itself directly, which makes the 
scheduling result less intuitive.
- All primitives are coupled on the schedule tree representation, it is not 
easy to add new primitives and ensure the correctness of transformation at the 
same time.
- Limited high-dimension instruction and tensorization support. TE is not 
schedulable after tensorization. The description of tensor intrinsics is not 
friendly to users.

## Introduction
TensorIR is a brand new low-level IR with full scheduling support. Here are 
some key features and novel techniques.

### Core data structure: Block and Block Realize
To generalize the high-dimension tensor expression, we introduce a new 
statement structure *Block*. A block can wrap any part of the IR to provide 
isolation. A block is *the minimal unit* for scheduling and tensorization, 
which stores all the fundamental information of computation including block 
iteration vars, the region that the block reads and writes, the buffer which 
allocated inside the block, and the critical body statement. Block declare the 
expression computation with its iteration vars and types.

### Workload Example
1. First of all, we will define the workload(gemm in the example) with hybrid. 
Note that the hybrid scipt direct generate a **TIR** program rather than TE 
stages. (We can auto-complete the loop nesting with block definition by default)
```python
@tvm.hybrid.script
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, (1024, 1024), "float32")
    A = tir.match_buffer(a, (1024, 1024), "float32")
    B = tir.match_buffer(b, (1024, 1024), "float32")
    reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))

    with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, vj, vk]:
        reducer.step(C[vi, vj], A[vi, vk] * B[vk, vj])
```

2. Then create the schedule from **TIR** and tile the loops using primitives.
```python
s = tir.create_schedule(matmul)
update = s.get_block("C")
i, j, k = s.get_axes(update)
i_o, i_i = s.split(i, bn)
j_o, j_i = s.split(j, bn)
k_o, k_i = s.split(k, 4)
s.reorder(i_o, j_o, k_o, k_i, i_i, j_i)
```
Result
```python
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # function attr dict
    B = tir.match_buffer(b, [1024, 1024])
    A = tir.match_buffer(a, [1024, 1024])
    C = tir.match_buffer(c, [1024, 1024])
    reducer = tir.comm_reducer(lambda x, y: x + y, tir.float32(0))
    # body
    for i0_outer, i1_outer, i2_outer, i2_inner, i0_inner, i1_inner in 
tir.grid(32, 32, 256, 4, 32, 32):
         with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "C") as [vi, 
vj, vk]:
             tir.bind(vi, ((i0_outer*32) + i0_inner))
             tir.bind(vj, ((i1_outer*32) + i1_inner))
             tir.bind(vk, ((i2_outer*4) + i2_inner))
             reducer.step(C[vi, vj], (A[vi, vk]*B[vk, vj]))
```

3. Vectorize and decompose reduction
```python
s.vectorize(j_i)
s.decompose_reduction(update, j_o)
```
Result
```python
@tvm.hybrid.script
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # function attr dict
    B = tir.match_buffer(b, [1024, 1024])
    A = tir.match_buffer(a, [1024, 1024])
    C = tir.match_buffer(c, [1024, 1024])
    # body
    for i0_outer in, i1_outer_init, i0_inner_init in tir.grid(32, 32, 32):
        for i1_inner_init in range(0, 32, annotation = 
{"loop_type":"vectorize"}):
            with tir.block([1024, 1024], "C_init") as [vi_init, vj_init]:
                tir.bind(vi_init, ((i0_outer*32) + i0_inner_init))
                tir.bind(vj_init, ((i1_outer_init*32) + i1_inner_init))
                C[vi_init, vj_init] = tir.float32(0)
        for i1_outer, i2_outer, i2_inner, i0_inner in tir.grid(32, 256, 4, 32):
            for i1_inner in range(0, 32, annotation = 
{"loop_type":"vectorize"}):
                with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], 
"C_update") as [vi, vj, vk]:
                    tir.bind(vi, ((i0_outer*32) + i0_inner))
                    tir.bind(vj, ((i1_outer*32) + i1_inner))
                    tir.bind(vk, ((i2_outer*4) + i2_inner))
                    C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vk, vj]))
```

4. Print it and run
```python
build_func = tvm.build(s.func, target=target)
build_func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.matmul(a.asnumpy(), b.asnumpy()), 
rtol=1e-5)
evaluator = build_func.time_evaluator(build_func.entry_name, ctx, number=1)
```

In this example, we are imperatively changing the IR rather than waiting until 
the end to change(TE). It is very important for users and developers to 
directly see what happens during the scheduling. Also, at every stage during 
the schedule, we can get a *verifiable* IR, that's the major improvement for 
both user experience and correctness proof.

## Key Features
### Independent scheduling based on IR itself
Different from the TE schedule, TensorIR has a complete set of schedule 
algorithms, which does not need a schedule tree or any extra data structure. We 
will introduce a brand new set of schedule primitives and it has full backward 
compatibility for the TE schedule.
We simplify the compiling workload and conception. 
```
TF/PyTorch/ONNX -> Relay -> TIR -> schedule -> TIR -> scheudle -> TIR -> 
C++/CUDA
```
Now, there is no stage during the schedule. Rather than lowering the schedule 
into TIR, we directly mutate the TIR itself. Also, it enables the sequential 
schedule (schedule several times for a single workload). 


### Stronger Expressiveness and Optimization Ability
TE has limited expressiveness since each stage is defined by `Stage = 
te.compute(lambda expr)`, while TensorIR is a full c++-like IR. We can write 
any program with TensorIR as you want. Although not all programs can be 
scheduled, there are still more workloads that can be optimized by TensorIR.

One of the improved tasks is concatenating: 
#### TE
```python
B = te.compute(i, tvm.tir.if_then_else(i < 10, A0[i], tvm.tir.if_then_else(i < 
20, A1[i - 10], A2[i - 20])
```
#### TensorIR:
```python
with tir.block([10]) as vi:
    B[vi] = A0[vi]
with tir.block([10]) as vi:
    B[vi + 10] = A1[vi]
with tir.block([10]) as vi:
    B[vi + 20] = A2[vi]
```
The critical improvement is performance. In TIR we optimize the program by 
deprecating the `if` branch, which is impossible in the TE schedule.

### Memory and Execution Scope
Hardware accelerators led by GPUs are increasingly using hierarchical 
architecture, including memory hierarchy(global, shared, local/wmma in NV-GPU) 
and execution hierarchy(SM, warp, thread in NV-GPU). TVM defines the memory 
hierarchy and TensorIR provides corresponding native hierarchical block and 
execution scope. Different execution scope can access different memory scope. 

TensorIR natively supports hierarchy checks. We will check the memory access 
and thread binding, including warp level instruction(wmma) validation during 
the schedule. Following is an example of the GPU hierarchy.

```python
for bx in range(0, 32, annotation = {"loop_type":"blockIdx.x"}):
    with block(exec_scope="gpu_block"):
        for ty in range(0, 32, annotation = {"loop_type":"threadIdx.y"}):
            with block(exec_scope="gpu_warp"):
                for ty in range(0, 32, annotation = 
{"loop_type":"threadIdx.x"}):
                    with block(exec_scope="gpu_thread"):
                        A[i] = B[i] + 1
```

### High-dimension Scheduling and Tensorization.
With more and more backend provides tensor operator and instruction, the 
limitation of the TE schedule shows off. It is hard to tensorize a complex 
stage. TensorIR chooses a block (a wrap of high-dimension computation) as the 
minimum schedule unit. We can natively tensorize a sub-program and even do 
schedule after tensorization.

### Decouple Primitives and Correctness Proof
Since every primitive directly rewrite the AST (it works like a `StmtMutator`), 
we easily decouple primitives. Developers can add a schedule primitive as easy 
as add a pass. Also, it is easy to ensure the program has the same behavior 
between origin one and scheduled one by proving the correctness of each 
primitive and checking intermedia IR during the schedule.

### Round-trip Python Hybrid Syntax (Already upstream)
Hybrid enables developers to write TensorIR through python syntax, one of the 
most popular languages. Also, it provides a way to store the IR after or during 
the schedule. Please see the detail by 
https://discuss.tvm.apache.org/t/rfc-hybrid-script-support-for-tir/7516


## Migration Plan
1. Upstream TensorIR data structure.
1. Upstream TensorIR schedule primitive
1. Support AutoTVM/Ansor on TensorIR

Co-author @spectrometerHBH @tqchen @junrushao1994 

I appreciate discuss and advise from @merrymercy





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/2fcebc91faa49032949e50d0909b466757f17f8503246f805d57d4a6782822d6).

Reply via email to