> Thanks for the RFC. I have two questions: > > 1. How to mark/set the color (i.e., attribute) of every operator? > 2. It seems to me that if we register a casting checker instead of just a > label (color), then we can simplify the algorithm a lot. Taking the case > `A(green) - B(gray) - C(green)` as an example, if we could register a casting > rule of B as follows, then we just need one traverse to know if we need cast > around B: > ``` > def amp_B(expr, args): > a = args[0] > if (a.dtype is float16): > return fp16 > return fp32 > ``` > > > > > > > > > > > > After all, we only need the previous nodes to determine 1) whether to use > FP16 implementation, and 2) whether to insert casts. It seems to me that this > pass is similar to the layout conversion pass, which uses one traverse to > finish everything, so it might be possible for AMP too.
Yep that is correct. This RFC has an initial PR here: https://github.com/apache/tvm/pull/8069. To answer your questions: 1. src/relay/transforms/fp32_to_fp16.h -- DefaultFP16Colorer is the default way. But the only thing we need is a callable with type CallNode*(Color). So you could write your own colorer that does arbitrary stuff when only looking at a single node at a time. 2. This is functionally what is done in the PR I link. -- You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub: https://github.com/apache/tvm-rfcs/pull/6#issuecomment-857903539