Caffe2 - Python API
A deep learning, cross platform ML framework
gru_cell.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import functools
22 from caffe2.python import brew, rnn_cell
23 
24 
26 
27  def __init__(
28  self,
29  input_size,
30  hidden_size,
31  forget_bias, # Currently unused! Values here will be ignored.
32  memory_optimization,
33  drop_states=False,
34  linear_before_reset=False,
35  **kwargs
36  ):
37  super(GRUCell, self).__init__(**kwargs)
38  self.input_size = input_size
39  self.hidden_size = hidden_size
40  self.forget_bias = float(forget_bias)
41  self.memory_optimization = memory_optimization
42  self.drop_states = drop_states
43  self.linear_before_reset = linear_before_reset
44 
45  # Unlike LSTMCell, GRUCell needs the output of one gate to feed into another.
46  # (reset gate -> output_gate)
47  # So, much of the logic to calculate the reset gate output and modified
48  # output gate input is set here, in the graph definition.
49  # The remaining logic lives in in gru_unit_op.{h,cc}.
50  def _apply(
51  self,
52  model,
53  input_t,
54  seq_lengths,
55  states,
56  timestep,
57  extra_inputs=None,
58  ):
59  hidden_t_prev = states[0]
60 
61  # Split input tensors to get inputs for each gate.
62  input_t_reset, input_t_update, input_t_output = model.net.Split(
63  [
64  input_t,
65  ],
66  [
67  self.scope('input_t_reset'),
68  self.scope('input_t_update'),
69  self.scope('input_t_output'),
70  ],
71  axis=2,
72  )
73 
74  # Fully connected layers for reset and update gates.
75  reset_gate_t = brew.fc(
76  model,
77  hidden_t_prev,
78  self.scope('reset_gate_t'),
79  dim_in=self.hidden_size,
80  dim_out=self.hidden_size,
81  axis=2,
82  )
83  update_gate_t = brew.fc(
84  model,
85  hidden_t_prev,
86  self.scope('update_gate_t'),
87  dim_in=self.hidden_size,
88  dim_out=self.hidden_size,
89  axis=2,
90  )
91 
92  # Calculating the modified hidden state going into output gate.
93  reset_gate_t = model.net.Sum(
94  [reset_gate_t, input_t_reset],
95  self.scope('reset_gate_t')
96  )
97  reset_gate_t_sigmoid = model.net.Sigmoid(
98  reset_gate_t,
99  self.scope('reset_gate_t_sigmoid')
100  )
101 
102  # `self.linear_before_reset = True` matches cudnn semantics
103  if self.linear_before_reset:
104  output_gate_fc = brew.fc(
105  model,
106  hidden_t_prev,
107  self.scope('output_gate_t'),
108  dim_in=self.hidden_size,
109  dim_out=self.hidden_size,
110  axis=2,
111  )
112  output_gate_t = model.net.Mul(
113  [reset_gate_t_sigmoid, output_gate_fc],
114  self.scope('output_gate_t_mul')
115  )
116  else:
117  modified_hidden_t_prev = model.net.Mul(
118  [reset_gate_t_sigmoid, hidden_t_prev],
119  self.scope('modified_hidden_t_prev')
120  )
121  output_gate_t = brew.fc(
122  model,
123  modified_hidden_t_prev,
124  self.scope('output_gate_t'),
125  dim_in=self.hidden_size,
126  dim_out=self.hidden_size,
127  axis=2,
128  )
129 
130  # Add input contributions to update and output gate.
131  # We already (in-place) added input contributions to the reset gate.
132  update_gate_t = model.net.Sum(
133  [update_gate_t, input_t_update],
134  self.scope('update_gate_t'),
135  )
136  output_gate_t = model.net.Sum(
137  [output_gate_t, input_t_output],
138  self.scope('output_gate_t_summed'),
139  )
140 
141  # Join gate outputs and add input contributions
142  gates_t, _gates_t_concat_dims = model.net.Concat(
143  [
144  reset_gate_t,
145  update_gate_t,
146  output_gate_t,
147  ],
148  [
149  self.scope('gates_t'),
150  self.scope('_gates_t_concat_dims'),
151  ],
152  axis=2,
153  )
154 
155  if seq_lengths is not None:
156  inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
157  else:
158  inputs = [hidden_t_prev, gates_t, timestep]
159 
160  hidden_t = model.net.GRUUnit(
161  inputs,
162  list(self.get_state_names()),
163  forget_bias=self.forget_bias,
164  drop_states=self.drop_states,
165  sequence_lengths=(seq_lengths is not None),
166  )
167  model.net.AddExternalOutputs(hidden_t)
168  return (hidden_t,)
169 
170  def prepare_input(self, model, input_blob):
171  return brew.fc(
172  model,
173  input_blob,
174  self.scope('i2h'),
175  dim_in=self.input_size,
176  dim_out=3 * self.hidden_size,
177  axis=2,
178  )
179 
180  def get_state_names(self):
181  return (self.scope('hidden_t'),)
182 
183  def get_output_dim(self):
184  return self.hidden_size
185 
186 
187 GRU = functools.partial(rnn_cell._LSTM, GRUCell)