KEKE046 opened a new pull request, #327: URL: https://github.com/apache/tvm-ffi/pull/327
When training neural networks, it is beneficial to free tensors as early as possible once they are no longer used. In CPython, as long as there is no reference cycle, tensors are freed immediately when their reference count drops to zero. Torch blog: [Finding and Removing Reference Cycles](https://pytorch.org/blog/understanding-gpu-memory-2/) However, in our TVM FFI error handling path, we create reference cycles that keep intermediate tensors (and the entire call chain) alive much longer than necessary. This slows down local GC, increases memory pressure, and hurts training throughput. ## Problem The following code represents the problem. 1. We allocate a tensor in foo 2. We call `Map.get` with a non-existent key, which causes a KeyError and will be captured by Map 3. The KeyError causes a reference cycle, and makes the whole call chain require gc ``` py import torch import tvm_ffi import gc from torch.utils.viz._cycles import warn_tensor_cycles m = tvm_ffi.Map({'a': 1}) def foo(): a = torch.tensor([1], device='cuda') # map get a non-existent key will cause a KeyError and capture it _tmp = m.get('b', 0) foo() remove = warn_tensor_cycles() gc.collect() remove() ``` The following figure demonstrates the call chain and reference cycles: <img width="60%" alt="image" src="https://github.com/user-attachments/assets/ba477a5a-eca5-4e6b-a9db-395da9ac49c5" /> The following figure shows the object reference graph from torch warn_tensor_cycles <img width="3624" height="512" alt="image" src="https://github.com/user-attachments/assets/e2b8b211-9e7a-4a46-b500-425f09f71d75" /> ## Approach As shown in the figure, the 3 variables, `frame`, `py_error`, and `tb`, are the core of the reference cycle. To remove this error, we can manually delete the 3 variables in the loop. https://github.com/apache/tvm-ffi/blob/f7e09d6a96b54554190bae0d7ba9ff7a6e9a109e/python/tvm_ffi/error.py#L124-L136 We can use `try ... finally ...` to manually delete these variables after the function returns. ```py try: tb = py_error.__traceback__ for filename, lineno, func in _parse_backtrace(backtrace): tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) return py_error.with_traceback(tb) finally: del py_error, tb ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
