@jonso here's a little relay pass I've been using to downcast a model from FP32 
to FP16. I don't think the target really matters as you would apply this pass 
before compiling to cuda.

```
from tvm.relay import transform as _transform
from tvm.ir import IRModule
from tvm.relay import cast

def downcast_fp16(func):
    # pylint: disable=line-too-long
    """Downcast to fp16 mutator
    Parameters
    ---------
    graph: Function
        The original graph.
    Retruns
    -------
    The graph after dowmcasting to half-precision floating-point.
    """
    # get_valid_counts and non_max_suppression does not support fp16 so we 
create a filter list for them
    filter_list = ['vision.get_valid_counts', 'vision.non_max_suppression']
    class DowncastMutator(ExprMutator):
        """Downcast to fp16 mutator"""
        def visit_call(self, call):
            dtype = 'float32' if call.op.name in filter_list else 'float16'
            new_fn = self.visit(call.op)
            # Collec the original dtypes
            type_list = []
            if call.op.name in filter_list:
                # For nms
                for arg in call.args:
                    if isinstance(arg, TupleGetItem) and 
isinstance(arg.tuple_value, Call):
                        tuple_types = arg.tuple_value.checked_type.fields
                        type_list.append(tuple_types[arg.index].dtype)
                if call.op.name == 'vision.get_valid_counts':
                    tuple_types = call.checked_type.fields
                    for cur_type in tuple_types:
                        type_list.append(cur_type.dtype)

            args = [self.visit(arg) for arg in call.args]
            new_args = list()
            arg_idx = 0
            for arg in args:
                if isinstance(arg, (Var, Constant)):
                    new_args.append(cast(arg, dtype=dtype))
                else:
                    if call.op.name in filter_list:
                        if isinstance(arg, TupleGetItem) and type_list[arg_idx] 
== 'int32':
                            new_args.append(arg)
                        else:
                            new_args.append(cast(arg, dtype=dtype))
                    else:
                        new_args.append(arg)
                arg_idx += 1
            if call.op.name in filter_list and call.op.name != 
'vision.get_valid_counts':
                return cast(Call(new_fn, new_args, call.attrs), dtype='float16')
            return Call(new_fn, new_args, call.attrs)

    class UpcastMutator(ExprMutator):
        """upcast output back to fp32 mutator"""
        def visit_call(self, call):
            return cast(call, dtype='float32')

    def infer_type(expr):
        """A method to infer the type of an intermediate node in the relay 
graph"""
        mod = IRModule.from_expr(expr)
        mod = _transform.InferType()(mod)
        entry = mod["main"]
        return entry if isinstance(expr, Function) else entry.body

    func = infer_type(func)
    downcast_pass = DowncastMutator()
    func = downcast_pass.visit(func)
    upcast_pass = UpcastMutator()
    func = upcast_pass.visit(func)
    func = infer_type(func)
    return func
```





---
[Visit Topic](https://discuss.tvm.ai/t/cuda-fp16-example/6190/5) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/89c9ac154c4037eb6e1a13724943704f57ab75de345e7c9071b0b76eeaed080f).

Reply via email to