You can first optimize the TIR primitive function independently using 
`ms.tir_integration.tune_tir` , then merge the optimized version back into the 
high-level Relax module.

  ```py
  import os
  import tvm
  import numpy as np
  import tvm.meta_schedule as ms
  
  from tvm import relax
  from tvm.script import relax as R
  from tvm.script import tir as T
  from tvm.script import ir as I
  
  
  @I.ir_module
  class Module:
      @T.prim_func
      def dense_loop(
          VAL: T.handle,
          VEC: T.handle,
          OUT: T.handle,
      ):
          val = T.match_buffer(VAL, (37,), "float64")
          vec = T.match_buffer(VEC, (11,), "float64")
          out = T.match_buffer(OUT, (11,), "float64")
  
          for j in T.serial(2):
              for i in T.serial(2):
                  with T.block("db0"):
                      T.init()
                      out[i + 0] += val[0 + j * 2 + i] * vec[j + 0]
          for j in T.serial(1):
              for i in T.serial(2):
                  with T.block("db1"):
                      T.init()
                      out[i + 0] += val[4 + j * 2 + i] * vec[j + 5]
          for j in T.serial(3):
              for i in T.serial(3):
                  with T.block("db3"):
                      T.init()
                      out[i + 2] += val[6 + j * 3 + i] * vec[j + 2]
          for j in T.serial(2):
              for i in T.serial(1):
                  with T.block("db5"):
                      T.init()
                      out[i + 5] += val[15 + j * 1 + i] * vec[j + 0]
          for j in T.serial(3):
              for i in T.serial(1):
                  with T.block("db8"):
                      T.init()
                      out[i + 5] += val[21 + j * 1 + i] * vec[j + 6]
  
      @R.function
      def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), 
dtype="float64")):
          cls = Module
          out = R.call_tir(cls.dense_loop, (val, vec), 
out_sinfo=R.Tensor((11,), dtype="float64"))
          return out
  
  
  mod = Module
  # The TIR function that will be tuned
  dense_loop_tir = mod["dense_loop"]
  
  target = tvm.target.Target("llvm -num-cores=1")
  this_dir = os.path.dirname(os.path.abspath(__file__))
  work_dir = os.path.join(this_dir, "tuning_logs")
  
  # Tune the TIR function
  database = ms.tir_integration.tune_tir(
      mod=dense_loop_tir,
      target=target,
      work_dir=work_dir,
      max_trials_global=64,
      num_trials_per_iter=16,
  )
  if database is None:
      raise ValueError("Database is None!")
  
  # Compile the TIR function with the tuned database to a tir.Schedule
  sch = ms.tir_integration.compile_tir(database, dense_loop_tir, target)
  if sch is None:
      print("No valid schedule found!")
  else:
      sch.mod.show()
  
      # Replace the optimized TIR Prim func back into the original module.
      # In `sch.mod`, the optimized TIR Prim func is stored as 
`sch.mod["main"]`.
      optimized_tir = sch.mod["main"]
      new_mod = tvm.IRModule({"dense_loop": optimized_tir, "main": mod["main"]})
      new_mod.show()
  
      # Build new module
      new_mod = relax.transform.LegalizeOps()(new_mod)
      ex = relax.build(new_mod, target=target)
      vm = relax.VirtualMachine(ex, tvm.cpu())
  
      # Prepare Data
      val_np = np.random.rand(37).astype("float64")
      vec_np = np.random.rand(11).astype("float64")
      val_tvm = tvm.nd.array(val_np, device=tvm.cpu())
      vec_tvm = tvm.nd.array(vec_np, device=tvm.cpu())
  
      # Execute
      output_tvm = vm["main"](val_tvm, vec_tvm)
  
      # Output
      output_np = output_tvm.numpy()
      print("Output shape:", output_np.shape)  # out_sinfo: (11,)
      print("Output values:", output_np)
  ```

The output:
```py
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def dense_loop(val: T.Buffer((37,), "float64"), vec: T.Buffer((11,), 
"float64"), out: T.Buffer((11,), "float64")):
        T.func_attr({"tir.noalias": True})
        # with T.block("root"):
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db0"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 4)
                    T.reads(val[(j_i_fused_0 * 64 + j_i_fused_1) // 2 * 2 + 
(j_i_fused_0 * 64 + j_i_fused_1) % 2], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 
2])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = 
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[(j_i_fused_0 * 64 + 
j_i_fused_1) // 2 * 2 + (j_i_fused_0 * 64 + j_i_fused_1) % 2] * 
vec[(j_i_fused_0 * 64 + j_i_fused_1) // 2]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db1"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
                    T.reads(val[4 + T.Mul(0, 2) + (j_i_fused_0 * 64 + 
j_i_fused_1) % 2], vec[T.Add(0, 5)])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] = 
out[(j_i_fused_0 * 64 + j_i_fused_1) % 2] + val[4 + T.Mul(0, 2) + (j_i_fused_0 
* 64 + j_i_fused_1) % 2] * vec[T.Add(0, 5)]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db3"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 9)
                    T.reads(val[6 + (j_i_fused_0 * 64 + j_i_fused_1) // 3 * 3 + 
(j_i_fused_0 * 64 + j_i_fused_1) % 3], vec[(j_i_fused_0 * 64 + j_i_fused_1) // 
3 + 2])
                    T.writes(out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2])
                    with T.init():
                        out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] = 
out[(j_i_fused_0 * 64 + j_i_fused_1) % 3 + 2] + val[6 + (j_i_fused_0 * 64 + 
j_i_fused_1) // 3 * 3 + (j_i_fused_0 * 64 + j_i_fused_1) % 3] * 
vec[(j_i_fused_0 * 64 + j_i_fused_1) // 3 + 2]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db5"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 2)
                    T.reads(val[T.Add(15 + (j_i_fused_0 * 64 + j_i_fused_1), 
0)], vec[j_i_fused_0 * 64 + j_i_fused_1])
                    T.writes(out[T.Add(0, 5)])
                    with T.init():
                        out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(15 + 
(j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1]
                    T.evaluate(0)
        for j_i_fused_0 in T.parallel(1):
            for j_i_fused_1 in T.vectorized(64):
                with T.block("db8"):
                    T.where(j_i_fused_0 * 64 + j_i_fused_1 < 3)
                    T.reads(val[T.Add(21 + (j_i_fused_0 * 64 + j_i_fused_1), 
0)], vec[j_i_fused_0 * 64 + j_i_fused_1 + 6])
                    T.writes(out[T.Add(0, 5)])
                    with T.init():
                        out[T.Add(0, 5)] = out[T.Add(0, 5)] + val[T.Add(21 + 
(j_i_fused_0 * 64 + j_i_fused_1), 0)] * vec[j_i_fused_0 * 64 + j_i_fused_1 + 6]
                    T.evaluate(0)

    @R.function
    def main(val: R.Tensor(("v",), dtype="float64"), vec: R.Tensor(("k",), 
dtype="float64")) -> R.Tensor((11,), dtype="float64"):
        v = T.int64()
        k = T.int64()
        out = R.call_tir(dense_loop, (val, vec), out_sinfo=R.Tensor((11,), 
dtype="float64"))
        return out

Output shape: (11,)
Output values: [1.56512320e-001 2.21357108e-001 1.09694842e+000 7.23903158e-001
 1.02699930e+000 1.66537513e+000 0.00000000e+000 4.64222119e-310
 4.24399158e-314 0.00000000e+000 0.00000000e+000]
```

As another optimization strategy, the Relax module can be tuned end-to-end 
using the methodology demonstrated here:
https://github.com/apache/tvm/blob/7707496a6601796393557087d2bef3d2f5513a34/docs/how_to/tutorials/e2e_opt_model.py#L87-L107





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/why-does-meta-schedule-not-find-any-schedules/18468/2)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/706c9893f266c229ced09438422ac30c65355b603051f17edeae3220a152e909).

Reply via email to