3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple
16 Class generalizing forward beam search for seq2seq models. 18 Also provides types to specify the recurrent structure of decoding: 21 initial_value: blob providing value of state at first step_model 22 state_prev_link: LinkConfig describing how recurrent step receives 23 input from global state blob in each step 24 state_link: LinkConfig describing how step writes (produces new state) 25 to global state blob in each step 28 blob: blob connecting global state blob to step application 29 offset: offset from beginning of global blob for link in time dimension 30 window: width of global blob to read/write in time dimension 33 LinkConfig = namedtuple(
'LinkConfig', [
'blob',
'offset',
'window'])
35 StateConfig = namedtuple(
37 [
'initial_value',
'state_prev_link',
'state_link'],
45 go_token_id=seq2seq_util.GO_ID,
46 post_eos_penalty=
None,
52 param_model=self.
model,
63 self.attention_t_prev,
64 ) = self.step_model.net.AddExternalInputs(
71 tokens_t_prev_int32 = self.step_model.net.Cast(
73 'tokens_t_prev_int32',
74 to=core.DataType.INT32,
76 self.tokens_t_prev_int32_flattened, _ = self.step_model.net.Reshape(
77 [tokens_t_prev_int32],
78 [tokens_t_prev_int32,
'input_t_int32_old_shape'],
82 def get_step_model(self):
85 def get_previous_tokens(self):
86 return self.tokens_t_prev_int32_flattened
88 def get_timestep(self):
104 possible_translation_tokens=
None,
107 ZERO = self.model.param_init_net.ConstantFill(
112 dtype=core.DataType.INT32,
114 on_initial_step = self.step_model.net.EQ(
115 [ZERO, self.timestep],
120 eos_token = self.model.param_init_net.ConstantFill(
125 dtype=core.DataType.INT32,
127 finished_penalty = self.model.param_init_net.ConstantFill(
132 dtype=core.DataType.FLOAT,
134 ZERO_FLOAT = self.model.param_init_net.ConstantFill(
139 dtype=core.DataType.FLOAT,
141 finished_penalty = self.step_model.net.Conditional(
142 [on_initial_step, ZERO_FLOAT, finished_penalty],
143 'possible_finished_penalty',
146 tokens_t_flat = self.step_model.net.FlattenToVec(
150 tokens_t_flat_int = self.step_model.net.Cast(
153 to=core.DataType.INT32,
156 predecessor_is_eos = self.step_model.net.EQ(
157 [tokens_t_flat_int, eos_token],
158 'predecessor_is_eos',
160 predecessor_is_eos_float = self.step_model.net.Cast(
162 'predecessor_is_eos_float',
163 to=core.DataType.FLOAT,
165 predecessor_is_eos_penalty = self.step_model.net.Mul(
166 [predecessor_is_eos_float, finished_penalty],
167 'predecessor_is_eos_penalty',
171 log_probs = self.step_model.net.Add(
172 [log_probs, predecessor_is_eos_penalty],
173 'log_probs_penalized',
179 best_scores_per_hypo, best_tokens_per_hypo = self.step_model.net.TopK(
181 [
'best_scores_per_hypo',
'best_tokens_per_hypo_indices'],
184 if possible_translation_tokens:
186 best_tokens_per_hypo = self.step_model.net.Gather(
187 [possible_translation_tokens, best_tokens_per_hypo],
188 [
'best_tokens_per_hypo']
192 scores_t_prev_squeezed, _ = self.step_model.net.Reshape(
194 [
'scores_t_prev_squeezed',
'scores_t_prev_old_shape'],
198 output_scores = self.step_model.net.Add(
199 [best_scores_per_hypo, scores_t_prev_squeezed],
204 if word_rewards
is not None:
206 word_rewards_for_best_tokens_per_hypo = self.step_model.net.Gather(
207 [word_rewards, best_tokens_per_hypo],
208 'word_rewards_for_best_tokens_per_hypo',
211 output_scores = self.step_model.net.Add(
212 [output_scores, word_rewards_for_best_tokens_per_hypo],
216 output_scores_flattened, _ = self.step_model.net.Reshape(
218 [output_scores,
'output_scores_old_shape'],
221 MINUS_ONE_INT32 = self.model.param_init_net.ConstantFill(
226 dtype=core.DataType.INT32,
228 BEAM_SIZE = self.model.param_init_net.ConstantFill(
233 dtype=core.DataType.INT32,
240 slice_end = self.step_model.net.Conditional(
241 [on_initial_step, BEAM_SIZE, MINUS_ONE_INT32],
246 output_scores_flattened_slice = self.step_model.net.Slice(
247 [output_scores_flattened, ZERO, slice_end],
248 'output_scores_flattened_slice',
251 output_scores_flattened_slice, _ = self.step_model.net.Reshape(
252 output_scores_flattened_slice,
254 output_scores_flattened_slice,
255 'output_scores_flattened_slice_old_shape',
260 scores_t, best_indices = self.step_model.net.TopK(
261 output_scores_flattened_slice,
262 [
'scores_t',
'best_indices'],
265 BEAM_SIZE_64 = self.model.param_init_net.Cast(
268 to=core.DataType.INT64,
271 hypo_t_int32 = self.step_model.net.Div(
272 [best_indices, BEAM_SIZE_64],
276 hypo_t = self.step_model.net.Cast(
279 to=core.DataType.FLOAT,
283 attention_t = self.step_model.net.Gather(
284 [attentions, hypo_t_int32],
288 attention_t, _ = self.step_model.net.Reshape(
290 [attention_t,
'attention_t_old_shape'],
294 best_tokens_per_hypo_flatten, _ = self.step_model.net.Reshape(
295 best_tokens_per_hypo,
297 'best_tokens_per_hypo_flatten',
298 'best_tokens_per_hypo_old_shape',
302 tokens_t_int32 = self.step_model.net.Gather(
303 [best_tokens_per_hypo_flatten, best_indices],
306 tokens_t = self.step_model.net.Cast(
309 to=core.DataType.FLOAT,
312 def choose_state_per_hypo(state_config):
313 state_flattened, _ = self.step_model.net.Reshape(
314 state_config.state_link.blob,
316 state_config.state_link.blob,
317 state_config.state_link.blob +
'_old_shape',
321 state_chosen_per_hypo = self.step_model.net.Gather(
322 [state_flattened, hypo_t_int32],
323 str(state_config.state_link.blob) +
'_chosen_per_hypo',
326 initial_value=state_config.initial_value,
327 state_prev_link=state_config.state_prev_link,
329 blob=state_chosen_per_hypo,
330 offset=state_config.state_link.offset,
331 window=state_config.state_link.window,
334 state_configs = [choose_state_per_hypo(c)
for c
in state_configs]
335 initial_scores = self.model.param_init_net.ConstantFill(
340 dtype=core.DataType.FLOAT,
343 initial_tokens = self.model.net.Copy(
348 initial_tokens = self.model.param_init_net.ConstantFill(
353 dtype=core.DataType.FLOAT,
356 initial_hypo = self.model.param_init_net.ConstantFill(
361 dtype=core.DataType.FLOAT,
363 encoder_inputs_flattened, _ = self.model.net.Reshape(
365 [
'encoder_inputs_flattened',
'encoder_inputs_old_shape'],
368 init_attention = self.model.net.ConstantFill(
369 encoder_inputs_flattened,
372 dtype=core.DataType.FLOAT,
374 state_configs = state_configs + [
376 initial_value=initial_scores,
377 state_prev_link=self.
LinkConfig(self.scores_t_prev, 0, 1),
381 initial_value=initial_tokens,
382 state_prev_link=self.
LinkConfig(self.tokens_t_prev, 0, 1),
386 initial_value=initial_hypo,
387 state_prev_link=self.
LinkConfig(self.hypo_t_prev, 0, 1),
391 initial_value=init_attention,
392 state_prev_link=self.
LinkConfig(self.attention_t_prev, 0, 1),
393 state_link=self.
LinkConfig(attention_t, 1, 1),
396 fake_input = self.model.net.ConstantFill(
398 'beam_search_fake_input',
402 dtype=core.DataType.FLOAT,
406 self.step_model.params +
407 [state_config.initial_value
for state_config
in state_configs] +
411 recurrent_states = []
412 for state_config
in state_configs:
413 state_name = str(state_config.state_prev_link.blob) +
'_states' 414 recurrent_states.append(state_name)
415 forward_links.append((
416 state_config.state_prev_link.blob,
418 state_config.state_prev_link.offset,
419 state_config.state_prev_link.window,
421 forward_links.append((
422 state_config.state_link.blob,
424 state_config.state_link.offset,
425 state_config.state_link.window,
427 link_internal, link_external, link_offset, link_window = (
432 for s
in [scores_t, tokens_t, hypo_t, attention_t]
434 results = self.model.net.RecurrentNetwork(
436 all_outputs + [
'step_workspaces'],
437 param=[all_inputs.index(p)
for p
in self.step_model.params],
444 self.attention_t_prev,
447 alias_dst=all_outputs,
448 alias_offset=[0] * 4,
449 recurrent_states=recurrent_states,
450 initial_recurrent_state_ids=[
451 all_inputs.index(state_config.initial_value)
452 for state_config
in state_configs
454 link_internal=[str(l)
for l
in link_internal],
455 link_external=[str(l)
for l
in link_external],
456 link_offset=link_offset,
457 link_window=link_window,
458 backward_link_internal=[],
459 backward_link_external=[],
460 backward_link_offset=[],
461 step_net=self.step_model.net.Proto(),
462 timestep=str(self.timestep),
463 outputs_with_grads=[],
464 enable_rnn_executor=1,
467 score_t_all, tokens_t_all, hypo_t_all, attention_t_all = results[:4]
469 output_token_beam_list = self.model.net.Cast(
471 'output_token_beam_list',
472 to=core.DataType.INT32,
474 output_prev_index_beam_list = self.model.net.Cast(
476 'output_prev_index_beam_list',
477 to=core.DataType.INT32,
479 output_score_beam_list = self.model.net.Alias(
481 'output_score_beam_list',
483 output_attention_weights_beam_list = self.model.net.Alias(
485 'output_attention_weights_beam_list',
489 output_token_beam_list,
490 output_prev_index_beam_list,
491 output_score_beam_list,
492 output_attention_weights_beam_list,