Caffe2 - Python API
A deep learning, cross platform ML framework
gru_cell.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import functools
7 from caffe2.python import brew, rnn_cell
8 
9 
11 
12  def __init__(
13  self,
14  input_size,
15  hidden_size,
16  forget_bias, # Currently unused! Values here will be ignored.
17  memory_optimization,
18  drop_states=False,
19  linear_before_reset=False,
20  **kwargs
21  ):
22  super(GRUCell, self).__init__(**kwargs)
23  self.input_size = input_size
24  self.hidden_size = hidden_size
25  self.forget_bias = float(forget_bias)
26  self.memory_optimization = memory_optimization
27  self.drop_states = drop_states
28  self.linear_before_reset = linear_before_reset
29 
30  # Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
31  # (reset gate -> output_gate)
32  # So, much of the logic to calculate the reset gate output and modified
33  # output gate input is set here, in the graph definition.
34  # The remaining logic lives in in gru_unit_op.{h,cc}.
35  def _apply(
36  self,
37  model,
38  input_t,
39  seq_lengths,
40  states,
41  timestep,
42  extra_inputs=None,
43  ):
44  hidden_t_prev = states[0]
45 
46  # Split input tensors to get inputs for each gate.
47  input_t_reset, input_t_update, input_t_output = model.net.Split(
48  [
49  input_t,
50  ],
51  [
52  self.scope('input_t_reset'),
53  self.scope('input_t_update'),
54  self.scope('input_t_output'),
55  ],
56  axis=2,
57  )
58 
59  # Fully connected layers for reset and update gates.
60  reset_gate_t = brew.fc(
61  model,
62  hidden_t_prev,
63  self.scope('reset_gate_t'),
64  dim_in=self.hidden_size,
65  dim_out=self.hidden_size,
66  axis=2,
67  )
68  update_gate_t = brew.fc(
69  model,
70  hidden_t_prev,
71  self.scope('update_gate_t'),
72  dim_in=self.hidden_size,
73  dim_out=self.hidden_size,
74  axis=2,
75  )
76 
77  # Calculating the modified hidden state going into output gate.
78  reset_gate_t = model.net.Sum(
79  [reset_gate_t, input_t_reset],
80  self.scope('reset_gate_t')
81  )
82  reset_gate_t_sigmoid = model.net.Sigmoid(
83  reset_gate_t,
84  self.scope('reset_gate_t_sigmoid')
85  )
86 
87  # `self.linear_before_reset = True` matches cudnn semantics
88  if self.linear_before_reset:
89  output_gate_fc = brew.fc(
90  model,
91  hidden_t_prev,
92  self.scope('output_gate_t'),
93  dim_in=self.hidden_size,
94  dim_out=self.hidden_size,
95  axis=2,
96  )
97  output_gate_t = model.net.Mul(
98  [reset_gate_t_sigmoid, output_gate_fc],
99  self.scope('output_gate_t_mul')
100  )
101  else:
102  modified_hidden_t_prev = model.net.Mul(
103  [reset_gate_t_sigmoid, hidden_t_prev],
104  self.scope('modified_hidden_t_prev')
105  )
106  output_gate_t = brew.fc(
107  model,
108  modified_hidden_t_prev,
109  self.scope('output_gate_t'),
110  dim_in=self.hidden_size,
111  dim_out=self.hidden_size,
112  axis=2,
113  )
114 
115  # Add input contributions to update and output gate.
116  # We already (in-place) added input contributions to the reset gate.
117  update_gate_t = model.net.Sum(
118  [update_gate_t, input_t_update],
119  self.scope('update_gate_t'),
120  )
121  output_gate_t = model.net.Sum(
122  [output_gate_t, input_t_output],
123  self.scope('output_gate_t_summed'),
124  )
125 
126  # Join gate outputs and add input contributions
127  gates_t, _gates_t_concat_dims = model.net.Concat(
128  [
129  reset_gate_t,
130  update_gate_t,
131  output_gate_t,
132  ],
133  [
134  self.scope('gates_t'),
135  self.scope('_gates_t_concat_dims'),
136  ],
137  axis=2,
138  )
139 
140  if seq_lengths is not None:
141  inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
142  else:
143  inputs = [hidden_t_prev, gates_t, timestep]
144 
145  hidden_t = model.net.GRUUnit(
146  inputs,
147  list(self.get_state_names()),
148  forget_bias=self.forget_bias,
149  drop_states=self.drop_states,
150  sequence_lengths=(seq_lengths is not None),
151  )
152  model.net.AddExternalOutputs(hidden_t)
153  return (hidden_t,)
154 
155  def prepare_input(self, model, input_blob):
156  return brew.fc(
157  model,
158  input_blob,
159  self.scope('i2h'),
160  dim_in=self.input_size,
161  dim_out=3 * self.hidden_size,
162  axis=2,
163  )
164 
165  def get_state_names(self):
166  return (self.scope('hidden_t'),)
167 
168  def get_output_dim(self):
169  return self.hidden_size
170 
171 
172 GRU = functools.partial(rnn_cell._LSTM, GRUCell)