Hi all, we sent out an RFC for TVMScript Metaprogramming at https://github.com/apache/tvm-rfcs/pull/79. Below is a section quoted from the RFC for better visibility, describing what metaprogramming features we want to support in this RFC. Would love to hear about your thoughts and feedback!
[quote="yelite, post:1, topic:12969, full:true"] ### (F1) Template Metaprogramming Users should be able to use variables from outer scope in the TVMScript function/class. The parsed result should be identical to function/class with the variable replaced by its value. For instance, ```python @T.prim_func def matmul( A: T.Buffer[(128, 128)], ) -> None: ... def gen_matmul(n, m) -> None: @T.prim_func def f(A: T.Buffer[(n, m)]): ... return f f = matmul(n=128, m=128) # `f` should be identical to `matmul` ``` This is already partially supported by https://github.com/apache/tvm/pull/11097 for using `PrimExpr` captured by outer function. With the new parser, we want to support this feature in more places and with more variable types. ### (F2) Rank-polymorphism Users should be able to write a single function to handle different ranks of input buffers (different numbers of dimensions). For example, user should be able to write a generic function to do broadcast add, ```python def broadcast_add(a, b, c): @T.prim_func def f( A: T.BufferFrom(a), B: T.BufferFrom(b), C: T.BufferFrom(c), ) -> None: for i, i_a, i_b in T.some_broadcast_method(A.shape, B.shape): with T.block(): C[*i] = A[*i_a] + B[*i_b] broadcast_add( a = Buffer((128, 1), "float32"), b = Buffer((1, 128), "float32"), c = Buffer((128, 128), "float32"), ) ``` ### (F3) Sugar: TE Compute in TIR Users should be able to replace boilerplate code with a function call, which’s expanded to large chunk of code during parsing. For example, we may want to use TE’s compute-like syntax to replace nested loop, ```python @T.prim_func def te_compute_sugar( A: T.Buffer[(128, 128)], B: T.Buffer[(128, 128)], ) -> None: ... C = T.compute((128, 128), lambda i, j: A[i, j] + B[i, j]) ... ## expands to ====> @T.prim_func def te_compute_expanded( A: T.Buffer[(128, 128)], B: T.Buffer[(128, 128)], ) -> None: ... for i in range(128): for j in range(128): with T.block("..."): C[i, j] = A[i, j] + B[i, j] ... ``` ### (F4) Interleave host program and TVMScript program to customize metaprogramming As an escape hatch from writing code to be parsed (or evaluated) by TVMScript parser, users should be able to write imperative code to construct IR nodes directly and embed it inside regular TVMScript. This gives users the ultimate tool when TVMScript isn’t expressible enough for their use cases. For example, at [python/tvm/topi/vision/nms.py#L380-L431](https://github.com/apache/tvm/blob/3cb4597ed48360e3f3d80161d1c03f833072d28e/python/tvm/topi/vision/nms.py#L380-L431), there are blocks of repetitive code on computing the coordinates of the four corners of bounding box. This can be simplified as: ```python # Before, without IRBuilder interleaving @T.prim_func def nms(...): ... for i in range(batch_size): ... a_l = min( output[batch_idx, box_a_idx, box_start_idx], output[batch_idx, box_a_idx, box_start_idx + 2], ) a_t = min( output[batch_idx, box_a_idx, box_start_idx + 1], output[batch_idx, box_a_idx, box_start_idx + 3], ) a_r = max( output[batch_idx, box_a_idx, box_start_idx], output[batch_idx, box_a_idx, box_start_idx + 2], ) a_b = max( output[batch_idx, box_a_idx, box_start_idx + 1], output[batch_idx, box_a_idx, box_start_idx + 3], ) ... for k in range(j): check_iou = ... ... if check_iou > 0: # b_l: left, b_t: top, b_r: right, b_b: bottom b_l = min( output[batch_idx, box_b_idx, box_start_idx], output[batch_idx, box_b_idx, box_start_idx + 2], ) b_t = min( output[batch_idx, box_b_idx, box_start_idx + 1], output[batch_idx, box_b_idx, box_start_idx + 3], ) b_r = max( output[batch_idx, box_b_idx, box_start_idx], output[batch_idx, box_b_idx, box_start_idx + 2], ) b_b = max( output[batch_idx, box_b_idx, box_start_idx + 1], output[batch_idx, box_b_idx, box_start_idx + 3], ) ... # With IRBuilder interleaving: from tvm.script import tir as T def get_box_coordinates(output, batch_idx, box_idx, box_start_idx): """a method executed by python interpreter""" box_l = T.min( output[batch_idx, box_idx, box_start_idx], output[batch_idx, box_idx, box_start_idx + 2], ) # type(box_l) is PrimExpr ... # Repeat for other coordinates return box_l, box_t, box_r, box_b @T.prim_func(capture=[get_box_coordinates]) def nms(...): ... for i in range(batch_size): ... a_l, a_t, a_r, a_b = get_box_coordinates(output, batch_idx, box_a_idx, box_start_idx) ... for k in range(j): check_iou = ... ... if check_iou > 0: b_l, b_t, b_r, b_b = get_box_coordinates(output, batch_idx, box_b_idx, box_start_idx) ... ``` [/quote] --- [Visit Topic](https://discuss.tvm.apache.org/t/rfc-tvmscript-metaprogramming/12969/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/94cae1ae3d732192ce273e6dceff2a13a033c9eb2458afef4aa942f8f5831b7c).