6 def __init__(self, model, batch_first):
7 super(RnnModelWithPackedSequence, self).__init__()
11 def forward(self, input, *args):
12 args, seq_lengths = args[:-1], args[-1]
13 input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.
batch_first)
14 rets = self.
model(input, *args)
15 ret, rets = rets[0], rets[1:]
16 ret, _ = rnn_utils.pad_packed_sequence(ret, self.
batch_first)
17 return tuple([ret] + list(rets))