Caffe2 - Python API
A deep learning, cross platform ML framework
beam_search.py
1 ## @package beam_search
2 # Module caffe2.python.models.seq2seq.beam_search
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from collections import namedtuple
9 from caffe2.python import core
10 import caffe2.python.models.seq2seq.seq2seq_util as seq2seq_util
11 from caffe2.python.models.seq2seq.seq2seq_model_helper import Seq2SeqModelHelper
12 
13 
14 class BeamSearchForwardOnly(object):
15  """
16  Class generalizing forward beam search for seq2seq models.
17 
18  Also provides types to specify the recurrent structure of decoding:
19 
20  StateConfig:
21  initial_value: blob providing value of state at first step_model
22  state_prev_link: LinkConfig describing how recurrent step receives
23  input from global state blob in each step
24  state_link: LinkConfig describing how step writes (produces new state)
25  to global state blob in each step
26 
27  LinkConfig:
28  blob: blob connecting global state blob to step application
29  offset: offset from beginning of global blob for link in time dimension
30  window: width of global blob to read/write in time dimension
31  """
32 
33  LinkConfig = namedtuple('LinkConfig', ['blob', 'offset', 'window'])
34 
35  StateConfig = namedtuple(
36  'StateConfig',
37  ['initial_value', 'state_prev_link', 'state_link'],
38  )
39 
40  def __init__(
41  self,
42  beam_size,
43  model,
44  eos_token_id,
45  go_token_id=seq2seq_util.GO_ID,
46  post_eos_penalty=None,
47  ):
48  self.beam_size = beam_size
49  self.model = model
51  name='step_model',
52  param_model=self.model,
53  )
54  self.go_token_id = go_token_id
55  self.eos_token_id = eos_token_id
56  self.post_eos_penalty = post_eos_penalty
57 
58  (
59  self.timestep,
60  self.scores_t_prev,
61  self.tokens_t_prev,
62  self.hypo_t_prev,
63  self.attention_t_prev,
64  ) = self.step_model.net.AddExternalInputs(
65  'timestep',
66  'scores_t_prev',
67  'tokens_t_prev',
68  'hypo_t_prev',
69  'attention_t_prev',
70  )
71  tokens_t_prev_int32 = self.step_model.net.Cast(
72  self.tokens_t_prev,
73  'tokens_t_prev_int32',
74  to=core.DataType.INT32,
75  )
76  self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape(
77  [tokens_t_prev_int32],
78  [tokens_t_prev_int32, 'input_t_int32_old_shape'],
79  shape=[1, -1],
80  )
81 
82  def get_step_model(self):
83  return self.step_model
84 
85  def get_previous_tokens(self):
86  return self.tokens_t_prev_int32_flattened
87 
88  def get_timestep(self):
89  return self.timestep
90 
91  # TODO: make attentions a generic state
92  # data_dependencies is a list of blobs that the operator should wait for
93  # before beginning execution. This ensures that ops are run in the correct
94  # order when the RecurrentNetwork op is embedded in a DAGNet, for ex.
95  def apply(
96  self,
97  inputs,
98  length,
99  log_probs,
100  attentions,
101  state_configs,
102  data_dependencies,
103  word_rewards=None,
104  possible_translation_tokens=None,
105  go_token_id=None,
106  ):
107  ZERO = self.model.param_init_net.ConstantFill(
108  [],
109  'ZERO',
110  shape=[1],
111  value=0,
112  dtype=core.DataType.INT32,
113  )
114  on_initial_step = self.step_model.net.EQ(
115  [ZERO, self.timestep],
116  'on_initial_step',
117  )
118 
119  if self.post_eos_penalty is not None:
120  eos_token = self.model.param_init_net.ConstantFill(
121  [],
122  'eos_token',
123  shape=[self.beam_size],
124  value=self.eos_token_id,
125  dtype=core.DataType.INT32,
126  )
127  finished_penalty = self.model.param_init_net.ConstantFill(
128  [],
129  'finished_penalty',
130  shape=[1],
131  value=float(self.post_eos_penalty),
132  dtype=core.DataType.FLOAT,
133  )
134  ZERO_FLOAT = self.model.param_init_net.ConstantFill(
135  [],
136  'ZERO_FLOAT',
137  shape=[1],
138  value=0.0,
139  dtype=core.DataType.FLOAT,
140  )
141  finished_penalty = self.step_model.net.Conditional(
142  [on_initial_step, ZERO_FLOAT, finished_penalty],
143  'possible_finished_penalty',
144  )
145 
146  tokens_t_flat = self.step_model.net.FlattenToVec(
147  self.tokens_t_prev,
148  'tokens_t_flat',
149  )
150  tokens_t_flat_int = self.step_model.net.Cast(
151  tokens_t_flat,
152  'tokens_t_flat_int',
153  to=core.DataType.INT32,
154  )
155 
156  predecessor_is_eos = self.step_model.net.EQ(
157  [tokens_t_flat_int, eos_token],
158  'predecessor_is_eos',
159  )
160  predecessor_is_eos_float = self.step_model.net.Cast(
161  predecessor_is_eos,
162  'predecessor_is_eos_float',
163  to=core.DataType.FLOAT,
164  )
165  predecessor_is_eos_penalty = self.step_model.net.Mul(
166  [predecessor_is_eos_float, finished_penalty],
167  'predecessor_is_eos_penalty',
168  broadcast=1,
169  )
170 
171  log_probs = self.step_model.net.Add(
172  [log_probs, predecessor_is_eos_penalty],
173  'log_probs_penalized',
174  broadcast=1,
175  axis=0,
176  )
177 
178  # [beam_size, beam_size]
179  best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK(
180  log_probs,
181  ['best_scores_per_hypo', 'best_tokens_per_hypo_indices'],
182  k=self.beam_size,
183  )
184  if possible_translation_tokens:
185  # [beam_size, beam_size]
186  best_tokens_per_hypo = self.step_model.net.Gather(
187  [possible_translation_tokens, best_tokens_per_hypo],
188  ['best_tokens_per_hypo']
189  )
190 
191  # [beam_size]
192  scores_t_prev_squeezed, _ = self.step_model.net.Reshape(
193  self.scores_t_prev,
194  ['scores_t_prev_squeezed', 'scores_t_prev_old_shape'],
195  shape=[self.beam_size],
196  )
197  # [beam_size, beam_size]
198  output_scores = self.step_model.net.Add(
199  [best_scores_per_hypo, scores_t_prev_squeezed],
200  'output_scores',
201  broadcast=1,
202  axis=0,
203  )
204  if word_rewards is not None:
205  # [beam_size, beam_size]
206  word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather(
207  [word_rewards, best_tokens_per_hypo],
208  'word_rewards_for_best_tokens_per_hypo',
209  )
210  # [beam_size, beam_size]
211  output_scores = self.step_model.net.Add(
212  [output_scores, word_rewards_for_best_tokens_per_hypo],
213  'output_scores',
214  )
215  # [beam_size * beam_size]
216  output_scores_flattened, _ = self.step_model.net.Reshape(
217  [output_scores],
218  [output_scores, 'output_scores_old_shape'],
219  shape=[-1],
220  )
221  MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill(
222  [],
223  'MINUS_ONE_INT32',
224  value=-1,
225  shape=[1],
226  dtype=core.DataType.INT32,
227  )
228  BEAM_SIZE = self.model.param_init_net.ConstantFill(
229  [],
230  'beam_size',
231  shape=[1],
232  value=self.beam_size,
233  dtype=core.DataType.INT32,
234  )
235 
236  # current_beam_size (predecessor states from previous step)
237  # is 1 on first step (so we just need beam_size scores),
238  # and beam_size subsequently (so we need all beam_size * beam_size
239  # scores)
240  slice_end = self.step_model.net.Conditional(
241  [on_initial_step, BEAM_SIZE, MINUS_ONE_INT32],
242  ['slice_end'],
243  )
244 
245  # [current_beam_size * beam_size]
246  output_scores_flattened_slice = self.step_model.net.Slice(
247  [output_scores_flattened, ZERO, slice_end],
248  'output_scores_flattened_slice',
249  )
250  # [1, current_beam_size * beam_size]
251  output_scores_flattened_slice, _ = self.step_model.net.Reshape(
252  output_scores_flattened_slice,
253  [
254  output_scores_flattened_slice,
255  'output_scores_flattened_slice_old_shape',
256  ],
257  shape=[1, -1],
258  )
259  # [1, beam_size]
260  scores_t, best_indices = self.step_model.net.TopK(
261  output_scores_flattened_slice,
262  ['scores_t', 'best_indices'],
263  k=self.beam_size,
264  )
265  BEAM_SIZE_64 = self.model.param_init_net.Cast(
266  BEAM_SIZE,
267  'BEAM_SIZE_64',
268  to=core.DataType.INT64,
269  )
270  # [1, beam_size]
271  hypo_t_int32 = self.step_model.net.Div(
272  [best_indices, BEAM_SIZE_64],
273  'hypo_t_int32',
274  broadcast=1,
275  )
276  hypo_t = self.step_model.net.Cast(
277  hypo_t_int32,
278  'hypo_t',
279  to=core.DataType.FLOAT,
280  )
281 
282  # [beam_size, encoder_length, 1]
283  attention_t = self.step_model.net.Gather(
284  [attentions, hypo_t_int32],
285  'attention_t',
286  )
287  # [1, beam_size, encoder_length]
288  attention_t, _ = self.step_model.net.Reshape(
289  attention_t,
290  [attention_t, 'attention_t_old_shape'],
291  shape=[1, self.beam_size, -1],
292  )
293  # [beam_size * beam_size]
294  best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape(
295  best_tokens_per_hypo,
296  [
297  'best_tokens_per_hypo_flatten',
298  'best_tokens_per_hypo_old_shape',
299  ],
300  shape=[-1],
301  )
302  tokens_t_int32 = self.step_model.net.Gather(
303  [best_tokens_per_hypo_flatten, best_indices],
304  'tokens_t_int32',
305  )
306  tokens_t = self.step_model.net.Cast(
307  tokens_t_int32,
308  'tokens_t',
309  to=core.DataType.FLOAT,
310  )
311 
312  def choose_state_per_hypo(state_config):
313  state_flattened, _ = self.step_model.net.Reshape(
314  state_config.state_link.blob,
315  [
316  state_config.state_link.blob,
317  state_config.state_link.blob + '_old_shape',
318  ],
319  shape=[self.beam_size, -1],
320  )
321  state_chosen_per_hypo = self.step_model.net.Gather(
322  [state_flattened, hypo_t_int32],
323  str(state_config.state_link.blob) + '_chosen_per_hypo',
324  )
325  return self.StateConfig(
326  initial_value=state_config.initial_value,
327  state_prev_link=state_config.state_prev_link,
328  state_link=self.LinkConfig(
329  blob=state_chosen_per_hypo,
330  offset=state_config.state_link.offset,
331  window=state_config.state_link.window,
332  )
333  )
334  state_configs = [choose_state_per_hypo(c) for c in state_configs]
335  initial_scores = self.model.param_init_net.ConstantFill(
336  [],
337  'initial_scores',
338  shape=[1],
339  value=0.0,
340  dtype=core.DataType.FLOAT,
341  )
342  if go_token_id:
343  initial_tokens = self.model.net.Copy(
344  [go_token_id],
345  'initial_tokens',
346  )
347  else:
348  initial_tokens = self.model.param_init_net.ConstantFill(
349  [],
350  'initial_tokens',
351  shape=[1],
352  value=float(self.go_token_id),
353  dtype=core.DataType.FLOAT,
354  )
355 
356  initial_hypo = self.model.param_init_net.ConstantFill(
357  [],
358  'initial_hypo',
359  shape=[1],
360  value=0.0,
361  dtype=core.DataType.FLOAT,
362  )
363  encoder_inputs_flattened, _ = self.model.net.Reshape(
364  inputs,
365  ['encoder_inputs_flattened', 'encoder_inputs_old_shape'],
366  shape=[-1],
367  )
368  init_attention = self.model.net.ConstantFill(
369  encoder_inputs_flattened,
370  'init_attention',
371  value=0.0,
372  dtype=core.DataType.FLOAT,
373  )
374  state_configs = state_configs + [
375  self.StateConfig(
376  initial_value=initial_scores,
377  state_prev_link=self.LinkConfig(self.scores_t_prev, 0, 1),
378  state_link=self.LinkConfig(scores_t, 1, 1),
379  ),
380  self.StateConfig(
381  initial_value=initial_tokens,
382  state_prev_link=self.LinkConfig(self.tokens_t_prev, 0, 1),
383  state_link=self.LinkConfig(tokens_t, 1, 1),
384  ),
385  self.StateConfig(
386  initial_value=initial_hypo,
387  state_prev_link=self.LinkConfig(self.hypo_t_prev, 0, 1),
388  state_link=self.LinkConfig(hypo_t, 1, 1),
389  ),
390  self.StateConfig(
391  initial_value=init_attention,
392  state_prev_link=self.LinkConfig(self.attention_t_prev, 0, 1),
393  state_link=self.LinkConfig(attention_t, 1, 1),
394  ),
395  ]
396  fake_input = self.model.net.ConstantFill(
397  length,
398  'beam_search_fake_input',
399  input_as_shape=True,
400  extra_shape=[self.beam_size, 1],
401  value=0.0,
402  dtype=core.DataType.FLOAT,
403  )
404  all_inputs = (
405  [fake_input] +
406  self.step_model.params +
407  [state_config.initial_value for state_config in state_configs] +
408  data_dependencies
409  )
410  forward_links = []
411  recurrent_states = []
412  for state_config in state_configs:
413  state_name = str(state_config.state_prev_link.blob) + '_states'
414  recurrent_states.append(state_name)
415  forward_links.append((
416  state_config.state_prev_link.blob,
417  state_name,
418  state_config.state_prev_link.offset,
419  state_config.state_prev_link.window,
420  ))
421  forward_links.append((
422  state_config.state_link.blob,
423  state_name,
424  state_config.state_link.offset,
425  state_config.state_link.window,
426  ))
427  link_internal, link_external, link_offset, link_window = (
428  zip(*forward_links)
429  )
430  all_outputs = [
431  str(s) + '_all'
432  for s in [scores_t, tokens_t, hypo_t, attention_t]
433  ]
434  results = self.model.net.RecurrentNetwork(
435  all_inputs,
436  all_outputs + ['step_workspaces'],
437  param=[all_inputs.index(p) for p in self.step_model.params],
438  alias_src=[
439  str(s) + '_states'
440  for s in [
441  self.scores_t_prev,
442  self.tokens_t_prev,
443  self.hypo_t_prev,
444  self.attention_t_prev,
445  ]
446  ],
447  alias_dst=all_outputs,
448  alias_offset=[0] * 4,
449  recurrent_states=recurrent_states,
450  initial_recurrent_state_ids=[
451  all_inputs.index(state_config.initial_value)
452  for state_config in state_configs
453  ],
454  link_internal=[str(l) for l in link_internal],
455  link_external=[str(l) for l in link_external],
456  link_offset=link_offset,
457  link_window=link_window,
458  backward_link_internal=[],
459  backward_link_external=[],
460  backward_link_offset=[],
461  step_net=self.step_model.net.Proto(),
462  timestep=str(self.timestep),
463  outputs_with_grads=[],
464  enable_rnn_executor=1,
465  rnn_executor_debug=0
466  )
467  score_t_all, tokens_t_all, hypo_t_all, attention_t_all = results[:4]
468 
469  output_token_beam_list = self.model.net.Cast(
470  tokens_t_all,
471  'output_token_beam_list',
472  to=core.DataType.INT32,
473  )
474  output_prev_index_beam_list = self.model.net.Cast(
475  hypo_t_all,
476  'output_prev_index_beam_list',
477  to=core.DataType.INT32,
478  )
479  output_score_beam_list = self.model.net.Alias(
480  score_t_all,
481  'output_score_beam_list',
482  )
483  output_attention_weights_beam_list = self.model.net.Alias(
484  attention_t_all,
485  'output_attention_weights_beam_list',
486  )
487 
488  return (
489  output_token_beam_list,
490  output_prev_index_beam_list,
491  output_score_beam_list,
492  output_attention_weights_beam_list,
493  )