Hi, Jiali and I tried to compile RNN-T PyTorch model by TVM. So we implement LSTM in /incubator-tvm/python/tvm/relay/frontend/pytorch.py ``` def _lstm(): def _lstm_cell(input, hidden, params): hx = hidden[0] cx = hidden[1] _w_ih = params[0] _w_hh = params[1] _b_ih = params[2] _b_hh = params[3] i2h = _op.nn.bias_add(_op.nn.dense(input, _w_ih), _b_ih, axis=-1) h2h = _op.nn.bias_add(_op.nn.dense(hx, _w_hh), _b_hh, axis=-1) # if _infer_shape(i2h) != _infer_shape(h2h): # print(_infer_shape(i2h), _infer_shape(h2h)) gates = i2h + h2h slice_gates = _op.split(gates, indices_or_sections=4, axis=1)
in_gate = _activation_map["sigmoid"](slice_gates[0]) # (1, 1024) forget_gate = _activation_map["sigmoid"](slice_gates[1]) cell_gate = _activation_map["tanh"](slice_gates[2]) out_gate = _activation_map["sigmoid"](slice_gates[3]) cy = forget_gate * cx + in_gate * cell_gate # next_c hy = out_gate * _activation_map["tanh"](cy) # next_h return [hy, cy] # return hy, (hy, cy) def gather_params(_params, has_biases): res = [] if has_biases: assert len(_params) % 4 == 0 # "got an incorrect number of RNN parameters(bias)" for i in range(len(_params)): if i % 4 == 0: res.append([_params[i], _params[i + 1], _params[i + 2], _params[i + 3]]) else: assert len(_params) % 2 == 0 # "got an incorrect number of RNN parameters(no bias)" for i in range(len(_params)): if i % 2 == 0: zero = _expr.const(0, dtype="int32") res.append([_params[i], _params[i + 1], zero, zero]) return res def unsqueeze_hidden(hiddens): return _op.transform.expand_dims(hiddens, int(0), 1) def full_layer(step_inputs, input_hidden, params, pre_compute_input=False): step_outputs = [] hidden = input_hidden for input in step_inputs: hidden = _lstm_cell(input, hidden, params) step_outputs.append(hidden[0]) hidden[0] = unsqueeze_hidden(hidden[0]) hidden[1] = unsqueeze_hidden(hidden[1]) return step_outputs, hidden def apply_layer_stack(input, hiddens, weights, num_layers): layer_input = input final_hiddens = [] for i in range(num_layers): # layer_output_outputs: list # layer_output_hidden: list 2 elem layer_output_outputs, layer_output_hidden = full_layer(layer_input, hiddens[i], weights[i]) final_hiddens.append(layer_output_hidden) layer_input = layer_output_outputs layer_out = [] for li in layer_input: layer_out.append(unsqueeze_hidden(li)) # final_hiddens[0] = unsqueeze_hidden(final_hiddens[0]) # final_hiddens[1] = unsqueeze_hidden(final_hiddens[1]) return layer_out, final_hiddens # final_hiddens: list def _lstm_impl(input, params, hx, cx, num_layers): hx_shape = _infer_shape(hx) cx_shape = _infer_shape(cx) layer_hx = _op.split(hx, hx_shape[0], axis=0) layer_cx = _op.split(hx, cx_shape[0], axis=0) total_layers = len(layer_hx) hiddens = [] for i in range(total_layers): hiddens.append([_op.squeeze(layer_hx[i], axis=[0]), _op.squeeze(layer_cx[i], axis=[0])]) res_output, res_hidden = apply_layer_stack(input, hiddens, params, num_layers) hy = [] cy = [] for hidden in res_hidden: hy.append(hidden[0]) cy.append(hidden[1]) hy_res = _op.concatenate(hy, 0) cy_res = _op.concatenate(cy, 0) res_output_res = _op.concatenate(res_output, 0) # print(_infer_shape(res_output_res)) return res_output_res, hy_res, cy_res def _impl(inputs, input_types): _input = inputs[0] # Tensor shape = _infer_shape(_input) temp_input = _op.split(_input, indices_or_sections=shape[0], axis=0) # print("factor: ", shape[0], "axis: ", 0) input_list = [] for item in temp_input: input_list.append(_op.squeeze(item, axis=[0])) hx = inputs[1] # TensorList _params = inputs[2] # TensorList has_biases = inputs[3] # bool num_layers = inputs[4] # int64_t dropout_p = inputs[5] # double train = inputs[6] # bool bidirectional = inputs[7] # bool batch_first = inputs[8] # bool assert len(hx) == 2 # "lstm expects two hidden states" params = gather_params(_params, has_biases) results = _lstm_impl(input_list, params, hx[0], hx[1], num_layers) return results return _impl ``` But we found out that it will generate very long Relay expression and run slowly. merrymercy suggested us to use control flow except unfold LSTM cell. Could you give me some advice, or some examples? Thanks very much. @junrushao1994 @MarisaKirisame @jroesch You can refer to this post: https://discuss.tvm.apache.org/t/auto-scheduling-for-lstm-operator/8158/6 --- [Visit Topic](https://discuss.tvm.apache.org/t/how-to-use-relay-control-flow/8243/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/805960ab906517a7c73c3647d3deb67685de73cf8fc2f34f92dab8f78dfc782f).