locnd182644 opened a new pull request, #18583:
URL: https://github.com/apache/tvm/pull/18583

   ## Issue 1: Without Dim
   ### Summary:
   In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] = [] 
and still pass into relax.op.sum so the result is incorrect.
   ### Steps to Reproduce
   - Module
   ```
   class SumWithoutDim(nn.Module):
       def forward(self, x):
           return torch.sum(x)
   ```
   ```
   @I.ir_module
   class Module:
       @R.function
       def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 
3), dtype="float32")):
           with R.dataflow():
               lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[], 
keepdims=False)
               gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
               R.output(gv)
           return gv
   ```
   - Result:
   
   Input: tensor([[1., 1., 1.], [1., 1., 1.]])
   Torch output: tensor(6.)
   Torch output shape: torch.Size([])
   TVM output: [[1. 1. 1.]  [1. 1. 1.]]
   TVM output shape: (2, 3)
   ### Expected
   ```
   @I.ir_module
   class Module:
       @R.function
       def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32")):
           with R.dataflow():
               lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, 
keepdims=False)
               gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
               R.output(gv)
           return gv
   ```
   - Result: TVM output: 6.0; TVM output shape: ()
   
   ## Issue 2: Keep Dim
   ### Summary:
   In _sum function (BaseFXGraphImporter), previously keepdim value get only 
from node.kwargs and no pass into relax.op.sum. Now keepdim get more from 
args[2] and pass into.
   ### Steps to Reproduce
   - Module
   ```
   class SumKeepDim(nn.Module):
       def forward(self, x):
           return torch.sum(x, dim=1, keepdim=True)
   ```
   ```
   @I.ir_module
   class Module:
       @R.function
       def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,), 
dtype="float32")):
           with R.dataflow():
               lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1], 
keepdims=False)
               gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
               R.output(gv)
           return gv
   
   ```
   - Result:
   
   Input: tensor([[1., 1., 1.], [1., 1., 1.]])
   Torch output: tensor([[3.], [3.]])
   Torch output shape: torch.Size([2, 1])
   TVM VM output: [3. 3.]
   TVM VM output shape: (2,)
   ### Expected
   ```
   @I.ir_module
   class Module:
       @R.function
       def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 
1), dtype="float32")):
           with R.dataflow():
               lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1], 
keepdims=True)
               gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,)
               R.output(gv)
           return gv
   ```
   - Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to