Caffe2 - Python API
A deep learning, cross platform ML framework
beam_search.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package beam_search
17 # Module caffe2.python.models.seq2seq.beam_search
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from collections import namedtuple
24 from caffe2.python import core
25 import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
26 from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
27 
28 
29 class BeamSearchForwardOnly(object):
30  """
31  Class generalizing forward beam search for seq2seq models.
32 
33  Also provides types to specify the recurrent structure of decoding:
34 
35  StateConfig:
36  initial_value: blob providing value of state at first step_model
37  state_prev_link: LinkConfig describing how recurrent step receives
38  input from global state blob in each step
39  state_link: LinkConfig describing how step writes (produces new state)
40  to global state blob in each step
41 
42  LinkConfig:
43  blob: blob connecting global state blob to step application
44  offset: offset from beginning of global blob for link in time dimension
45  window: width of global blob to read/write in time dimension
46  """
47 
48  LinkConfig = namedtuple('LinkConfig', ['blob', 'offset', 'window'])
49 
50  StateConfig = namedtuple(
51  'StateConfig',
52  ['initial_value', 'state_prev_link', 'state_link'],
53  )
54 
55  def __init__(
56  self,
57  beam_size,
58  model,
59  eos_token_id,
60  go_token_id=seq2seq_util.GO_ID,
61  post_eos_penalty=None,
62  ):
63  self.beam_size = beam_size
64  self.model = model
66  name='step_model',
67  param_model=self.model,
68  )
69  self.go_token_id = go_token_id
70  self.eos_token_id = eos_token_id
71  self.post_eos_penalty = post_eos_penalty
72 
73  (
74  self.timestep,
75  self.scores_t_prev,
76  self.tokens_t_prev,
77  self.hypo_t_prev,
78  self.attention_t_prev,
79  ) = self.step_model.net.AddExternalInputs(
80  'timestep',
81  'scores_t_prev',
82  'tokens_t_prev',
83  'hypo_t_prev',
84  'attention_t_prev',
85  )
86  tokens_t_prev_int32 = self.step_model.net.Cast(
87  self.tokens_t_prev,
88  'tokens_t_prev_int32',
89  to=core.DataType.INT32,
90  )
91  self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape(
92  [tokens_t_prev_int32],
93  [tokens_t_prev_int32, 'input_t_int32_old_shape'],
94  shape=[1, -1],
95  )
96 
97  def get_step_model(self):
98  return self.step_model
99 
100  def get_previous_tokens(self):
101  return self.tokens_t_prev_int32_flattened
102 
103  def get_timestep(self):
104  return self.timestep
105 
106  # TODO: make attentions a generic state
107  # data_dependencies is a list of blobs that the operator should wait for
108  # before beginning execution. This ensures that ops are run in the correct
109  # order when the RecurrentNetwork op is embedded in a DAGNet, for ex.
110  def apply(
111  self,
112  inputs,
113  length,
114  log_probs,
115  attentions,
116  state_configs,
117  data_dependencies,
118  word_rewards=None,
119  possible_translation_tokens=None,
120  go_token_id=None,
121  ):
122  ZERO = self.model.param_init_net.ConstantFill(
123  [],
124  'ZERO',
125  shape=[1],
126  value=0,
127  dtype=core.DataType.INT32,
128  )
129  on_initial_step = self.step_model.net.EQ(
130  [ZERO, self.timestep],
131  'on_initial_step',
132  )
133 
134  if self.post_eos_penalty is not None:
135  eos_token = self.model.param_init_net.ConstantFill(
136  [],
137  'eos_token',
138  shape=[self.beam_size],
139  value=self.eos_token_id,
140  dtype=core.DataType.INT32,
141  )
142  finished_penalty = self.model.param_init_net.ConstantFill(
143  [],
144  'finished_penalty',
145  shape=[1],
146  value=float(self.post_eos_penalty),
147  dtype=core.DataType.FLOAT,
148  )
149  ZERO_FLOAT = self.model.param_init_net.ConstantFill(
150  [],
151  'ZERO_FLOAT',
152  shape=[1],
153  value=0.0,
154  dtype=core.DataType.FLOAT,
155  )
156  finished_penalty = self.step_model.net.Conditional(
157  [on_initial_step, ZERO_FLOAT, finished_penalty],
158  'possible_finished_penalty',
159  )
160 
161  # [beam_size]
162  hypo_t_flat = self.step_model.net.FlattenToVec(
163  self.hypo_t_prev,
164  'hypo_t_flat',
165  )
166  hypo_t_flat_int = self.step_model.net.Cast(
167  hypo_t_flat,
168  'hypo_t_flat_int',
169  to=core.DataType.INT32,
170  )
171 
172  tokens_t_flat = self.step_model.net.FlattenToVec(
173  self.tokens_t_prev,
174  'tokens_t_flat',
175  )
176  tokens_t_flat_int = self.step_model.net.Cast(
177  tokens_t_flat,
178  'tokens_t_flat_int',
179  to=core.DataType.INT32,
180  )
181 
182  predecessor_tokens = self.step_model.net.Gather(
183  [tokens_t_flat_int, hypo_t_flat_int],
184  'predecessor_tokens',
185  )
186  predecessor_is_eos = self.step_model.net.EQ(
187  [predecessor_tokens, eos_token],
188  'predecessor_is_eos',
189  )
190  predecessor_is_eos_float = self.step_model.net.Cast(
191  predecessor_is_eos,
192  'predecessor_is_eos_float',
193  to=core.DataType.FLOAT,
194  )
195  predecessor_is_eos_penalty = self.step_model.net.Mul(
196  [predecessor_is_eos_float, finished_penalty],
197  'predecessor_is_eos_penalty',
198  broadcast=1,
199  )
200 
201  log_probs = self.step_model.net.Add(
202  [log_probs, predecessor_is_eos_penalty],
203  'log_probs_penalized',
204  broadcast=1,
205  axis=0,
206  )
207 
208  # [beam_size, beam_size]
209  best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK(
210  log_probs,
211  ['best_scores_per_hypo', 'best_tokens_per_hypo_indices'],
212  k=self.beam_size,
213  )
214  if possible_translation_tokens:
215  # [beam_size, beam_size]
216  best_tokens_per_hypo = self.step_model.net.Gather(
217  [possible_translation_tokens, best_tokens_per_hypo],
218  ['best_tokens_per_hypo']
219  )
220 
221  # [beam_size]
222  scores_t_prev_squeezed, _ = self.step_model.net.Reshape(
223  self.scores_t_prev,
224  ['scores_t_prev_squeezed', 'scores_t_prev_old_shape'],
225  shape=[self.beam_size],
226  )
227  # [beam_size, beam_size]
228  output_scores = self.step_model.net.Add(
229  [best_scores_per_hypo, scores_t_prev_squeezed],
230  'output_scores',
231  broadcast=1,
232  axis=0,
233  )
234  if word_rewards is not None:
235  # [beam_size, beam_size]
236  word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather(
237  [word_rewards, best_tokens_per_hypo],
238  'word_rewards_for_best_tokens_per_hypo',
239  )
240  # [beam_size, beam_size]
241  output_scores = self.step_model.net.Add(
242  [output_scores, word_rewards_for_best_tokens_per_hypo],
243  'output_scores',
244  )
245  # [beam_size * beam_size]
246  output_scores_flattened, _ = self.step_model.net.Reshape(
247  [output_scores],
248  [output_scores, 'output_scores_old_shape'],
249  shape=[-1],
250  )
251  MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill(
252  [],
253  'MINUS_ONE_INT32',
254  value=-1,
255  shape=[1],
256  dtype=core.DataType.INT32,
257  )
258  BEAM_SIZE = self.model.param_init_net.ConstantFill(
259  [],
260  'beam_size',
261  shape=[1],
262  value=self.beam_size,
263  dtype=core.DataType.INT32,
264  )
265 
266  # current_beam_size (predecessor states from previous step)
267  # is 1 on first step (so we just need beam_size scores),
268  # and beam_size subsequently (so we need all beam_size * beam_size
269  # scores)
270  slice_end = self.step_model.net.Conditional(
271  [on_initial_step, BEAM_SIZE, MINUS_ONE_INT32],
272  ['slice_end'],
273  )
274 
275  # [current_beam_size * beam_size]
276  output_scores_flattened_slice = self.step_model.net.Slice(
277  [output_scores_flattened, ZERO, slice_end],
278  'output_scores_flattened_slice',
279  )
280  # [1, current_beam_size * beam_size]
281  output_scores_flattened_slice, _ = self.step_model.net.Reshape(
282  output_scores_flattened_slice,
283  [
284  output_scores_flattened_slice,
285  'output_scores_flattened_slice_old_shape',
286  ],
287  shape=[1, -1],
288  )
289  # [1, beam_size]
290  scores_t, best_indices = self.step_model.net.TopK(
291  output_scores_flattened_slice,
292  ['scores_t', 'best_indices'],
293  k=self.beam_size,
294  )
295  BEAM_SIZE_64 = self.model.param_init_net.Cast(
296  BEAM_SIZE,
297  'BEAM_SIZE_64',
298  to=core.DataType.INT64,
299  )
300  # [1, beam_size]
301  hypo_t_int32 = self.step_model.net.Div(
302  [best_indices, BEAM_SIZE_64],
303  'hypo_t_int32',
304  broadcast=1,
305  )
306  hypo_t = self.step_model.net.Cast(
307  hypo_t_int32,
308  'hypo_t',
309  to=core.DataType.FLOAT,
310  )
311 
312  # [beam_size, encoder_length, 1]
313  attention_t = self.step_model.net.Gather(
314  [attentions, hypo_t_int32],
315  'attention_t',
316  )
317  # [1, beam_size, encoder_length]
318  attention_t, _ = self.step_model.net.Reshape(
319  attention_t,
320  [attention_t, 'attention_t_old_shape'],
321  shape=[1, self.beam_size, -1],
322  )
323  # [beam_size * beam_size]
324  best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape(
325  best_tokens_per_hypo,
326  [
327  'best_tokens_per_hypo_flatten',
328  'best_tokens_per_hypo_old_shape',
329  ],
330  shape=[-1],
331  )
332  tokens_t_int32 = self.step_model.net.Gather(
333  [best_tokens_per_hypo_flatten, best_indices],
334  'tokens_t_int32',
335  )
336  tokens_t = self.step_model.net.Cast(
337  tokens_t_int32,
338  'tokens_t',
339  to=core.DataType.FLOAT,
340  )
341 
342  def choose_state_per_hypo(state_config):
343  state_flattened, _ = self.step_model.net.Reshape(
344  state_config.state_link.blob,
345  [
346  state_config.state_link.blob,
347  state_config.state_link.blob + '_old_shape',
348  ],
349  shape=[self.beam_size, -1],
350  )
351  state_chosen_per_hypo = self.step_model.net.Gather(
352  [state_flattened, hypo_t_int32],
353  str(state_config.state_link.blob) + '_chosen_per_hypo',
354  )
355  return self.StateConfig(
356  initial_value=state_config.initial_value,
357  state_prev_link=state_config.state_prev_link,
358  state_link=self.LinkConfig(
359  blob=state_chosen_per_hypo,
360  offset=state_config.state_link.offset,
361  window=state_config.state_link.window,
362  )
363  )
364  state_configs = [choose_state_per_hypo(c) for c in state_configs]
365  initial_scores = self.model.param_init_net.ConstantFill(
366  [],
367  'initial_scores',
368  shape=[1],
369  value=0.0,
370  dtype=core.DataType.FLOAT,
371  )
372  if go_token_id:
373  initial_tokens = self.model.net.Copy(
374  [go_token_id],
375  'initial_tokens',
376  )
377  else:
378  initial_tokens = self.model.param_init_net.ConstantFill(
379  [],
380  'initial_tokens',
381  shape=[1],
382  value=float(self.go_token_id),
383  dtype=core.DataType.FLOAT,
384  )
385 
386  initial_hypo = self.model.param_init_net.ConstantFill(
387  [],
388  'initial_hypo',
389  shape=[1],
390  value=0.0,
391  dtype=core.DataType.FLOAT,
392  )
393  encoder_inputs_flattened, _ = self.model.net.Reshape(
394  inputs,
395  ['encoder_inputs_flattened', 'encoder_inputs_old_shape'],
396  shape=[-1],
397  )
398  init_attention = self.model.net.ConstantFill(
399  encoder_inputs_flattened,
400  'init_attention',
401  value=0.0,
402  dtype=core.DataType.FLOAT,
403  )
404  state_configs = state_configs + [
405  self.StateConfig(
406  initial_value=initial_scores,
407  state_prev_link=self.LinkConfig(self.scores_t_prev, 0, 1),
408  state_link=self.LinkConfig(scores_t, 1, 1),
409  ),
410  self.StateConfig(
411  initial_value=initial_tokens,
412  state_prev_link=self.LinkConfig(self.tokens_t_prev, 0, 1),
413  state_link=self.LinkConfig(tokens_t, 1, 1),
414  ),
415  self.StateConfig(
416  initial_value=initial_hypo,
417  state_prev_link=self.LinkConfig(self.hypo_t_prev, 0, 1),
418  state_link=self.LinkConfig(hypo_t, 1, 1),
419  ),
420  self.StateConfig(
421  initial_value=init_attention,
422  state_prev_link=self.LinkConfig(self.attention_t_prev, 0, 1),
423  state_link=self.LinkConfig(attention_t, 1, 1),
424  ),
425  ]
426  fake_input = self.model.net.ConstantFill(
427  length,
428  'beam_search_fake_input',
429  input_as_shape=True,
430  extra_shape=[self.beam_size, 1],
431  value=0.0,
432  dtype=core.DataType.FLOAT,
433  )
434  all_inputs = (
435  [fake_input] +
436  self.step_model.params +
437  [state_config.initial_value for state_config in state_configs] +
438  data_dependencies
439  )
440  forward_links = []
441  recurrent_states = []
442  for state_config in state_configs:
443  state_name = str(state_config.state_prev_link.blob) + '_states'
444  recurrent_states.append(state_name)
445  forward_links.append((
446  state_config.state_prev_link.blob,
447  state_name,
448  state_config.state_prev_link.offset,
449  state_config.state_prev_link.window,
450  ))
451  forward_links.append((
452  state_config.state_link.blob,
453  state_name,
454  state_config.state_link.offset,
455  state_config.state_link.window,
456  ))
457  link_internal, link_external, link_offset, link_window = (
458  zip(*forward_links)
459  )
460  all_outputs = [
461  str(s) + '_all'
462  for s in [scores_t, tokens_t, hypo_t, attention_t]
463  ]
464  results = self.model.net.RecurrentNetwork(
465  all_inputs,
466  all_outputs + ['step_workspaces'],
467  param=[all_inputs.index(p) for p in self.step_model.params],
468  alias_src=[
469  str(s) + '_states'
470  for s in [
471  self.scores_t_prev,
472  self.tokens_t_prev,
473  self.hypo_t_prev,
474  self.attention_t_prev,
475  ]
476  ],
477  alias_dst=all_outputs,
478  alias_offset=[0] * 4,
479  recurrent_states=recurrent_states,
480  initial_recurrent_state_ids=[
481  all_inputs.index(state_config.initial_value)
482  for state_config in state_configs
483  ],
484  link_internal=[str(l) for l in link_internal],
485  link_external=[str(l) for l in link_external],
486  link_offset=link_offset,
487  link_window=link_window,
488  backward_link_internal=[],
489  backward_link_external=[],
490  backward_link_offset=[],
491  step_net=self.step_model.net.Proto(),
492  timestep=str(self.timestep),
493  outputs_with_grads=[],
494  enable_rnn_executor=1,
495  rnn_executor_debug=0
496  )
497  score_t_all, tokens_t_all, hypo_t_all, attention_t_all = results[:4]
498 
499  output_token_beam_list = self.model.net.Cast(
500  tokens_t_all,
501  'output_token_beam_list',
502  to=core.DataType.INT32,
503  )
504  output_prev_index_beam_list = self.model.net.Cast(
505  hypo_t_all,
506  'output_prev_index_beam_list',
507  to=core.DataType.INT32,
508  )
509  output_score_beam_list = self.model.net.Alias(
510  score_t_all,
511  'output_score_beam_list',
512  )
513  output_attention_weights_beam_list = self.model.net.Alias(
514  attention_t_all,
515  'output_attention_weights_beam_list',
516  )
517 
518  return (
519  output_token_beam_list,
520  output_prev_index_beam_list,
521  output_score_beam_list,
522  output_attention_weights_beam_list,
523  )