You can first optimize the TIR primitive function independently using
`ms.tir_integration.tune_tir` , then merge the optimized version back into the
high-level Relax module.
```py
import os
import tvm
import numpy as np
import tvm.meta_schedule as ms
from tvm import relax
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I
@I.ir_module
class Module:
@T.prim_func
def dense_loop(
VAL: T.handle,
VEC: T.handle,
OUT: T.handle,
):
val = T.match_buffer(VAL, (37,), "float64")
vec = T.match_buffer(VEC, (11,), "float64")
out = T.match_buffer(OUT, (11,), "float64")
for j in T.serial(2):
for i in T.serial(2):
with T.block("db0"):
T.init()
out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
for j in T.serial(1):
for i in T.serial(2):
with T.block("db1"):
T.init()
out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
for j in T.serial(3):
for i in T.serial(3):
with T.block("db3"):
T.init()
out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
for j in T.serial(2):
for i in T.serial(1):
with T.block("db5"):
T.init()
out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
for j in T.serial(3):
for i in T.serial(1):
with T.block("db8"):
T.init()
out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]
@R.function
def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",),
dtype="float64")):
cls = Module
out = R.call_tir(cls.dense_loop, (val, vec),
out_sinfo=R.Tensor((11,), dtype="float64"))
return out
mod = Module
# The TIR function that will be tuned
dense_loop_tir = mod["dense_loop"]
target = tvm.target.Target("llvm -num-cores=1")
this_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(this_dir, "tuning_logs")
# Tune the TIR function
database = ms.tir_integration.tune_tir(
mod=dense_loop_tir,
target=target,
work_dir=work_dir,
max_trials_global=64,
num_trials_per_iter=16,
)
if database is None:
raise ValueError("Database is None!")
# Compile the TIR function with the tuned database to a tir.Schedule
sch = ms.tir_integration.compile_tir(database, dense_loop_tir, target)
if sch is None:
print("No valid schedule found!")
else:
sch.mod.show()
# Replace the optimized TIR Prim func back into the original module.
# In `sch.mod`, the optimized TIR Prim func is stored as
`sch.mod["main"]`.
optimized_tir = sch.mod["main"]
new_mod = tvm.IRModule({"dense_loop": optimized_tir, "main": mod["main"]})
new_mod.show()
# Build new module
new_mod = relax.transform.LegalizeOps()(new_mod)
ex = relax.build(new_mod, target=target)
vm = relax.VirtualMachine(ex, tvm.cpu())
# Prepare Data
val_np = np.random.rand(37).astype("float64")
vec_np = np.random.rand(11).astype("float64")
val_tvm = tvm.nd.array(val_np, device=tvm.cpu())
vec_tvm = tvm.nd.array(vec_np, device=tvm.cpu())
# Execute
output_tvm = vm["main"](val_tvm, vec_tvm)
# Output
output_np = output_tvm.numpy()
print("Output shape:", output_np.shape) # out_sinfo: (11,)
print("Output values:", output_np)
```
The output:
```py
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def dense_loop(val: T.Buffer((37,), "float64"), vec: T.Buffer((11,),
"float64"), out: T.Buffer((11,), "float64")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db0"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 4)
T.reads(val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 +
(j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[(j_i_fused_0 * 64 + j_i_fused_1) //
2])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] =
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[(j_i_fused_0 * 64 +
j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2] *
vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db1"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
T.reads(val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 +
j_i_fused_1) % 2], vec[T.Add(0, 5)])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] =
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[4 + T.Mul(0, 2) + (j_i_fused_0
* 64 + j_i_fused_1) % 2] * vec[T.Add(0, 5)]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db3"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 9)
T.reads(val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 +
(j_i_fused_0 * 64 + j_i_fused_1) % 3], vec[(j_i_fused_0 * 64 + j_i_fused_1) //
3 + 2])
T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2])
with T.init():
out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] =
out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] + val[6 + (j_i_fused_0 * 64 +
j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3] *
vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db5"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
T.reads(val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1),
0)], vec[j_i_fused_0 * 64 + j_i_fused_1])
T.writes(out[T.Add(0, 5)])
with T.init():
out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(15 +
(j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1]
T.evaluate(0)
for j_i_fused_0 in T.parallel(1):
for j_i_fused_1 in T.vectorized(64):
with T.block("db8"):
T.where(j_i_fused_0 * 64 + j_i_fused_1 < 3)
T.reads(val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1),
0)], vec[j_i_fused_0 * 64 + j_i_fused_1 + 6])
T.writes(out[T.Add(0, 5)])
with T.init():
out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(21 +
(j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1 + 6]
T.evaluate(0)
@R.function
def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",),
dtype="float64")) -> R.Tensor((11,), dtype="float64"):
v = T.int64()
k = T.int64()
out = R.call_tir(dense_loop, (val, vec), out_sinfo=R.Tensor((11,),
dtype="float64"))
return out
Output shape: (11,)
Output values: [1.56512320e-001 2.21357108e-001 1.09694842e+000 7.23903158e-001
1.02699930e+000 1.66537513e+000 0.00000000e+000 4.64222119e-310
4.24399158e-314 0.00000000e+000 0.00000000e+000]
```
As another optimization strategy, the Relax module can be tuned end-to-end
using the methodology demonstrated here:
https://github.com/apache/tvm/blob/7707496a6601796393557087d2bef3d2f5513a34/docs/how_to/tutorials/e2e_opt_model.py#L87-L107
---
[Visit
Topic](https://discuss.tvm.apache.org/t/why-does-meta-schedule-not-find-any-schedules/18468/2)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/706c9893f266c229ced09438422ac30c65355b603051f17edeae3220a152e909).