Hi, I copy and paste 4 script. The RNN is defined in 'rnn.py'. The rnnt is defined in 'model_rnnt.py'. I tried to import the rnnt model into TVM . Please check it in first script. In the function get_rnnt_model I load the pre-trained model RNNT. In the function rnnt_model_to_tvm_mod , I tried to transform it into TVM. ``` def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt): model = RNNT( feature_config=featurizer_config, rnnt=model_definition['rnnt'], num_classes=len(ctc_vocab) ) checkpoint = torch.load(ckpt, map_location="cpu") model.load_state_dict(checkpoint['state_dict'], strict=False) model.eval() return model
def rnnt_model_to_tvm_mod(model): input_shape = (316, 1, 240) len_shape = (316) t_audio_signal_e = torch.randn(input_shape) t_a_sig_length_e = torch.randn(len_shape) model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e, t_a_sig_length_e)).eval() mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None) mod = relay.transform.RemoveUnusedFunctions()(mod) return mod, params ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/9) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/4eee8e043d12bb2734a67633e8328268b4da7715f22d46f14dd122ea21c02abc).