Is this what you're looking for ?
``` @I.ir_module class InputModule: @R.function def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): z = R.add(x, y) return z lib = tvm.relax.build(InputModule, target="llvm") print(lib.mod.imported_modules[0].get_source()) ``` You can also inspect C code by lowering to C ``` lib = tvm.relax.build(InputModule, target="c") print(lib.mod.imported_modules[0].get_source()) ``` Here, you have to lower Relax IR before generating LLVM, through `tvm.relax.build` that gives a `VMExecutable` of which you retrieve the module. <br/> It's usually simpler to get llvm code from TIR (that you can get from lowering Relax to TIR, for example with `tvm.relax.transform.LegalizeOps`) ``` @I.ir_module class TirModule: @T.prim_func def add(x: T.Buffer((T.int64(3), T.int64(4)), "float32"), y: T.Buffer((T.int64(3), T.int64(4)), "float32"), T_add: T.Buffer((T.int64(3), T.int64(4)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(4)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(x[v_ax0, v_ax1], y[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + y[v_ax0, v_ax1] print(tvm.build(TirModule, target="c").get_source()) # or target="llvm" of course ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/tvm0-20-0-relax-how-to-codegen-llvm-ir-with-relax/18409/3) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/eb80ceb437b5d4f614d41723c38ea9eefa703e64bbfb22e149945e0e63f9b67a).