10 """Container module with an encoder, a recurrent module, and a decoder.""" 12 def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers,
13 dropout=0.5, tie_weights=
False, batchsize=2):
14 super(RNNModel, self).__init__()
15 self.
drop = nn.Dropout(dropout)
16 self.
encoder = nn.Embedding(ntoken, ninp)
17 if rnn_type
in [
'LSTM',
'GRU']:
18 self.
rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
21 nonlinearity = {
'RNN_TANH':
'tanh',
'RNN_RELU':
'relu'}[rnn_type]
23 raise ValueError(
"""An invalid option for `--model` was supplied, 24 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
25 self.
rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
26 self.
decoder = nn.Linear(nhid, ntoken)
36 raise ValueError(
'When using the tied flag, nhid must be equal to emsize')
37 self.decoder.weight = self.encoder.weight
48 """Detach hidden states from their history.""" 49 if isinstance(h, torch.Tensor):
52 return tuple(RNNModel.repackage_hidden(v)
for v
in h)
54 def init_weights(self):
56 self.encoder.weight.data.uniform_(-initrange, initrange)
57 self.decoder.bias.data.fill_(0)
58 self.decoder.weight.data.uniform_(-initrange, initrange)
60 def forward(self, input, hidden):
62 output, hidden = self.
rnn(emb, hidden)
63 output = self.
drop(output)
64 decoded = self.
decoder(output.view(output.size(0) * output.size(1), output.size(2)))
65 self.
hidden = RNNModel.repackage_hidden(hidden)
66 return decoded.view(output.size(0), output.size(1), decoded.size(1))
68 def init_hidden(self, bsz):
69 weight = next(self.parameters()).data
71 return (weight.new(self.
nlayers, bsz, self.
nhid).zero_(),
74 return weight.new(self.
nlayers, bsz, self.
nhid).zero_()
def init_hidden(self, bsz)