Caffe2 - Python API
A deep learning, cross platform ML framework
attention.py
1 ## @package attention
2 # Module caffe2.python.attention
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import brew
9 
10 
12  Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
13 
14 
15 def s(scope, name):
16  # We have to manually scope due to our internal/external blob
17  # relationships.
18  return "{}/{}".format(str(scope), str(name))
19 
20 
21 # c_i = \sum_j w_{ij}\textbf{s}_j
22 def _calc_weighted_context(
23  model,
24  encoder_outputs_transposed,
25  encoder_output_dim,
26  attention_weights_3d,
27  scope,
28 ):
29  # [batch_size, encoder_output_dim, 1]
30  attention_weighted_encoder_context = brew.batch_mat_mul(
31  model,
32  [encoder_outputs_transposed, attention_weights_3d],
33  s(scope, 'attention_weighted_encoder_context'),
34  )
35  # [batch_size, encoder_output_dim]
36  attention_weighted_encoder_context, _ = model.net.Reshape(
37  attention_weighted_encoder_context,
38  [
39  attention_weighted_encoder_context,
40  s(scope, 'attention_weighted_encoder_context_old_shape'),
41  ],
42  shape=[1, -1, encoder_output_dim],
43  )
44  return attention_weighted_encoder_context
45 
46 
47 # Calculate a softmax over the passed in attention energy logits
48 def _calc_attention_weights(
49  model,
50  attention_logits_transposed,
51  scope,
52  encoder_lengths=None,
53 ):
54  if encoder_lengths is not None:
55  attention_logits_transposed = model.net.SequenceMask(
56  [attention_logits_transposed, encoder_lengths],
57  ['masked_attention_logits'],
58  mode='sequence',
59  )
60 
61  # [batch_size, encoder_length, 1]
62  attention_weights_3d = brew.softmax(
63  model,
64  attention_logits_transposed,
65  s(scope, 'attention_weights_3d'),
66  engine='CUDNN',
67  axis=1,
68  )
69  return attention_weights_3d
70 
71 
72 # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
73 def _calc_attention_logits_from_sum_match(
74  model,
75  decoder_hidden_encoder_outputs_sum,
76  encoder_output_dim,
77  scope,
78 ):
79  # [encoder_length, batch_size, encoder_output_dim]
80  decoder_hidden_encoder_outputs_sum = model.net.Tanh(
81  decoder_hidden_encoder_outputs_sum,
82  decoder_hidden_encoder_outputs_sum,
83  )
84 
85  # [encoder_length, batch_size, 1]
86  attention_logits = brew.fc(
87  model,
88  decoder_hidden_encoder_outputs_sum,
89  s(scope, 'attention_logits'),
90  dim_in=encoder_output_dim,
91  dim_out=1,
92  axis=2,
93  freeze_bias=True,
94  )
95 
96  # [batch_size, encoder_length, 1]
97  attention_logits_transposed = brew.transpose(
98  model,
99  attention_logits,
100  s(scope, 'attention_logits_transposed'),
101  axes=[1, 0, 2],
102  )
103  return attention_logits_transposed
104 
105 
106 # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
107 def _apply_fc_weight_for_sum_match(
108  model,
109  input,
110  dim_in,
111  dim_out,
112  scope,
113  name,
114 ):
115  output = brew.fc(
116  model,
117  input,
118  s(scope, name),
119  dim_in=dim_in,
120  dim_out=dim_out,
121  axis=2,
122  )
123  output = model.net.Squeeze(
124  output,
125  output,
126  dims=[0],
127  )
128  return output
129 
130 
131 # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
132 def apply_recurrent_attention(
133  model,
134  encoder_output_dim,
135  encoder_outputs_transposed,
136  weighted_encoder_outputs,
137  decoder_hidden_state_t,
138  decoder_hidden_state_dim,
139  attention_weighted_encoder_context_t_prev,
140  scope,
141  encoder_lengths=None,
142 ):
143  weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
144  model=model,
145  input=attention_weighted_encoder_context_t_prev,
146  dim_in=encoder_output_dim,
147  dim_out=encoder_output_dim,
148  scope=scope,
149  name='weighted_prev_attention_context',
150  )
151 
152  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
153  model=model,
154  input=decoder_hidden_state_t,
155  dim_in=decoder_hidden_state_dim,
156  dim_out=encoder_output_dim,
157  scope=scope,
158  name='weighted_decoder_hidden_state',
159  )
160  # [1, batch_size, encoder_output_dim]
161  decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
162  [
163  weighted_prev_attention_context,
164  weighted_decoder_hidden_state,
165  ],
166  s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
167  )
168  # [encoder_length, batch_size, encoder_output_dim]
169  decoder_hidden_encoder_outputs_sum = model.net.Add(
170  [
171  weighted_encoder_outputs,
172  decoder_hidden_encoder_outputs_sum_tmp,
173  ],
174  s(scope, 'decoder_hidden_encoder_outputs_sum'),
175  broadcast=1,
176  )
177  attention_logits_transposed = _calc_attention_logits_from_sum_match(
178  model=model,
179  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
180  encoder_output_dim=encoder_output_dim,
181  scope=scope,
182  )
183 
184  # [batch_size, encoder_length, 1]
185  attention_weights_3d = _calc_attention_weights(
186  model=model,
187  attention_logits_transposed=attention_logits_transposed,
188  scope=scope,
189  encoder_lengths=encoder_lengths,
190  )
191 
192  # [batch_size, encoder_output_dim, 1]
193  attention_weighted_encoder_context = _calc_weighted_context(
194  model=model,
195  encoder_outputs_transposed=encoder_outputs_transposed,
196  encoder_output_dim=encoder_output_dim,
197  attention_weights_3d=attention_weights_3d,
198  scope=scope,
199  )
200  return attention_weighted_encoder_context, attention_weights_3d, [
201  decoder_hidden_encoder_outputs_sum,
202  ]
203 
204 
205 def apply_regular_attention(
206  model,
207  encoder_output_dim,
208  encoder_outputs_transposed,
209  weighted_encoder_outputs,
210  decoder_hidden_state_t,
211  decoder_hidden_state_dim,
212  scope,
213  encoder_lengths=None,
214 ):
215  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
216  model=model,
217  input=decoder_hidden_state_t,
218  dim_in=decoder_hidden_state_dim,
219  dim_out=encoder_output_dim,
220  scope=scope,
221  name='weighted_decoder_hidden_state',
222  )
223 
224  # [encoder_length, batch_size, encoder_output_dim]
225  decoder_hidden_encoder_outputs_sum = model.net.Add(
226  [weighted_encoder_outputs, weighted_decoder_hidden_state],
227  s(scope, 'decoder_hidden_encoder_outputs_sum'),
228  broadcast=1,
229  use_grad_hack=1,
230  )
231 
232  attention_logits_transposed = _calc_attention_logits_from_sum_match(
233  model=model,
234  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
235  encoder_output_dim=encoder_output_dim,
236  scope=scope,
237  )
238 
239  # [batch_size, encoder_length, 1]
240  attention_weights_3d = _calc_attention_weights(
241  model=model,
242  attention_logits_transposed=attention_logits_transposed,
243  scope=scope,
244  encoder_lengths=encoder_lengths,
245  )
246 
247  # [batch_size, encoder_output_dim, 1]
248  attention_weighted_encoder_context = _calc_weighted_context(
249  model=model,
250  encoder_outputs_transposed=encoder_outputs_transposed,
251  encoder_output_dim=encoder_output_dim,
252  attention_weights_3d=attention_weights_3d,
253  scope=scope,
254  )
255  return attention_weighted_encoder_context, attention_weights_3d, [
256  decoder_hidden_encoder_outputs_sum,
257  ]
258 
259 
260 def apply_dot_attention(
261  model,
262  encoder_output_dim,
263  # [batch_size, encoder_output_dim, encoder_length]
264  encoder_outputs_transposed,
265  # [1, batch_size, decoder_state_dim]
266  decoder_hidden_state_t,
267  decoder_hidden_state_dim,
268  scope,
269  encoder_lengths=None,
270 ):
271  if decoder_hidden_state_dim != encoder_output_dim:
272  weighted_decoder_hidden_state = brew.fc(
273  model,
274  decoder_hidden_state_t,
275  s(scope, 'weighted_decoder_hidden_state'),
276  dim_in=decoder_hidden_state_dim,
277  dim_out=encoder_output_dim,
278  axis=2,
279  )
280  else:
281  weighted_decoder_hidden_state = decoder_hidden_state_t
282 
283  # [batch_size, decoder_state_dim]
284  squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
285  weighted_decoder_hidden_state,
286  s(scope, 'squeezed_weighted_decoder_hidden_state'),
287  dims=[0],
288  )
289 
290  # [batch_size, decoder_state_dim, 1]
291  expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
292  squeezed_weighted_decoder_hidden_state,
293  squeezed_weighted_decoder_hidden_state,
294  dims=[2],
295  )
296 
297  # [batch_size, encoder_output_dim, 1]
298  attention_logits_transposed = model.net.BatchMatMul(
299  [
300  encoder_outputs_transposed,
301  expanddims_squeezed_weighted_decoder_hidden_state,
302  ],
303  s(scope, 'attention_logits'),
304  trans_a=1,
305  )
306 
307  # [batch_size, encoder_length, 1]
308  attention_weights_3d = _calc_attention_weights(
309  model=model,
310  attention_logits_transposed=attention_logits_transposed,
311  scope=scope,
312  encoder_lengths=encoder_lengths,
313  )
314 
315  # [batch_size, encoder_output_dim, 1]
316  attention_weighted_encoder_context = _calc_weighted_context(
317  model=model,
318  encoder_outputs_transposed=encoder_outputs_transposed,
319  encoder_output_dim=encoder_output_dim,
320  attention_weights_3d=attention_weights_3d,
321  scope=scope,
322  )
323  return attention_weighted_encoder_context, attention_weights_3d, []
324 
325 
326 def apply_soft_coverage_attention(
327  model,
328  encoder_output_dim,
329  encoder_outputs_transposed,
330  weighted_encoder_outputs,
331  decoder_hidden_state_t,
332  decoder_hidden_state_dim,
333  scope,
334  encoder_lengths,
335  coverage_t_prev,
336  coverage_weights,
337 ):
338 
339  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
340  model=model,
341  input=decoder_hidden_state_t,
342  dim_in=decoder_hidden_state_dim,
343  dim_out=encoder_output_dim,
344  scope=scope,
345  name='weighted_decoder_hidden_state',
346  )
347 
348  # [encoder_length, batch_size, encoder_output_dim]
349  decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
350  [weighted_encoder_outputs, weighted_decoder_hidden_state],
351  s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
352  broadcast=1,
353  )
354  # [batch_size, encoder_length]
355  coverage_t_prev_2d = model.net.Squeeze(
356  coverage_t_prev,
357  s(scope, 'coverage_t_prev_2d'),
358  dims=[0],
359  )
360  # [encoder_length, batch_size]
361  coverage_t_prev_transposed = brew.transpose(
362  model,
363  coverage_t_prev_2d,
364  s(scope, 'coverage_t_prev_transposed'),
365  )
366 
367  # [encoder_length, batch_size, encoder_output_dim]
368  scaled_coverage_weights = model.net.Mul(
369  [coverage_weights, coverage_t_prev_transposed],
370  s(scope, 'scaled_coverage_weights'),
371  broadcast=1,
372  axis=0,
373  )
374 
375  # [encoder_length, batch_size, encoder_output_dim]
376  decoder_hidden_encoder_outputs_sum = model.net.Add(
377  [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights],
378  s(scope, 'decoder_hidden_encoder_outputs_sum'),
379  )
380 
381  # [batch_size, encoder_length, 1]
382  attention_logits_transposed = _calc_attention_logits_from_sum_match(
383  model=model,
384  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
385  encoder_output_dim=encoder_output_dim,
386  scope=scope,
387  )
388 
389  # [batch_size, encoder_length, 1]
390  attention_weights_3d = _calc_attention_weights(
391  model=model,
392  attention_logits_transposed=attention_logits_transposed,
393  scope=scope,
394  encoder_lengths=encoder_lengths,
395  )
396 
397  # [batch_size, encoder_output_dim, 1]
398  attention_weighted_encoder_context = _calc_weighted_context(
399  model=model,
400  encoder_outputs_transposed=encoder_outputs_transposed,
401  encoder_output_dim=encoder_output_dim,
402  attention_weights_3d=attention_weights_3d,
403  scope=scope,
404  )
405 
406  # [batch_size, encoder_length]
407  attention_weights_2d = model.net.Squeeze(
408  attention_weights_3d,
409  s(scope, 'attention_weights_2d'),
410  dims=[2],
411  )
412 
413  coverage_t = model.net.Add(
414  [coverage_t_prev, attention_weights_2d],
415  s(scope, 'coverage_t'),
416  broadcast=1,
417  )
418 
419  return (
420  attention_weighted_encoder_context,
421  attention_weights_3d,
422  [decoder_hidden_encoder_outputs_sum],
423  coverage_t,
424  )