``` # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import toml import torch import torchvision import argparse from tvm import relay from tqdm import tqdm import tvm from model_rnnt import RNNT from decoders import TransducerDecoder from preprocessing import AudioPreprocessing from dataset import AudioToTextDataLayer from helpers import process_evaluation_batch, process_evaluation_epoch, add_blank_label, print_dict
def get_args(): """Parse commandline.""" parser = argparse.ArgumentParser() parser.add_argument('--ckpt', type=str, required=True, help='The rnnt model path (pytorch ckpt path)') parser.add_argument('--dataset-path', type=str, required=True, help='The dataset path') parser.add_argument("--val_manifest", type=str, required=True, help='relative path to evaluation dataset manifest file') parser.add_argument("--model_toml", type=str, default='configs/rnnt.toml', help='relative model configuration path given dataset folder') parser.add_argument("--steps", default=100, help='if not specified do evaluation on full dataset. ' 'otherwise only evaluates the specified number of iterations for each worker', type=int) parser.add_argument('--batch_size', type=int, default=1, help='The batchsize to inference') parser.add_argument('--target', type=str, default='cuda', help='The target in TVM') parser.add_argument('--network', type=str, default='rnnt', help='The name of network') parser.add_argument('--cs', type=str, default=80, help='The number of calibration dataset') args = parser.parse_args() return args class RNNTGreedyDecoder(TransducerDecoder): """A greedy transducer decoder. Args: blank_symbol: See `Decoder`. model: Model to use for prediction. max_symbols_per_step: The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit. cutoff_prob: Skip to next step in search if current highest character probability is less than this. """ def __init__(self, blank_index, model, mod=None, params=None, target='cuda', max_symbols_per_step=30): super().__init__(blank_index, model) assert max_symbols_per_step is None or max_symbols_per_step > 0 self.mod = mod self.params = params self.target = target self.ctx = tvm.context(target) self.max_symbols = max_symbols_per_step def decode(self, x, out_lens): """Returns a list of sentences given an input batch. Args: x: A tensor of size (batch, channels, features, seq_len) TODO was (seq_len, batch, in_features). out_lens: list of int representing the length of each sequence output sequence. Returns: list containing batch number of sentences (strings). """ with torch.no_grad(): # Apply optional preprocessing # x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, out_lens) # logits, out_lens = self._model.encode(x, out_lens) logits, out_lens = self._model.encoder(x, out_lens) # executor = relay.create_executor('graph', self.mod, self.ctx, self.target) # logits, out_lens = executor.evaluate(self.mod["main"])(x, out_lens) output = [] for batch_idx in range(logits.size(0)): inseq = logits[batch_idx, :, :].unsqueeze(1) logitlen = out_lens[batch_idx] sentence = self._greedy_decode(inseq, logitlen) output.append(sentence) return output def _greedy_decode(self, x, out_len): training_state = self._model.training self._model.eval() device = x.device hidden = None label = [] for time_idx in range(out_len): f = x[time_idx, :, :].unsqueeze(0) not_blank = True symbols_added = 0 while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols): g, hidden_prime = self._pred_step( self._get_last_symb(label), hidden, device ) logp = self._joint_step(f, g, log_normalize=False)[0, :] # get index k, of max prob v, k = logp.max(0) k = k.item() if k == self._blank_id: not_blank = False else: label.append(k) hidden = hidden_prime symbols_added += 1 self._model.train(training_state) return label def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt): model = RNNT( feature_config=featurizer_config, rnnt=model_definition['rnnt'], num_classes=len(ctc_vocab) ) checkpoint = torch.load(ckpt, map_location="cpu") model.load_state_dict(checkpoint['state_dict'], strict=False) model.eval() return model def rnnt_model_to_tvm_mod(model): input_shape = (316, 1, 240) len_shape = (316) t_audio_signal_e = torch.randn(input_shape) t_a_sig_length_e = torch.randn(len_shape) model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e, t_a_sig_length_e)).eval() mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None) mod = relay.transform.RemoveUnusedFunctions()(mod) return mod, params def gat_data_layer(args, featurizer_config, val_manifest, dataset_vocab): data_layer = AudioToTextDataLayer( dataset_dir=args.dataset_path, featurizer_config=featurizer_config, manifest_filepath=val_manifest, labels=dataset_vocab, batch_size=args.batch_size, pad_to_max=featurizer_config['pad_to'] == "max", shuffle=False) return data_layer def get_audio_preprocessor(featurizer_config): audio_preprocessor = AudioPreprocessing(**featurizer_config) audio_preprocessor.featurizer.normalize = "per_feature" audio_preprocessor.eval() return audio_preprocessor def get_eval_transforms(audio_preprocessor): eval_transforms = [] eval_transforms.append(lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]]) # These are just some very confusing transposes, that's all. # BxFxT -> TxBxF eval_transforms.append(lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]]) eval_transforms = torchvision.transforms.Compose(eval_transforms) return eval_transforms def eval(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, args): """performs inference / evaluation Args: data_layer: data layer object that holds data loader audio_processor: data processing module encoderdecoder: acoustic model greedy_decoder: greedy decoder labels: list of labels as output vocabulary args: script input arguments """ encoderdecoder.eval() with torch.no_grad(): _global_var_dict = { 'predictions': [], 'transcripts': [], 'logits': [], } for it, data in enumerate(tqdm(data_layer.data_iterator)): (t_audio_signal_e, t_a_sig_length_e, transcript_list, t_transcript_e, t_transcript_len_e) = audio_processor(data) t_predictions_e = greedy_decoder.decode( t_audio_signal_e, t_a_sig_length_e) values_dict = dict( predictions=[t_predictions_e], transcript=transcript_list, transcript_length=t_transcript_len_e, ) process_evaluation_batch( values_dict, _global_var_dict, labels=labels) if args.steps is not None and it + 1 >= args.steps: break wer = process_evaluation_epoch(_global_var_dict) print("==========>>>>>>Evaluation WER: {0}\n".format(wer)) def main(): args = get_args() model_definition = toml.load(args.model_toml) dataset_vocab = model_definition['labels']['labels'] ctc_vocab = add_blank_label(dataset_vocab) featurizer_config = model_definition['input_eval'] rnnt_model = get_rnnt_model(featurizer_config, model_definition, ctc_vocab, args.ckpt) mod, params = rnnt_model_to_tvm_mod(rnnt_model) data_layer = gat_data_layer(args, featurizer_config, args.val_manifest, dataset_vocab) audio_preprocessor = get_audio_preprocessor(featurizer_config) eval_transforms = get_eval_transforms(audio_preprocessor) greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, rnnt_model, mod=None, params=None, target=args.target) eval(data_layer=data_layer, audio_processor=eval_transforms, encoderdecoder=rnnt_model, greedy_decoder=greedy_decoder, labels=ctc_vocab, args=args) if __name__ == '__main__': main() ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/5) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/c1bc8a25b81b3c02bd32101c435d28a8452583954b40dabe1d5ffc4bcedaa675).