> Thanks for the RFC. I have two questions:
> 
> 1. How to mark/set the color (i.e., attribute) of every operator?
> 2. It seems to me that if we register a casting checker instead of just a 
> label (color), then we can simplify the algorithm a lot. Taking the case 
> `A(green) - B(gray) - C(green)` as an example, if we could register a casting 
> rule of B as follows, then we just need one traverse to know if we need cast 
> around B:
>    ```
>    def amp_B(expr, args):
>        a = args[0]
>        if (a.dtype is float16):
>          return fp16
>        return fp32
>    ```
>    
>    
>        
>          
>        
>    
>          
>        
>    
>        
>      
>    After all, we only need the previous nodes to determine 1) whether to use 
> FP16 implementation, and 2) whether to insert casts. It seems to me that this 
> pass is similar to the layout conversion pass, which uses one traverse to 
> finish everything, so it might be possible for AMP too.

Yep that is correct. This RFC has an initial PR here: 
https://github.com/apache/tvm/pull/8069.

To answer your questions:
1.  src/relay/transforms/fp32_to_fp16.h -- DefaultFP16Colorer is the default 
way. But the only thing we need is a callable with type CallNode*(Color). So 
you could write your own colorer that does arbitrary stuff when only looking at 
a single node at a time.

2. This is functionally what is done in the PR I link.

-- 
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/apache/tvm-rfcs/pull/6#issuecomment-857903539

Reply via email to