1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
19 linear_before_reset=
False,
22 super(GRUCell, self).__init__(**kwargs)
44 hidden_t_prev = states[0]
47 input_t_reset, input_t_update, input_t_output = model.net.Split(
52 self.
scope(
'input_t_reset'),
53 self.
scope(
'input_t_update'),
54 self.
scope(
'input_t_output'),
60 reset_gate_t = brew.fc(
63 self.
scope(
'reset_gate_t'),
68 update_gate_t = brew.fc(
71 self.
scope(
'update_gate_t'),
78 reset_gate_t = model.net.Sum(
79 [reset_gate_t, input_t_reset],
80 self.
scope(
'reset_gate_t')
82 reset_gate_t_sigmoid = model.net.Sigmoid(
84 self.
scope(
'reset_gate_t_sigmoid')
89 output_gate_fc = brew.fc(
92 self.
scope(
'output_gate_t'),
97 output_gate_t = model.net.Mul(
98 [reset_gate_t_sigmoid, output_gate_fc],
99 self.
scope(
'output_gate_t_mul')
102 modified_hidden_t_prev = model.net.Mul(
103 [reset_gate_t_sigmoid, hidden_t_prev],
104 self.
scope(
'modified_hidden_t_prev')
106 output_gate_t = brew.fc(
108 modified_hidden_t_prev,
109 self.
scope(
'output_gate_t'),
117 update_gate_t = model.net.Sum(
118 [update_gate_t, input_t_update],
119 self.
scope(
'update_gate_t'),
121 output_gate_t = model.net.Sum(
122 [output_gate_t, input_t_output],
123 self.
scope(
'output_gate_t_summed'),
127 gates_t, _gates_t_concat_dims = model.net.Concat(
134 self.
scope(
'gates_t'),
135 self.
scope(
'_gates_t_concat_dims'),
140 if seq_lengths
is not None:
141 inputs = [hidden_t_prev, gates_t, seq_lengths, timestep]
143 inputs = [hidden_t_prev, gates_t, timestep]
145 hidden_t = model.net.GRUUnit(
150 sequence_lengths=(seq_lengths
is not None),
152 model.net.AddExternalOutputs(hidden_t)
155 def prepare_input(self, model, input_blob):
165 def get_state_names(self):
166 return (self.
scope(
'hidden_t'),)
168 def get_output_dim(self):
172 GRU = functools.partial(rnn_cell._LSTM, GRUCell)
def get_state_names(self)