3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
18 return "{}/{}".format(str(scope), str(name))
22 def _calc_weighted_context(
24 encoder_outputs_transposed,
30 attention_weighted_encoder_context = brew.batch_mat_mul(
32 [encoder_outputs_transposed, attention_weights_3d],
33 s(scope,
'attention_weighted_encoder_context'),
36 attention_weighted_encoder_context, _ = model.net.Reshape(
37 attention_weighted_encoder_context,
39 attention_weighted_encoder_context,
40 s(scope,
'attention_weighted_encoder_context_old_shape'),
42 shape=[1, -1, encoder_output_dim],
44 return attention_weighted_encoder_context
48 def _calc_attention_weights(
50 attention_logits_transposed,
54 if encoder_lengths
is not None:
55 attention_logits_transposed = model.net.SequenceMask(
56 [attention_logits_transposed, encoder_lengths],
57 [
'masked_attention_logits'],
62 attention_weights_3d = brew.softmax(
64 attention_logits_transposed,
65 s(scope,
'attention_weights_3d'),
69 return attention_weights_3d
73 def _calc_attention_logits_from_sum_match(
75 decoder_hidden_encoder_outputs_sum,
80 decoder_hidden_encoder_outputs_sum = model.net.Tanh(
81 decoder_hidden_encoder_outputs_sum,
82 decoder_hidden_encoder_outputs_sum,
86 attention_logits = brew.fc(
88 decoder_hidden_encoder_outputs_sum,
89 s(scope,
'attention_logits'),
90 dim_in=encoder_output_dim,
97 attention_logits_transposed = brew.transpose(
100 s(scope,
'attention_logits_transposed'),
103 return attention_logits_transposed
107 def _apply_fc_weight_for_sum_match(
123 output = model.net.Squeeze(
132 def apply_recurrent_attention(
135 encoder_outputs_transposed,
136 weighted_encoder_outputs,
137 decoder_hidden_state_t,
138 decoder_hidden_state_dim,
139 attention_weighted_encoder_context_t_prev,
141 encoder_lengths=
None,
143 weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
145 input=attention_weighted_encoder_context_t_prev,
146 dim_in=encoder_output_dim,
147 dim_out=encoder_output_dim,
149 name=
'weighted_prev_attention_context',
152 weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
154 input=decoder_hidden_state_t,
155 dim_in=decoder_hidden_state_dim,
156 dim_out=encoder_output_dim,
158 name=
'weighted_decoder_hidden_state',
161 decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
163 weighted_prev_attention_context,
164 weighted_decoder_hidden_state,
166 s(scope,
'decoder_hidden_encoder_outputs_sum_tmp'),
169 decoder_hidden_encoder_outputs_sum = model.net.Add(
171 weighted_encoder_outputs,
172 decoder_hidden_encoder_outputs_sum_tmp,
174 s(scope,
'decoder_hidden_encoder_outputs_sum'),
177 attention_logits_transposed = _calc_attention_logits_from_sum_match(
179 decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
180 encoder_output_dim=encoder_output_dim,
185 attention_weights_3d = _calc_attention_weights(
187 attention_logits_transposed=attention_logits_transposed,
189 encoder_lengths=encoder_lengths,
193 attention_weighted_encoder_context = _calc_weighted_context(
195 encoder_outputs_transposed=encoder_outputs_transposed,
196 encoder_output_dim=encoder_output_dim,
197 attention_weights_3d=attention_weights_3d,
200 return attention_weighted_encoder_context, attention_weights_3d, [
201 decoder_hidden_encoder_outputs_sum,
205 def apply_regular_attention(
208 encoder_outputs_transposed,
209 weighted_encoder_outputs,
210 decoder_hidden_state_t,
211 decoder_hidden_state_dim,
213 encoder_lengths=
None,
215 weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
217 input=decoder_hidden_state_t,
218 dim_in=decoder_hidden_state_dim,
219 dim_out=encoder_output_dim,
221 name=
'weighted_decoder_hidden_state',
225 decoder_hidden_encoder_outputs_sum = model.net.Add(
226 [weighted_encoder_outputs, weighted_decoder_hidden_state],
227 s(scope,
'decoder_hidden_encoder_outputs_sum'),
232 attention_logits_transposed = _calc_attention_logits_from_sum_match(
234 decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
235 encoder_output_dim=encoder_output_dim,
240 attention_weights_3d = _calc_attention_weights(
242 attention_logits_transposed=attention_logits_transposed,
244 encoder_lengths=encoder_lengths,
248 attention_weighted_encoder_context = _calc_weighted_context(
250 encoder_outputs_transposed=encoder_outputs_transposed,
251 encoder_output_dim=encoder_output_dim,
252 attention_weights_3d=attention_weights_3d,
255 return attention_weighted_encoder_context, attention_weights_3d, [
256 decoder_hidden_encoder_outputs_sum,
260 def apply_dot_attention(
264 encoder_outputs_transposed,
266 decoder_hidden_state_t,
267 decoder_hidden_state_dim,
269 encoder_lengths=
None,
271 if decoder_hidden_state_dim != encoder_output_dim:
272 weighted_decoder_hidden_state = brew.fc(
274 decoder_hidden_state_t,
275 s(scope,
'weighted_decoder_hidden_state'),
276 dim_in=decoder_hidden_state_dim,
277 dim_out=encoder_output_dim,
281 weighted_decoder_hidden_state = decoder_hidden_state_t
284 squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
285 weighted_decoder_hidden_state,
286 s(scope,
'squeezed_weighted_decoder_hidden_state'),
291 expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
292 squeezed_weighted_decoder_hidden_state,
293 squeezed_weighted_decoder_hidden_state,
298 attention_logits_transposed = model.net.BatchMatMul(
300 encoder_outputs_transposed,
301 expanddims_squeezed_weighted_decoder_hidden_state,
303 s(scope,
'attention_logits'),
308 attention_weights_3d = _calc_attention_weights(
310 attention_logits_transposed=attention_logits_transposed,
312 encoder_lengths=encoder_lengths,
316 attention_weighted_encoder_context = _calc_weighted_context(
318 encoder_outputs_transposed=encoder_outputs_transposed,
319 encoder_output_dim=encoder_output_dim,
320 attention_weights_3d=attention_weights_3d,
323 return attention_weighted_encoder_context, attention_weights_3d, []
326 def apply_soft_coverage_attention(
329 encoder_outputs_transposed,
330 weighted_encoder_outputs,
331 decoder_hidden_state_t,
332 decoder_hidden_state_dim,
339 weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
341 input=decoder_hidden_state_t,
342 dim_in=decoder_hidden_state_dim,
343 dim_out=encoder_output_dim,
345 name=
'weighted_decoder_hidden_state',
349 decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
350 [weighted_encoder_outputs, weighted_decoder_hidden_state],
351 s(scope,
'decoder_hidden_encoder_outputs_sum_tmp'),
355 coverage_t_prev_2d = model.net.Squeeze(
357 s(scope,
'coverage_t_prev_2d'),
361 coverage_t_prev_transposed = brew.transpose(
364 s(scope,
'coverage_t_prev_transposed'),
368 scaled_coverage_weights = model.net.Mul(
369 [coverage_weights, coverage_t_prev_transposed],
370 s(scope,
'scaled_coverage_weights'),
376 decoder_hidden_encoder_outputs_sum = model.net.Add(
377 [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
378 s(scope,
'decoder_hidden_encoder_outputs_sum'),
382 attention_logits_transposed = _calc_attention_logits_from_sum_match(
384 decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
385 encoder_output_dim=encoder_output_dim,
390 attention_weights_3d = _calc_attention_weights(
392 attention_logits_transposed=attention_logits_transposed,
394 encoder_lengths=encoder_lengths,
398 attention_weighted_encoder_context = _calc_weighted_context(
400 encoder_outputs_transposed=encoder_outputs_transposed,
401 encoder_output_dim=encoder_output_dim,
402 attention_weights_3d=attention_weights_3d,
407 attention_weights_2d = model.net.Squeeze(
408 attention_weights_3d,
409 s(scope,
'attention_weights_2d'),
413 coverage_t = model.net.Add(
414 [coverage_t_prev, attention_weights_2d],
415 s(scope,
'coverage_t'),
420 attention_weighted_encoder_context,
421 attention_weights_3d,
422 [decoder_hidden_encoder_outputs_sum],