```
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import toml
import torch
import torchvision
import argparse
from tvm import relay
from tqdm import tqdm
import tvm
from model_rnnt import RNNT
from decoders import TransducerDecoder
from preprocessing import AudioPreprocessing
from dataset import AudioToTextDataLayer
from helpers import process_evaluation_batch, process_evaluation_epoch, 
add_blank_label, print_dict

def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt', type=str, required=True, help='The rnnt model 
path (pytorch ckpt path)')
    parser.add_argument('--dataset-path', type=str, required=True, help='The 
dataset path')
    parser.add_argument("--val_manifest", type=str, required=True,
                        help='relative path to evaluation dataset manifest 
file')
    parser.add_argument("--model_toml", type=str, default='configs/rnnt.toml',
                        help='relative model configuration path given dataset 
folder')
    parser.add_argument("--steps", default=100,
                        help='if not specified do evaluation on full dataset. '
                             'otherwise only evaluates the specified number of 
iterations for each worker', type=int)
    parser.add_argument('--batch_size', type=int, default=1, help='The 
batchsize to inference')
    parser.add_argument('--target', type=str, default='cuda', help='The target 
in TVM')
    parser.add_argument('--network', type=str, default='rnnt', help='The name 
of network')
    parser.add_argument('--cs', type=str, default=80, help='The number of 
calibration dataset')
    args = parser.parse_args()
    return args


class RNNTGreedyDecoder(TransducerDecoder):
    """A greedy transducer decoder.

    Args:
        blank_symbol: See `Decoder`.
        model: Model to use for prediction.
        max_symbols_per_step: The maximum number of symbols that can be added
            to a sequence in a single time step; if set to None then there is
            no limit.
        cutoff_prob: Skip to next step in search if current highest character
            probability is less than this.
    """

    def __init__(self, blank_index, model, mod=None, params=None, 
target='cuda', max_symbols_per_step=30):
        super().__init__(blank_index, model)
        assert max_symbols_per_step is None or max_symbols_per_step > 0
        self.mod = mod
        self.params = params
        self.target = target
        self.ctx = tvm.context(target)
        self.max_symbols = max_symbols_per_step

    def decode(self, x, out_lens):
        """Returns a list of sentences given an input batch.

        Args:
            x: A tensor of size (batch, channels, features, seq_len)
                TODO was (seq_len, batch, in_features).
            out_lens: list of int representing the length of each sequence
                output sequence.

        Returns:
            list containing batch number of sentences (strings).
        """
        with torch.no_grad():
            # Apply optional preprocessing

            # x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, out_lens)
            # logits, out_lens = self._model.encode(x, out_lens)
            logits, out_lens = self._model.encoder(x, out_lens)
            # executor = relay.create_executor('graph', self.mod, self.ctx, 
self.target)
            # logits, out_lens = executor.evaluate(self.mod["main"])(x, 
out_lens)

            output = []
            for batch_idx in range(logits.size(0)):
                inseq = logits[batch_idx, :, :].unsqueeze(1)
                logitlen = out_lens[batch_idx]
                sentence = self._greedy_decode(inseq, logitlen)
                output.append(sentence)

        return output

    def _greedy_decode(self, x, out_len):
        training_state = self._model.training
        self._model.eval()

        device = x.device

        hidden = None
        label = []
        for time_idx in range(out_len):
            f = x[time_idx, :, :].unsqueeze(0)

            not_blank = True
            symbols_added = 0

            while not_blank and (self.max_symbols is None or symbols_added < 
self.max_symbols):
                g, hidden_prime = self._pred_step(
                    self._get_last_symb(label),
                    hidden,
                    device
                )
                logp = self._joint_step(f, g, log_normalize=False)[0, :]

                # get index k, of max prob
                v, k = logp.max(0)
                k = k.item()

                if k == self._blank_id:
                    not_blank = False
                else:
                    label.append(k)
                    hidden = hidden_prime
                symbols_added += 1

        self._model.train(training_state)
        return label

def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt):
    model = RNNT(
        feature_config=featurizer_config,
        rnnt=model_definition['rnnt'],
        num_classes=len(ctc_vocab)
    )
    checkpoint = torch.load(ckpt, map_location="cpu")
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    return model

def rnnt_model_to_tvm_mod(model):
    input_shape = (316, 1, 240)
    len_shape = (316)
    t_audio_signal_e = torch.randn(input_shape)
    t_a_sig_length_e = torch.randn(len_shape)
    model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e, 
t_a_sig_length_e)).eval()

    mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
    mod = relay.transform.RemoveUnusedFunctions()(mod)
    return mod, params


def gat_data_layer(args, featurizer_config, val_manifest, dataset_vocab):
    data_layer = AudioToTextDataLayer(
        dataset_dir=args.dataset_path,
        featurizer_config=featurizer_config,
        manifest_filepath=val_manifest,
        labels=dataset_vocab,
        batch_size=args.batch_size,
        pad_to_max=featurizer_config['pad_to'] == "max",
        shuffle=False)
    return data_layer

def get_audio_preprocessor(featurizer_config):
    audio_preprocessor = AudioPreprocessing(**featurizer_config)
    audio_preprocessor.featurizer.normalize = "per_feature"
    audio_preprocessor.eval()
    return audio_preprocessor

def get_eval_transforms(audio_preprocessor):
    eval_transforms = []
    eval_transforms.append(lambda xs: [*audio_preprocessor(xs[0:2]), *xs[2:]])
    # These are just some very confusing transposes, that's all.
    # BxFxT -> TxBxF
    eval_transforms.append(lambda xs: [xs[0].permute(2, 0, 1), *xs[1:]])
    eval_transforms = torchvision.transforms.Compose(eval_transforms)
    return eval_transforms


def eval(data_layer, audio_processor, encoderdecoder, greedy_decoder, labels, 
args):
    """performs inference / evaluation
        Args:
            data_layer: data layer object that holds data loader
            audio_processor: data processing module
            encoderdecoder: acoustic model
            greedy_decoder: greedy decoder
            labels: list of labels as output vocabulary
            args: script input arguments
        """
    encoderdecoder.eval()
    with torch.no_grad():
        _global_var_dict = {
            'predictions': [],
            'transcripts': [],
            'logits': [],
        }

        for it, data in enumerate(tqdm(data_layer.data_iterator)):
            (t_audio_signal_e, t_a_sig_length_e,
             transcript_list, t_transcript_e,
             t_transcript_len_e) = audio_processor(data)

            t_predictions_e = greedy_decoder.decode(
                t_audio_signal_e, t_a_sig_length_e)

            values_dict = dict(
                predictions=[t_predictions_e],
                transcript=transcript_list,
                transcript_length=t_transcript_len_e,
            )
            process_evaluation_batch(
                values_dict, _global_var_dict, labels=labels)

            if args.steps is not None and it + 1 >= args.steps:
                break
        wer = process_evaluation_epoch(_global_var_dict)
        print("==========>>>>>>Evaluation WER: {0}\n".format(wer))


def main():
    args = get_args()
    model_definition = toml.load(args.model_toml)
    dataset_vocab = model_definition['labels']['labels']
    ctc_vocab = add_blank_label(dataset_vocab)
    featurizer_config = model_definition['input_eval']

    rnnt_model = get_rnnt_model(featurizer_config, model_definition, ctc_vocab, 
args.ckpt)
    mod, params = rnnt_model_to_tvm_mod(rnnt_model)
    data_layer = gat_data_layer(args, featurizer_config, args.val_manifest, 
dataset_vocab)
    audio_preprocessor = get_audio_preprocessor(featurizer_config)
    eval_transforms = get_eval_transforms(audio_preprocessor)
    greedy_decoder = RNNTGreedyDecoder(len(ctc_vocab) - 1, rnnt_model, 
mod=None, params=None, target=args.target)
    eval(data_layer=data_layer,
        audio_processor=eval_transforms,
        encoderdecoder=rnnt_model,
        greedy_decoder=greedy_decoder,
        labels=ctc_vocab,
        args=args)

if __name__ == '__main__':
    main()
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/5)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/c1bc8a25b81b3c02bd32101c435d28a8452583954b40dabe1d5ffc4bcedaa675).

Reply via email to