model_rnnt.py
```
# Copyright (c) 2019, Myrtle Software Limited. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#           http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch

from rnn import rnn
from rnn import StackTime


class RNNT(torch.nn.Module):
    """A Recurrent Neural Network Transducer (RNN-T).

    Args:
        in_features: Number of input features per step per batch.
        vocab_size: Number of output symbols (inc blank).
        forget_gate_bias: Total initialized value of the bias used in the
            forget gate. Set to None to use PyTorch's default initialisation.
            (See: http://proceedings.mlr.press/v37/jozefowicz15.pdf)
        batch_norm: Use batch normalization in encoder and prediction network
            if true.
        encoder_n_hidden: Internal hidden unit size of the encoder.
        encoder_rnn_layers: Encoder number of layers.
        pred_n_hidden:  Internal hidden unit size of the prediction network.
        pred_rnn_layers: Prediction network number of layers.
        joint_n_hidden: Internal hidden unit size of the joint network.
        rnn_type: string. Type of rnn in SUPPORTED_RNNS.
    """

    def __init__(self, rnnt=None, num_classes=1, **kwargs):
        super().__init__()
        if kwargs.get("no_featurizer", False):
            in_features = kwargs.get("in_features")
        else:
            feat_config = kwargs.get("feature_config")
            # This may be useful in the future, for MLPerf
            # configuration.
            in_features = feat_config['features'] * \
                feat_config.get("frame_splicing", 1)

        self._pred_n_hidden = rnnt['pred_n_hidden']

        self.encoder_n_hidden = rnnt["encoder_n_hidden"]
        self.encoder_pre_rnn_layers = rnnt["encoder_pre_rnn_layers"]
        self.encoder_post_rnn_layers = rnnt["encoder_post_rnn_layers"]

        self.pred_n_hidden = rnnt["pred_n_hidden"]
        self.pred_rnn_layers = rnnt["pred_rnn_layers"]

        
        self.encoder = Encoder(in_features,
                               rnnt["encoder_n_hidden"],
                               rnnt["encoder_pre_rnn_layers"],
                               rnnt["encoder_post_rnn_layers"],
                               rnnt["forget_gate_bias"],
                               None if "norm" not in rnnt else rnnt["norm"],
                               rnnt["rnn_type"],
                               rnnt["encoder_stack_time_factor"],
                               rnnt["dropout"],
                               )

        self.prediction = self._predict(
            num_classes,
            rnnt["pred_n_hidden"],
            rnnt["pred_rnn_layers"],
            rnnt["forget_gate_bias"],
            None if "norm" not in rnnt else rnnt["norm"],
            rnnt["rnn_type"],
            rnnt["dropout"],
        )
       
        self.joint_net = self._joint_net(
            num_classes,
            rnnt["pred_n_hidden"],
            rnnt["encoder_n_hidden"],
            rnnt["joint_n_hidden"],
            rnnt["dropout"],
        )

   
    def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers,
                 forget_gate_bias, norm, rnn_type, dropout):
        layers = torch.nn.ModuleDict({
            "embed": torch.nn.Embedding(vocab_size - 1, pred_n_hidden),
            "dec_rnn": rnn(
                rnn=rnn_type,
                input_size=pred_n_hidden,
                hidden_size=pred_n_hidden,
                num_layers=pred_rnn_layers,
                norm=norm,
                forget_gate_bias=forget_gate_bias,
                dropout=dropout,
            ),
        })
        return layers

    def _joint_net(self, vocab_size, pred_n_hidden, enc_n_hidden,
                   joint_n_hidden, dropout):
        layers = [
            torch.nn.Linear(pred_n_hidden + enc_n_hidden, joint_n_hidden),
            torch.nn.ReLU(),
        ] + ([torch.nn.Dropout(p=dropout), ] if dropout else []) + [
            torch.nn.Linear(joint_n_hidden, vocab_size)
        ]
        return torch.nn.Sequential(
            *layers
        )

    # Perhaps what I really need to do is provide a value for
    # state. But why can't I just specify a type for abstract
    # intepretation? That's what I really want!
    # We really want two "states" here...
    def forward(self, batch, state=None):
        # batch: ((x, y), (x_lens, y_lens))

        raise RuntimeError(
            "RNNT::forward is not currently used. "
            "It corresponds to training, where your entire output sequence "
            "is known before hand.")

        # x: TxBxF
        (x, y_packed), (x_lens, y_lens) = batch
        x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, x_lens)

        f, x_lens = self.encode(x_packed)

        g, _ = self.predict(y_packed, state)
        out = self.joint(f, g)

        return out, (x_lens, y_lens)


    def predict(self, y, state=None, add_sos=True):
        """
        B - batch size
        U - label length
        H - Hidden dimension size
        L - Number of decoder layers = 2

        Args:
            y: (B, U)

        Returns:
            Tuple (g, hid) where:
                g: (B, U + 1, H)
                hid: (h, c) where h is the final sequence hidden state and c is
                    the final cell state:
                        h (tensor), shape (L, B, H)
                        c (tensor), shape (L, B, H)
        """
        if isinstance(y, torch.Tensor):
            y = self.prediction["embed"](y)
        elif isinstance(y, torch.nn.utils.rnn.PackedSequence):
            # Teacher-forced training mode
            # (B, U) -> (B, U, H)
            y._replace(data=self.prediction["embed"](y.data))
        else:
            # inference mode
            B = 1 if state is None else state[0].size(1)
            y = torch.zeros((B, 1, self.pred_n_hidden)).to(
                device=self.joint_net[0].weight.device,
                dtype=self.joint_net[0].weight.dtype
            )

        # preprend blank "start of sequence" symbol
        if add_sos:
            B, U, H = y.shape
            start = torch.zeros((B, 1, H)).to(device=y.device, dtype=y.dtype)
            y = torch.cat([start, y], dim=1).contiguous()   # (B, U + 1, H)
        else:
            start = None   # makes del call later easier



        y = y.transpose(0, 1)  # .contiguous()   # (U + 1, B, H)
        g, hid = self.prediction["dec_rnn"](y, state)
        g = g.transpose(0, 1)  # .contiguous()   # (B, U + 1, H)
        del y, start, state
        return g, hid


    def joint(self, f, g):
        """
        f should be shape (B, T, H)
        g should be shape (B, U + 1, H)

        returns:
            logits of shape (B, T, U, K + 1)
        """
        # Combine the input states and the output states
        B, T, H = f.shape
        B, U_, H2 = g.shape

        f = f.unsqueeze(dim=2)   # (B, T, 1, H)
        f = f.expand((B, T, U_, H))

        g = g.unsqueeze(dim=1)   # (B, 1, U + 1, H)
        g = g.expand((B, T, U_, H2))

        inp = torch.cat([f, g], dim=3)   # (B, T, U, 2H)
        res = self.joint_net(inp)
        del f, g, inp
        return res



class Encoder(torch.nn.Module):
    def __init__(self, in_features, encoder_n_hidden,
                 encoder_pre_rnn_layers, encoder_post_rnn_layers,
                 forget_gate_bias, norm, rnn_type, encoder_stack_time_factor,
                 dropout):
        super().__init__()
        self.pre_rnn = rnn(
            rnn=rnn_type,
            input_size=in_features,
            hidden_size=encoder_n_hidden,
            num_layers=encoder_pre_rnn_layers,
            norm=norm,
            forget_gate_bias=forget_gate_bias,
            dropout=dropout,
        )
        self.stack_time = StackTime(factor=encoder_stack_time_factor)
        self.post_rnn = rnn(
            rnn=rnn_type,
            input_size=encoder_stack_time_factor * encoder_n_hidden,
            hidden_size=encoder_n_hidden,
            num_layers=encoder_post_rnn_layers,
            norm=norm,
            forget_gate_bias=forget_gate_bias,
            norm_first_rnn=True,
            dropout=dropout,
        )

    def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
        x, _ = self.pre_rnn(x, None)
        x, x_lens = self.stack_time(x, x_lens)
        x, _ = self.post_rnn(x, None)
        x = x.transpose(0, 1)
        return x, x_lens






def label_collate(labels):
    """Collates the label inputs for the rnn-t prediction network.

    If `labels` is already in torch.Tensor form this is a no-op.

    Args:
        labels: A torch.Tensor List of label indexes or a torch.Tensor.

    Returns:
        A padded torch.Tensor of shape (batch, max_seq_len).
    """

    if isinstance(labels, torch.Tensor):
        return labels.type(torch.int64)
    if not isinstance(labels, (list, tuple)):
        raise ValueError(
            f"`labels` should be a list or tensor not {type(labels)}"
        )

    batch_size = len(labels)
    max_len = max(len(l) for l in labels)

    cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32)
    for e, l in enumerate(labels):
        cat_labels[e, :len(l)] = l
    labels = torch.LongTensor(cat_labels)

    return labels
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/6)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/7744e63c8ab2438b276e808229cef758a83d3010b48b0980fa2df8b3fea85150).

Reply via email to