6     def __init__(self, model, batch_first):
     7         super(RnnModelWithPackedSequence, self).__init__()
    11     def forward(self, input, *args):
    12         args, seq_lengths = args[:-1], args[-1]
    13         input = rnn_utils.pack_padded_sequence(input, seq_lengths, self.
batch_first)
    14         rets = self.
model(input, *args)
    15         ret, rets = rets[0], rets[1:]
    16         ret, _ = rnn_utils.pad_packed_sequence(ret, self.
batch_first)
    17         return tuple([ret] + list(rets))