locnd182644 opened a new issue, #18526:
URL: https://github.com/apache/tvm/issues/18526

   ### Description
   Happen error "Cannot convert from type `DLTensor*` to `ffi.Shape``" when run 
part: Deploy PyTorch Models to Remote Devices with RPC in "Cross Compilation 
and RPC" 
   
   ### Actual behavior
   `RPCError: Error caught from RPC call:
   Cannot convert from type `DLTensor*` to `ffi.Shape``
   
   ### Steps to reproduce
   
   ```
   class TorchMLP(torch.nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.net = torch.nn.Sequential(
                   torch.nn.Flatten(),
                   torch.nn.Linear(28 * 28, 128),
                   torch.nn.ReLU(),
                   torch.nn.Linear(128, 10),
               )
   
           def forward(self, data: torch.Tensor) -> torch.Tensor:
               return self.net(data)
   ```
   
   ```
   # from tvm.script import ir as I
   # from tvm.script import relax as R
   
   @I.ir_module
   class Module:
       @R.function
       def main(data: R.Tensor((1, 1, 28, 28), dtype="float32"), 
p_net_1_weight: R.Tensor((128, 784), dtype="float32"), p_net_1_bias: 
R.Tensor((128,), dtype="float32"), p_net_3_weight: R.Tensor((10, 128), 
dtype="float32"), p_net_3_bias: R.Tensor((10,), dtype="float32")) -> 
R.Tuple(R.Tensor((1, 10), dtype="float32")):
           R.func_attr({"num_input": 1})
           with R.dataflow():
               lv: R.Tensor((1, 784), dtype="float32") = R.reshape(data, 
R.shape([1, 784]))
               lv1: R.Tensor((784, 128), dtype="float32") = 
R.permute_dims(p_net_1_weight, axes=None)
               lv2: R.Tensor((1, 128), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="float32")
               lv3: R.Tensor((1, 128), dtype="float32") = R.add(lv2, 
p_net_1_bias)
               lv4: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv3)
               lv5: R.Tensor((128, 10), dtype="float32") = 
R.permute_dims(p_net_3_weight, axes=None)
               lv6: R.Tensor((1, 10), dtype="float32") = R.matmul(lv4, lv5, 
out_dtype="float32")
               lv7: R.Tensor((1, 10), dtype="float32") = R.add(lv6, 
p_net_3_bias)
               gv: R.Tuple(R.Tensor((1, 10), dtype="float32")) = (lv7,)
               R.output(gv)
           return gv
   
   ```
   ### Error log
   
   ```
   
   ---------------------------------------------------------------------------
   RPCError                                  Traceback (most recent call last)
   Cell In[4], line 2
         1 local_demo = False
   ----> 2 run_pytorch_model_via_rpc()
   
   Cell In[3], line 189
       186 remote_input = tvm.runtime.tensor(input_data, dev)
       188 # Run inference on remote device
   --> 189 output = vm["main"](remote_input, *remote_params)
       191 # Extract result (handle both tuple and single tensor outputs)
       192 if isinstance(output, tvm.ir.Array) and len(output) > 0:
   
   File python/tvm_ffi/cython/function.pxi:904, in 
tvm_ffi.core.Function.__call__()
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_module.cc:141, in 
tvm::runtime::RPCWrappedFunc::operator()(tvm::ffi::PackedArgs, tvm::ffi::Any*) 
const()
       139   }
       140   auto set_return = [this, rv](ffi::PackedArgs args) { 
this->WrapRemoteReturnToValue(args, rv); };
   --> 141   sess_->CallFunc(handle_, ffi::PackedArgs(packed_args.data(), 
packed_args.size()), set_return);
       142 }
       143 
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:1116, in 
tvm::runtime::RPCClientSession::CallFunc(void*, tvm::ffi::PackedArgs, 
std::function<void (tvm::ffi::PackedArgs)> const&)()
      1114 void CallFunc(PackedFuncHandle func, ffi::PackedArgs args,
      1115               const FEncodeReturn& fencode_return) final {
   -> 1116   endpoint_->CallFunc(func, args, fencode_return);
      1117 }
      1118 
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:906, in 
tvm::runtime::RPCEndpoint::CallFunc(void*, tvm::ffi::PackedArgs, 
std::function<void (tvm::ffi::PackedArgs)>)()
       904   handler_->SendPackedSeq(args.data(), args.size(), true);
       905 
   --> 906   code = HandleUntilReturnEvent(true, encode_return);
       907   ICHECK(code == RPCCode::kReturn) << "code=" << 
RPCCodeToString(code);
       908 }
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:746, in 
tvm::runtime::RPCEndpoint::HandleUntilReturnEvent(bool, std::function<void 
(tvm::ffi::PackedArgs)>)()
       744     }
       745   }
   --> 746   code = handler_->HandleNextEvent(client_mode, false, setreturn);
       747 }
       748 return code;
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:134, in 
tvm::runtime::RPCEndpoint::EventHandler::HandleNextEvent(bool, bool, 
std::function<void (tvm::ffi::PackedArgs)>)()
       132 }
       133 case kProcessPacket: {
   --> 134   this->HandleProcessPacket(setreturn);
       135   break;
       136 }
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:409, in 
tvm::runtime::RPCEndpoint::EventHandler::HandleProcessPacket(std::function<void 
(tvm::ffi::PackedArgs)>)()
       407 case RPCCode::kException:
       408 case RPCCode::kReturn: {
   --> 409   this->HandleReturn(code, setreturn);
       410   break;
       411 }
   
   File ~/Programming/tvm/src/runtime/rpc/rpc_endpoint.cc:473, in 
tvm::runtime::RPCEndpoint::EventHandler::HandleReturn(tvm::runtime::RPCCode, 
std::function<void (tvm::ffi::PackedArgs)>)()
       471     msg = "RPCError: Error caught from RPC call:\n" + msg;
       472   }
   --> 473   LOG(FATAL) << msg;
       474 }
       475 
   
   File ~/Programming/tvm/include/tvm/runtime/logging.h:321, in 
tvm::runtime::detail::LogFatal::~LogFatal()()
       319 #endif
       320   [[noreturn]] ~LogFatal() TVM_THROW_EXCEPTION {
   --> 321     GetEntry().Finalize();
       322     throw;
       323   }
   
   File ~/Programming/tvm/include/tvm/runtime/logging.h:337, in 
tvm::runtime::detail::LogFatal::Entry::Finalize()()
       335     }
       336     [[noreturn]] TVM_NO_INLINE dmlc::Error Finalize() 
TVM_THROW_EXCEPTION {
   --> 337       InternalError error(file_, lineno_, stream_.str());
       338 #if DMLC_LOG_BEFORE_THROW
       339       std::cerr << error.what() << std::endl;
   
   RPCError: Error caught from RPC call:
   Cannot convert from type `DLTensor*` to `ffi.Shape`
   ```
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to