KEKE046 opened a new pull request, #327:
URL: https://github.com/apache/tvm-ffi/pull/327

   When training neural networks, it is beneficial to free tensors as early as 
possible once they are no longer used. In CPython, as long as there is no 
reference cycle, tensors are freed immediately when their reference count drops 
to zero.
   
   Torch blog: [Finding and Removing Reference 
Cycles](https://pytorch.org/blog/understanding-gpu-memory-2/)
   
   However, in our TVM FFI error handling path, we create reference cycles that 
keep intermediate tensors (and the entire call chain) alive much longer than 
necessary. This slows down local GC, increases memory pressure, and hurts 
training throughput.
   
   ## Problem
   
   The following code represents the problem.
   1. We allocate a tensor in foo
   2. We call `Map.get` with a non-existent key, which causes a KeyError and 
will be captured by Map
   3. The KeyError causes a reference cycle, and makes the whole call chain 
require gc
   
   ``` py
   import torch
   import tvm_ffi
   import gc
   from torch.utils.viz._cycles import warn_tensor_cycles
   
   m = tvm_ffi.Map({'a': 1})
   
   def foo():
       a = torch.tensor([1], device='cuda')
       # map get a non-existent key will cause a KeyError and capture it
       _tmp = m.get('b', 0)
   
   foo()
   
   remove = warn_tensor_cycles()
   gc.collect()
   remove()
   ```
   
   The following figure demonstrates the call chain and reference cycles:
   
   <img width="60%" alt="image" 
src="https://github.com/user-attachments/assets/ba477a5a-eca5-4e6b-a9db-395da9ac49c5";
 />
   
   The following figure shows the object reference graph from torch 
warn_tensor_cycles
   
   <img width="3624" height="512" alt="image" 
src="https://github.com/user-attachments/assets/e2b8b211-9e7a-4a46-b500-425f09f71d75";
 />
   
   ## Approach
   
   As shown in the figure, the 3 variables, `frame`, `py_error`, and `tb`, are 
the core of the reference cycle. To remove this error, we can manually delete 
the 3 variables in the loop.
   
   
https://github.com/apache/tvm-ffi/blob/f7e09d6a96b54554190bae0d7ba9ff7a6e9a109e/python/tvm_ffi/error.py#L124-L136
   
   We can use `try ... finally ...` to manually delete these variables after 
the function returns.
   
   ```py
       try:
           tb = py_error.__traceback__
           for filename, lineno, func in _parse_backtrace(backtrace):
               tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, 
func)
           return py_error.with_traceback(tb)
       finally:
           del py_error, tb
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to