Caffe2 - Python API
A deep learning, cross platform ML framework
rnn.py
1 import torch.cuda
2 import torch.backends.cudnn as cudnn
3 
4 
5 def get_cudnn_mode(mode):
6  if mode == 'RNN_RELU':
7  return cudnn.CUDNN_RNN_RELU
8  elif mode == 'RNN_TANH':
9  return cudnn.CUDNN_RNN_TANH
10  elif mode == 'LSTM':
11  return cudnn.CUDNN_LSTM
12  elif mode == 'GRU':
13  return cudnn.CUDNN_GRU
14  else:
15  raise Exception("Unknown mode: {}".format(mode))
16 
17 
18 # NB: We don't actually need this class anymore (in fact, we could serialize the
19 # dropout state for even better reproducibility), but it is kept for backwards
20 # compatibility for old models.
21 class Unserializable(object):
22 
23  def __init__(self, inner):
24  self.inner = inner
25 
26  def get(self):
27  return self.inner
28 
29  def __getstate__(self):
30  # Note: can't return {}, because python2 won't call __setstate__
31  # if the value evaluates to False
32  return "<unserializable>"
33 
34  def __setstate__(self, state):
35  self.inner = None
36 
37 
38 def init_dropout_state(dropout, train, dropout_seed, dropout_state):
39  dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
40  dropout_p = dropout if train else 0
41  if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
42  if dropout_p == 0:
43  dropout_state[dropout_desc_name] = Unserializable(None)
44  else:
45  dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state(
46  dropout_p,
47  train,
48  dropout_seed,
49  self_ty=torch.uint8,
50  device=torch.device('cuda')))
51  dropout_ts = dropout_state[dropout_desc_name].get()
52  return dropout_ts
def current_device()
Definition: __init__.py:349