```
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import toml
import torch
import torchvision
import argparse
from tvm import relay
from tqdm import tqdm
import tvm
from model_rnnt import RNNT
from decoders import TransducerDecoder
from preprocessing import AudioPreprocessing
from dataset import AudioToTextDataLayer
from helpers import process_evaluation_batch, process_evaluation_epoch,
add_blank_label, print_dict
def get_args():
"""Parse commandline."""
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, required=True, help='The rnnt model
path (pytorch ckpt path)')
parser.add_argument('--dataset-path', type=str, required=True, help='The
dataset path')
parser.add_argument("--val_manifest", type=str, required=True,
help='relative path to evaluation dataset manifest
file')
parser.add_argument("--model_toml", type=str, default='configs/rnnt.toml',
help='relative model configuration path given dataset
folder')
parser.add_argument("--steps", default=100,
help='if not specified do evaluation on full dataset. '
'otherwise only evaluates the specified number of
iterations for each worker', type=int)
parser.add_argument('--batch_size', type=int, default=1, help='The
batchsize to inference')
parser.add_argument('--target', type=str, default='cuda', help='The target
in TVM')
parser.add_argument('--network', type=str, default='rnnt', help='The name
of network')
parser.add_argument('--cs', type=str, default=80, help='The number of
calibration dataset')
args = parser.parse_args()
return args
class RNNTGreedyDecoder(TransducerDecoder):
"""A greedy transducer decoder.
Args:
blank_symbol: See `Decoder`.
model: Model to use for prediction.
max_symbols_per_step: The maximum number of symbols that can be added
to a sequence in a single time step; if set to None then there is
no limit.
cutoff_prob: Skip to next step in search if current highest character
probability is less than this.
"""
def __init__(self, blank_index, model, mod=None, params=None,
target='cuda', max_symbols_per_step=30):
super().__init__(blank_index, model)
assert max_symbols_per_step is None or max_symbols_per_step > 0
self.mod = mod
self.params = params
self.target = target
self.ctx = tvm.context(target)
self.max_symbols = max_symbols_per_step
def decode(self, x, out_lens):
"""Returns a list of sentences given an input batch.
Args:
x: A tensor of size (batch, channels, features, seq_len)
TODO was (seq_len, batch, in_features).
out_lens: list of int representing the length of each sequence
output sequence.
Returns:
list containing batch number of sentences (strings).
"""
with torch.no_grad():
# Apply optional preprocessing
# x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, out_lens)
# logits, out_lens = self._model.encode(x, out_lens)
logits, out_lens = self._model.encoder(x, out_lens)
# executor = relay.create_executor('graph', self.mod, self.ctx,
self.target)
# logits, out_lens = executor.evaluate(self.mod["main"])(x,
out_lens)
output = []
for batch_idx in range(logits.size(0)):
inseq = logits[batch_idx, :, :].unsqueeze(1)
logitlen = out_lens[batch_idx]
sentence = self._greedy_decode(inseq, logitlen)
output.append(sentence)
return output
def _greedy_decode(self, x, out_len):
training_state = self._model.training
self._model.eval()
device = x.device
hidden = None
label = []
for time_idx in range(out_len):
f = x[time_idx, :, :].unsqueeze(0)
not_blank = True
symbols_added = 0
while not_blank and (self.max_symbols is None or symbols_added <
self.max_symbols):
g, hidden_prime = self._pred_step(
self._get_last_symb(label),
hidden,
device
)
logp = self._joint_step(f, g, log_normalize=False)[0, :]
# get index k, of max prob
v, k = logp.max(0)
k = k.item()
if k == self._blank_id:
not_blank = False
else:
label.append(k)
hidden = hidden_prime
symbols_added += 1
self._model.train(training_state)
return label
def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt):
model = RNNT(
feature_config=featurizer_config,
rnnt=model_definition['rnnt'],
num_classes=len(ctc_vocab)
)
checkpoint = torch.load(ckpt, map_location="cpu")
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
return model
def rnnt_model_to_tvm_mod(model):
input_shape = (316, 1, 240)
len_shape = (316)
t_audio_signal_e = torch.randn(input_shape)
t_a_sig_length_e = torch.randn(len_shape)
model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e,
t_a_sig_length_e)).eval()
mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
mod = relay.transform.RemoveUnusedFunctions()(mod)
return mod, params
def gat_data_layer(args, featurizer_config, val_manifest, dataset_vocab):
data_layer = AudioToTextDataLayer(
dataset_dir=args.dataset_path,
featurizer_config=featurizer_config,
manifest_filepath=val_manifest,
labels=dataset_vocab,
batch_size=args.batch_size,
pad_to_max=featurizer_config['pad_to'] == "max",
shuffle=False)
return data_layer
def get_audio_preprocessor(featurizer_config):
audio_preprocessor = AudioPreprocessing(**featurizer_config)
audio_preprocessor.featurizer.normalize = "per_feature"
audio_preprocessor.eval()
return audio_preprocessor
def get_eval_transforms(audio_preprocessor):
eval_transforms = []
eval_transforms.append(lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]])
# These are just some very confusing transposes, that's all.
# BxFxT -> TxBxF
eval_transforms.append(lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]])
eval_transforms = torchvision.transforms.Compose(eval_transforms)
return eval_transforms
def eval(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels,
args):
"""performs inference / evaluation
Args:
data_layer: data layer object that holds data loader
audio_processor: data processing module
encoderdecoder: acoustic model
greedy_decoder: greedy decoder
labels: list of labels as output vocabulary
args: script input arguments
"""
encoderdecoder.eval()
with torch.no_grad():
_global_var_dict = {
'predictions': [],
'transcripts': [],
'logits': [],
}
for it, data in enumerate(tqdm(data_layer.data_iterator)):
(t_audio_signal_e, t_a_sig_length_e,
transcript_list, t_transcript_e,
t_transcript_len_e) = audio_processor(data)
t_predictions_e = greedy_decoder.decode(
t_audio_signal_e, t_a_sig_length_e)
values_dict = dict(
predictions=[t_predictions_e],
transcript=transcript_list,
transcript_length=t_transcript_len_e,
)
process_evaluation_batch(
values_dict, _global_var_dict, labels=labels)
if args.steps is not None and it + 1 >= args.steps:
break
wer = process_evaluation_epoch(_global_var_dict)
print("==========>>>>>>Evaluation WER: {0}\n".format(wer))
def main():
args = get_args()
model_definition = toml.load(args.model_toml)
dataset_vocab = model_definition['labels']['labels']
ctc_vocab = add_blank_label(dataset_vocab)
featurizer_config = model_definition['input_eval']
rnnt_model = get_rnnt_model(featurizer_config, model_definition, ctc_vocab,
args.ckpt)
mod, params = rnnt_model_to_tvm_mod(rnnt_model)
data_layer = gat_data_layer(args, featurizer_config, args.val_manifest,
dataset_vocab)
audio_preprocessor = get_audio_preprocessor(featurizer_config)
eval_transforms = get_eval_transforms(audio_preprocessor)
greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, rnnt_model,
mod=None, params=None, target=args.target)
eval(data_layer=data_layer,
audio_processor=eval_transforms,
encoderdecoder=rnnt_model,
greedy_decoder=greedy_decoder,
labels=ctc_vocab,
args=args)
if __name__ == '__main__':
main()
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/5)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/c1bc8a25b81b3c02bd32101c435d28a8452583954b40dabe1d5ffc4bcedaa675).