Caffe2 - Python API
A deep learning, cross platform ML framework
attention.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package attention
17 # Module caffe2.python.attention
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import brew
24 
25 
27  Regular, Recurrent, Dot, SoftCoverage = tuple(range(4))
28 
29 
30 def s(scope, name):
31  # We have to manually scope due to our internal/external blob
32  # relationships.
33  return "{}/{}".format(str(scope), str(name))
34 
35 
36 # c_i = \sum_j w_{ij}\textbf{s}_j
37 def _calc_weighted_context(
38  model,
39  encoder_outputs_transposed,
40  encoder_output_dim,
41  attention_weights_3d,
42  scope,
43 ):
44  # [batch_size, encoder_output_dim, 1]
45  attention_weighted_encoder_context = brew.batch_mat_mul(
46  model,
47  [encoder_outputs_transposed, attention_weights_3d],
48  s(scope, 'attention_weighted_encoder_context'),
49  )
50  # [batch_size, encoder_output_dim]
51  attention_weighted_encoder_context, _ = model.net.Reshape(
52  attention_weighted_encoder_context,
53  [
54  attention_weighted_encoder_context,
55  s(scope, 'attention_weighted_encoder_context_old_shape'),
56  ],
57  shape=[1, -1, encoder_output_dim],
58  )
59  return attention_weighted_encoder_context
60 
61 
62 # Calculate a softmax over the passed in attention energy logits
63 def _calc_attention_weights(
64  model,
65  attention_logits_transposed,
66  scope,
67  encoder_lengths=None,
68 ):
69  if encoder_lengths is not None:
70  attention_logits_transposed = model.net.SequenceMask(
71  [attention_logits_transposed, encoder_lengths],
72  ['masked_attention_logits'],
73  mode='sequence',
74  )
75 
76  # [batch_size, encoder_length, 1]
77  attention_weights_3d = brew.softmax(
78  model,
79  attention_logits_transposed,
80  s(scope, 'attention_weights_3d'),
81  engine='CUDNN',
82  axis=1,
83  )
84  return attention_weights_3d
85 
86 
87 # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
88 def _calc_attention_logits_from_sum_match(
89  model,
90  decoder_hidden_encoder_outputs_sum,
91  encoder_output_dim,
92  scope,
93 ):
94  # [encoder_length, batch_size, encoder_output_dim]
95  decoder_hidden_encoder_outputs_sum = model.net.Tanh(
96  decoder_hidden_encoder_outputs_sum,
97  decoder_hidden_encoder_outputs_sum,
98  )
99 
100  # [encoder_length, batch_size, 1]
101  attention_logits = brew.fc(
102  model,
103  decoder_hidden_encoder_outputs_sum,
104  s(scope, 'attention_logits'),
105  dim_in=encoder_output_dim,
106  dim_out=1,
107  axis=2,
108  freeze_bias=True,
109  )
110 
111  # [batch_size, encoder_length, 1]
112  attention_logits_transposed = brew.transpose(
113  model,
114  attention_logits,
115  s(scope, 'attention_logits_transposed'),
116  axes=[1, 0, 2],
117  )
118  return attention_logits_transposed
119 
120 
121 # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
122 def _apply_fc_weight_for_sum_match(
123  model,
124  input,
125  dim_in,
126  dim_out,
127  scope,
128  name,
129 ):
130  output = brew.fc(
131  model,
132  input,
133  s(scope, name),
134  dim_in=dim_in,
135  dim_out=dim_out,
136  axis=2,
137  )
138  output = model.net.Squeeze(
139  output,
140  output,
141  dims=[0],
142  )
143  return output
144 
145 
146 # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
147 def apply_recurrent_attention(
148  model,
149  encoder_output_dim,
150  encoder_outputs_transposed,
151  weighted_encoder_outputs,
152  decoder_hidden_state_t,
153  decoder_hidden_state_dim,
154  attention_weighted_encoder_context_t_prev,
155  scope,
156  encoder_lengths=None,
157 ):
158  weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
159  model=model,
160  input=attention_weighted_encoder_context_t_prev,
161  dim_in=encoder_output_dim,
162  dim_out=encoder_output_dim,
163  scope=scope,
164  name='weighted_prev_attention_context',
165  )
166 
167  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
168  model=model,
169  input=decoder_hidden_state_t,
170  dim_in=decoder_hidden_state_dim,
171  dim_out=encoder_output_dim,
172  scope=scope,
173  name='weighted_decoder_hidden_state',
174  )
175  # [1, batch_size, encoder_output_dim]
176  decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
177  [
178  weighted_prev_attention_context,
179  weighted_decoder_hidden_state,
180  ],
181  s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
182  )
183  # [encoder_length, batch_size, encoder_output_dim]
184  decoder_hidden_encoder_outputs_sum = model.net.Add(
185  [
186  weighted_encoder_outputs,
187  decoder_hidden_encoder_outputs_sum_tmp,
188  ],
189  s(scope, 'decoder_hidden_encoder_outputs_sum'),
190  broadcast=1,
191  )
192  attention_logits_transposed = _calc_attention_logits_from_sum_match(
193  model=model,
194  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
195  encoder_output_dim=encoder_output_dim,
196  scope=scope,
197  )
198 
199  # [batch_size, encoder_length, 1]
200  attention_weights_3d = _calc_attention_weights(
201  model=model,
202  attention_logits_transposed=attention_logits_transposed,
203  scope=scope,
204  encoder_lengths=encoder_lengths,
205  )
206 
207  # [batch_size, encoder_output_dim, 1]
208  attention_weighted_encoder_context = _calc_weighted_context(
209  model=model,
210  encoder_outputs_transposed=encoder_outputs_transposed,
211  encoder_output_dim=encoder_output_dim,
212  attention_weights_3d=attention_weights_3d,
213  scope=scope,
214  )
215  return attention_weighted_encoder_context, attention_weights_3d, [
216  decoder_hidden_encoder_outputs_sum,
217  ]
218 
219 
220 def apply_regular_attention(
221  model,
222  encoder_output_dim,
223  encoder_outputs_transposed,
224  weighted_encoder_outputs,
225  decoder_hidden_state_t,
226  decoder_hidden_state_dim,
227  scope,
228  encoder_lengths=None,
229 ):
230  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
231  model=model,
232  input=decoder_hidden_state_t,
233  dim_in=decoder_hidden_state_dim,
234  dim_out=encoder_output_dim,
235  scope=scope,
236  name='weighted_decoder_hidden_state',
237  )
238 
239  # [encoder_length, batch_size, encoder_output_dim]
240  decoder_hidden_encoder_outputs_sum = model.net.Add(
241  [weighted_encoder_outputs, weighted_decoder_hidden_state],
242  s(scope, 'decoder_hidden_encoder_outputs_sum'),
243  broadcast=1,
244  use_grad_hack=1,
245  )
246 
247  attention_logits_transposed = _calc_attention_logits_from_sum_match(
248  model=model,
249  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
250  encoder_output_dim=encoder_output_dim,
251  scope=scope,
252  )
253 
254  # [batch_size, encoder_length, 1]
255  attention_weights_3d = _calc_attention_weights(
256  model=model,
257  attention_logits_transposed=attention_logits_transposed,
258  scope=scope,
259  encoder_lengths=encoder_lengths,
260  )
261 
262  # [batch_size, encoder_output_dim, 1]
263  attention_weighted_encoder_context = _calc_weighted_context(
264  model=model,
265  encoder_outputs_transposed=encoder_outputs_transposed,
266  encoder_output_dim=encoder_output_dim,
267  attention_weights_3d=attention_weights_3d,
268  scope=scope,
269  )
270  return attention_weighted_encoder_context, attention_weights_3d, [
271  decoder_hidden_encoder_outputs_sum,
272  ]
273 
274 
275 def apply_dot_attention(
276  model,
277  encoder_output_dim,
278  # [batch_size, encoder_output_dim, encoder_length]
279  encoder_outputs_transposed,
280  # [1, batch_size, decoder_state_dim]
281  decoder_hidden_state_t,
282  decoder_hidden_state_dim,
283  scope,
284  encoder_lengths=None,
285 ):
286  if decoder_hidden_state_dim != encoder_output_dim:
287  weighted_decoder_hidden_state = brew.fc(
288  model,
289  decoder_hidden_state_t,
290  s(scope, 'weighted_decoder_hidden_state'),
291  dim_in=decoder_hidden_state_dim,
292  dim_out=encoder_output_dim,
293  axis=2,
294  )
295  else:
296  weighted_decoder_hidden_state = decoder_hidden_state_t
297 
298  # [batch_size, decoder_state_dim]
299  squeezed_weighted_decoder_hidden_state = model.net.Squeeze(
300  weighted_decoder_hidden_state,
301  s(scope, 'squeezed_weighted_decoder_hidden_state'),
302  dims=[0],
303  )
304 
305  # [batch_size, decoder_state_dim, 1]
306  expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims(
307  squeezed_weighted_decoder_hidden_state,
308  squeezed_weighted_decoder_hidden_state,
309  dims=[2],
310  )
311 
312  # [batch_size, encoder_output_dim, 1]
313  attention_logits_transposed = model.net.BatchMatMul(
314  [
315  encoder_outputs_transposed,
316  expanddims_squeezed_weighted_decoder_hidden_state,
317  ],
318  s(scope, 'attention_logits'),
319  trans_a=1,
320  )
321 
322  # [batch_size, encoder_length, 1]
323  attention_weights_3d = _calc_attention_weights(
324  model=model,
325  attention_logits_transposed=attention_logits_transposed,
326  scope=scope,
327  encoder_lengths=encoder_lengths,
328  )
329 
330  # [batch_size, encoder_output_dim, 1]
331  attention_weighted_encoder_context = _calc_weighted_context(
332  model=model,
333  encoder_outputs_transposed=encoder_outputs_transposed,
334  encoder_output_dim=encoder_output_dim,
335  attention_weights_3d=attention_weights_3d,
336  scope=scope,
337  )
338  return attention_weighted_encoder_context, attention_weights_3d, []
339 
340 
341 def apply_soft_coverage_attention(
342  model,
343  encoder_output_dim,
344  encoder_outputs_transposed,
345  weighted_encoder_outputs,
346  decoder_hidden_state_t,
347  decoder_hidden_state_dim,
348  scope,
349  encoder_lengths,
350  coverage_t_prev,
351  coverage_weights,
352 ):
353 
354  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
355  model=model,
356  input=decoder_hidden_state_t,
357  dim_in=decoder_hidden_state_dim,
358  dim_out=encoder_output_dim,
359  scope=scope,
360  name='weighted_decoder_hidden_state',
361  )
362 
363  # [encoder_length, batch_size, encoder_output_dim]
364  decoder_hidden_encoder_outputs_sum = model.net.Add(
365  [weighted_encoder_outputs, weighted_decoder_hidden_state],
366  s(scope, 'decoder_hidden_encoder_outputs_sum'),
367  broadcast=1,
368  )
369  # [batch_size, encoder_length]
370  coverage_t_prev_2d = model.net.Squeeze(
371  coverage_t_prev,
372  s(scope, 'coverage_t_prev_2d'),
373  dims=[0],
374  )
375  # [encoder_length, batch_size]
376  coverage_t_prev_transposed = brew.transpose(
377  model,
378  coverage_t_prev_2d,
379  s(scope, 'coverage_t_prev_transposed'),
380  )
381 
382  # [encoder_length, batch_size, encoder_output_dim]
383  scaled_coverage_weights = model.net.Mul(
384  [coverage_weights, coverage_t_prev_transposed],
385  s(scope, 'scaled_coverage_weights'),
386  broadcast=1,
387  axis=0,
388  )
389 
390  # [encoder_length, batch_size, encoder_output_dim]
391  decoder_hidden_encoder_outputs_sum = model.net.Add(
392  [decoder_hidden_encoder_outputs_sum, scaled_coverage_weights],
393  decoder_hidden_encoder_outputs_sum,
394  )
395 
396  # [batch_size, encoder_length, 1]
397  attention_logits_transposed = _calc_attention_logits_from_sum_match(
398  model=model,
399  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
400  encoder_output_dim=encoder_output_dim,
401  scope=scope,
402  )
403 
404  # [batch_size, encoder_length, 1]
405  attention_weights_3d = _calc_attention_weights(
406  model=model,
407  attention_logits_transposed=attention_logits_transposed,
408  scope=scope,
409  encoder_lengths=encoder_lengths,
410  )
411 
412  # [batch_size, encoder_output_dim, 1]
413  attention_weighted_encoder_context = _calc_weighted_context(
414  model=model,
415  encoder_outputs_transposed=encoder_outputs_transposed,
416  encoder_output_dim=encoder_output_dim,
417  attention_weights_3d=attention_weights_3d,
418  scope=scope,
419  )
420 
421  # [batch_size, encoder_length]
422  attention_weights_2d = model.net.Squeeze(
423  attention_weights_3d,
424  s(scope, 'attention_weights_2d'),
425  dims=[2],
426  )
427 
428  coverage_t = model.net.Add(
429  [coverage_t_prev, attention_weights_2d],
430  s(scope, 'coverage_t'),
431  broadcast=1,
432  )
433 
434  return (
435  attention_weighted_encoder_context,
436  attention_weights_3d,
437  [decoder_hidden_encoder_outputs_sum],
438  coverage_t,
439  )