Thanks for the RFC. I have two questions: 1. How to mark/set the color (i.e., attribute) of every operator? 2. It seems to me that if we register a casting checker instead of just a label (color), then we can simplify the algorithm a lot. Taking the case `A(green) - B(gray) - C(green)` as an example, if we could register a casting rule of B as follows, then we just need one traverse to know if we need cast around B:
``` def amp_B(expr, args): a = args[0] if (a.dtype is float16): return fp16 return fp32 ``` After all, we only need the previous nodes to determine 1) whether to use FP16 implementation, and 2) whether to insert casts. It seems to me that this pass is similar to the layout conversion pass, which uses one traverse to finish everything, so it might be possible for AMP too. -- You are receiving this because you are subscribed to this thread. Reply to this email directly or view it on GitHub: https://github.com/apache/tvm-rfcs/pull/6#issuecomment-857889774