Neat! This is a really good explanation. I think I get most everything you are
explaining. (I'm still stuck on `reverse_compute_at` which seems like a long
name, and is still a bit too magical for me to understand.)
In terms of splitting / reordering, my goofy thought is that my favorite
construct for tensor abstractions is `vmap` in Jax. What makes programming
tensors so hard is that keeping 6 tensor dimensions in your head is really
hard. vmap lets you do something then "zoom" into the area, forget entirely
about the outer dimension, and focus on that.
When writing tvm code for matrix multiply with double buffering. I would really
like to 1) first decide on my tiling of the output, split, assign to blocks and
threads, and then 2) write a separately scoped bit of code that doesn't even
know about the outer construction at all. Ideally, I would create my outer
scope, vmap in, then all my buffers are automatically "reverse_compute_at" /
vmapped, and then create my inner setup.
I don't know if this totally works but this would be my ideal:
```python
l_o = split_out(C, l) # only exposes the outer
n_o = split_out(C, n)
with prefix(ll_o, nn_o, tensors=[A, B], threads=[]) as s2:
s2.cache_read(... ) # this cache read is now local computed at here
l, n = s2.axes(C) # these are now the inner splitted axes
m = s2.axes(A) # A's outer axes are now invisible
s2.reorder(n, l) # only touches the visible axes.
```
(Maybe instead of a `with` this is an inner function like in jax)
---
[Visit
Topic](https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872/45)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/05e4fd027d9c9c1a9ca4be05a983d0df9f859ea1f43bfac0a64be8f03e0bf98f).