mshr-h commented on code in PR #18516:
URL: https://github.com/apache/tvm/pull/18516#discussion_r2567863387
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -42,6 +42,37 @@ def verify_model(
tvm.ir.assert_structural_equal(mod, expected)
+def verify_model_numerically(torch_model, example_args, rtol=1e-4, atol=1e-5):
Review Comment:
The default rtol and atol should be that of `tvm.testing.assert_allclose`.
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -7654,74 +7647,40 @@ def main(
def test_lstm():
Review Comment:
Please decorate it with `tvm.testing.utils.requires_llvm`
https://github.com/apache/tvm/blob/main/python/tvm/testing/utils.py#L842
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -42,6 +42,37 @@ def verify_model(
tvm.ir.assert_structural_equal(mod, expected)
+def verify_model_numerically(torch_model, example_args, rtol=1e-4, atol=1e-5):
+ """Verify model by comparing numerical outputs between PyTorch and TVM."""
+ with torch.no_grad():
+ pytorch_output = torch_model(*example_args)
+
+ exported_program = export(torch_model, args=example_args)
+ mod = from_exported_program(exported_program)
+ target = tvm.target.Target("llvm")
+ ex = relax.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
+ tvm_output = vm["main"](*tvm_args)
+
+ if hasattr(tvm_output, "numpy"):
+ tvm_output_np = tvm_output.numpy()
+ else:
+ tvm_output_np = tvm_output[0].numpy()
+
+ pytorch_output_np = (
+ pytorch_output.numpy()
+ if isinstance(pytorch_output, torch.Tensor)
+ else pytorch_output[0].numpy()
+ )
+
+ assert (
+ pytorch_output_np.shape == tvm_output_np.shape
+ ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM
{tvm_output_np.shape}"
+ np.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol,
atol=atol)
Review Comment:
Use `tvm.testing.assert_allclose` instead.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]