Pytorch dose support forward hook for `torch.jit.trace(...)`.  

For details, you can check: https://github.com/pytorch/pytorch/issues/34329 and 
https://github.com/pytorch/pytorch/pull/49544 . 

For usage, there is a test file from Pytorch: 
https://github.com/pytorch/pytorch/blob/5c23888953d277041b341d38dcd5b2d891619ba4/test/jit/test_hooks.py
 . 

I personally think that a hook mechanism is useful, as it will be convenient if 
we can get intermediate output for debugging (and for cases like quantization 
accuracy checking, as you have mentioned). Pytorch itself does support this 
feature, however, it seems that we can't do the same thing for TVM for now. I 
will explain a little bit:

To actually get the intermediate result, one way is to just "print" the 
intermediate tensor in the hook. You can use `torch.jit.trace` to compile a 
PyTorch model with print function inside a hooker. However, TVM will give you 
an error saying that some functions are not implemented:
```
The following operators are not implemented: ['prim::Print']
```

Another way is to create a python class like:

```
class HookRecorder:
    def __init__(self):
        self.recorder = dict() # Get intermediate tensor from the recorder
        self.handlers = list()
    
    def _register_hooker(self, name):
        self.recorder[name] = list()
        def named_hooker(module, input: Tuple[torch.Tensor], output: 
torch.Tensor):
            self.recorder[name].append(output)
        return named_hooker
    
    def register_hookers(self, target_sub_modules, layer_names):
        for i in range(len(layer_names)):
            module = target_sub_modules[i]
            layer_name = layer_names[i]
            handler = 
module.register_forward_hook(self._register_hooker(layer_name))
        self.handlers.append(handler)
        
    def remove_handlers(self):
        for i in self.handlers:
            i.remove()
        self.handlers.clear()
        
    def __del__(self):
        self.remove_handlers()

hook = HookRecorder()
hook.register_hookers([net.conv2], ["conv2"])
out = net(input)
print(hook.recorder)
```

In this way, we can indeed get intermediate values from the python class. 
However, this can not be compiled by `torch.jit.trace`.





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/pytorch-register-forward-hook-support/11036/7)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/13f6bc6341612c9a6f04ef79e2bd9dcc38e82c613206b4c5557069c157d7ee0a).

Reply via email to