Caffe2 - Python API
A deep learning, cross platform ML framework
rnn_model_with_packed_sequence.py
1 from torch import nn
2 from torch.nn.utils import rnn as rnn_utils
3 
4 
5 class RnnModelWithPackedSequence(nn.Module):
6  def __init__(self, model, batch_first):
7  super(RnnModelWithPackedSequence, self).__init__()
8  self.model = model
9  self.batch_first = batch_first
10 
11  def forward(self, input, *args):
12  args, seq_lengths = args[:-1], args[-1]
13  input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.batch_first)
14  rets = self.model(input, *args)
15  ret, rets = rets[0], rets[1:]
16  ret, _ = rnn_utils.pad_packed_sequence(ret, self.batch_first)
17  return tuple([ret] + list(rets))