5 def get_cudnn_mode(mode):
7 return cudnn.CUDNN_RNN_RELU
8 elif mode ==
'RNN_TANH':
9 return cudnn.CUDNN_RNN_TANH
11 return cudnn.CUDNN_LSTM
13 return cudnn.CUDNN_GRU
15 raise Exception(
"Unknown mode: {}".format(mode))
23 def __init__(self, inner):
29 def __getstate__(self):
32 return "<unserializable>" 34 def __setstate__(self, state):
38 def init_dropout_state(dropout, train, dropout_seed, dropout_state):
40 dropout_p = dropout
if train
else 0
41 if (dropout_desc_name
not in dropout_state)
or (dropout_state[dropout_desc_name].get()
is None):
45 dropout_state[dropout_desc_name] =
Unserializable(torch._cudnn_init_dropout_state(
50 device=torch.device(
'cuda')))
51 dropout_ts = dropout_state[dropout_desc_name].get()