Hi, I copy and paste 4 script. The RNN is defined in 'rnn.py'. The rnnt is 
defined in 'model_rnnt.py'.
I tried to import the rnnt model into TVM . Please check it in first script.
In the function get_rnnt_model I load the pre-trained model RNNT.
In the function rnnt_model_to_tvm_mod , I tried to transform it into TVM.
```
def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt):
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )
    checkpoint = torch.load(ckpt, map_location="cpu")
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    return model

def rnnt_model_to_tvm_mod(model):
    input_shape = (316, 1, 240)
    len_shape = (316)
    t_audio_signal_e = torch.randn(input_shape)
    t_a_sig_length_e = torch.randn(len_shape)
    model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e, 
t_a_sig_length_e)).eval()

    mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
    mod = relay.transform.RemoveUnusedFunctions()(mod)
    return mod, params
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/9)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/4eee8e043d12bb2734a67633e8328268b4da7715f22d46f14dd122ea21c02abc).

Reply via email to