3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from abc
import ABCMeta, abstractmethod
10 from future.utils
import viewitems
13 from six
import with_metaclass
22 logger = logging.getLogger(__name__)
23 logger.setLevel(logging.INFO)
24 logger.addHandler(logging.StreamHandler(sys.stderr))
27 def _weighted_sum(model, values, weight, output_name):
28 values_weights = zip(values, [weight] * len(values))
29 values_weights_flattened = [x
for v_w
in values_weights
for x
in v_w]
30 return model.net.WeightedSum(
31 values_weights_flattened,
39 def get_model_file(self, model):
43 def get_db_type(self):
46 def build_word_rewards(self, vocab_size, word_reward, unk_reward):
47 word_rewards = np.full([vocab_size], word_reward, dtype=np.float32)
48 word_rewards[seq2seq_util.PAD_ID] = 0
49 word_rewards[seq2seq_util.GO_ID] = 0
50 word_rewards[seq2seq_util.EOS_ID] = 0
51 word_rewards[seq2seq_util.UNK_ID] = word_reward + unk_reward
54 def load_models(self):
56 for model, scope_name
in zip(
58 self.decoder_scope_names,
60 params_for_current_model = [
62 for param
in self.model.GetAllParams()
63 if str(param).startswith(scope_name)
65 assert workspace.RunOperatorOnce(core.CreateOperator(
68 db=self.get_model_file(model),
69 db_type=self.get_db_type())
70 ),
'Failed to create db {}'.format(self.get_model_file(model))
71 assert workspace.RunOperatorOnce(core.CreateOperator(
74 params_for_current_model,
76 add_prefix=scope_name +
'/',
77 strip_prefix=
'gpu_0/',
79 logger.info(
'Model {} is loaded from a checkpoint {}'.format(
80 scope_name, self.get_model_file(model)))
85 def get_model_file(self, model):
86 return model[
'model_file']
88 def get_db_type(self):
91 def scope(self, scope_name, blob_name):
93 scope_name +
'/' + blob_name
94 if scope_name
is not None 108 attention_type = model_params[
'attention']
109 assert attention_type
in [
'none',
'regular']
110 use_attention = (attention_type !=
'none')
112 with core.NameScope(scope):
113 encoder_embeddings = seq2seq_util.build_embeddings(
116 embedding_size=model_params[
'encoder_embedding_size'],
117 name=
'encoder_embeddings',
118 freeze_embeddings=
False,
123 weighted_encoder_outputs,
124 final_encoder_hidden_states,
125 final_encoder_cell_states,
126 encoder_units_per_layer,
127 ) = seq2seq_util.build_embedding_encoder(
129 encoder_params=model_params[
'encoder_type'],
130 num_decoder_layers=len(model_params[
'decoder_layer_configs']),
134 embeddings=encoder_embeddings,
135 embedding_size=model_params[
'encoder_embedding_size'],
136 use_attention=use_attention,
141 with core.NameScope(scope):
144 encoder_outputs = model.net.Tile(
146 'encoder_outputs_tiled',
151 if weighted_encoder_outputs
is not None:
152 weighted_encoder_outputs = model.net.Tile(
153 weighted_encoder_outputs,
154 'weighted_encoder_outputs_tiled',
159 decoder_embeddings = seq2seq_util.build_embeddings(
162 embedding_size=model_params[
'decoder_embedding_size'],
163 name=
'decoder_embeddings',
164 freeze_embeddings=
False,
166 embedded_tokens_t_prev = step_model.net.Gather(
167 [decoder_embeddings, previous_tokens],
168 'embedded_tokens_t_prev',
172 decoder_units_per_layer = []
173 for i, layer_config
in enumerate(model_params[
'decoder_layer_configs']):
174 num_units = layer_config[
'num_units']
175 decoder_units_per_layer.append(num_units)
177 input_size = model_params[
'decoder_embedding_size']
180 model_params[
'decoder_layer_configs'][i - 1][
'num_units']
185 input_size=input_size,
186 hidden_size=num_units,
188 memory_optimization=
False,
190 decoder_cells.append(cell)
192 with core.NameScope(scope):
193 if final_encoder_hidden_states
is not None:
194 for i
in range(len(final_encoder_hidden_states)):
195 if final_encoder_hidden_states[i]
is not None:
196 final_encoder_hidden_states[i] = model.net.Tile(
197 final_encoder_hidden_states[i],
198 'final_encoder_hidden_tiled_{}'.format(i),
202 if final_encoder_cell_states
is not None:
203 for i
in range(len(final_encoder_cell_states)):
204 if final_encoder_cell_states[i]
is not None:
205 final_encoder_cell_states[i] = model.net.Tile(
206 final_encoder_cell_states[i],
207 'final_encoder_cell_tiled_{}'.format(i),
212 seq2seq_util.build_initial_rnn_decoder_states(
214 encoder_units_per_layer=encoder_units_per_layer,
215 decoder_units_per_layer=decoder_units_per_layer,
216 final_encoder_hidden_states=final_encoder_hidden_states,
217 final_encoder_cell_states=final_encoder_cell_states,
218 use_attention=use_attention,
222 encoder_outputs=encoder_outputs,
223 encoder_output_dim=encoder_units_per_layer[-1],
224 encoder_lengths=
None,
226 attention_type=attention_type,
227 embedding_size=model_params[
'decoder_embedding_size'],
228 decoder_num_units=decoder_units_per_layer[-1],
229 decoder_cells=decoder_cells,
230 weighted_encoder_outputs=weighted_encoder_outputs,
233 states_prev = step_model.net.AddExternalInputs(*[
234 '{}/{}_prev'.format(scope, s)
235 for s
in attention_decoder.get_state_names()
237 decoder_outputs, states = attention_decoder.apply(
239 input_t=embedded_tokens_t_prev,
240 seq_lengths=fake_seq_lengths,
246 BeamSearchForwardOnly.StateConfig(
247 initial_value=initial_state,
248 state_prev_link=BeamSearchForwardOnly.LinkConfig(
253 state_link=BeamSearchForwardOnly.LinkConfig(
259 for initial_state, state_prev, state
in zip(
266 with core.NameScope(scope):
267 decoder_outputs_flattened, _ = step_model.net.Reshape(
270 'decoder_outputs_flattened',
271 'decoder_outputs_and_contexts_combination_old_shape',
273 shape=[-1, attention_decoder.get_output_dim()],
275 output_logits = seq2seq_util.output_projection(
277 decoder_outputs=decoder_outputs_flattened,
278 decoder_output_size=attention_decoder.get_output_dim(),
280 decoder_softmax_size=model_params[
'decoder_softmax_size'],
283 output_probs = step_model.net.Softmax(
287 output_log_probs = step_model.net.Log(
292 attention_weights = attention_decoder.get_attention_weights()
294 attention_weights = step_model.net.ConstantFill(
296 'zero_attention_weights_tmp_1',
299 attention_weights = step_model.net.Transpose(
301 'zero_attention_weights_tmp_2',
303 attention_weights = step_model.net.Tile(
305 'zero_attention_weights_tmp',
320 self.
models = translate_params[
'ensemble_models']
321 decoding_params = translate_params[
'decoding_params']
322 self.
beam_size = decoding_params[
'beam_size']
324 assert len(self.
models) > 0
325 source_vocab = self.
models[0][
'source_vocab']
326 target_vocab = self.
models[0][
'target_vocab']
328 assert model[
'source_vocab'] == source_vocab
329 assert model[
'target_vocab'] == target_vocab
335 'model{}'.format(i)
for i
in range(len(self.
models))
340 self.
encoder_inputs = self.model.net.AddExternalInput(
'encoder_inputs')
348 fake_seq_lengths = self.model.param_init_net.ConstantFill(
353 dtype=core.DataType.INT32,
359 go_token_id=seq2seq_util.GO_ID,
360 eos_token_id=seq2seq_util.EOS_ID,
362 step_model = beam_decoder.get_step_model()
365 output_log_probs = []
366 attention_weights = []
367 for model, scope_name
in zip(
372 state_configs_per_decoder,
373 output_log_probs_per_decoder,
374 attention_weights_per_decoder,
377 step_model=step_model,
378 model_params=model[
'model_params'],
380 previous_tokens=beam_decoder.get_previous_tokens(),
381 timestep=beam_decoder.get_timestep(),
382 fake_seq_lengths=fake_seq_lengths,
384 state_configs.extend(state_configs_per_decoder)
385 output_log_probs.append(output_log_probs_per_decoder)
386 if attention_weights_per_decoder
is not None:
387 attention_weights.append(attention_weights_per_decoder)
389 assert len(attention_weights) > 0
390 num_decoders_with_attention_blob = (
391 self.model.param_init_net.ConstantFill(
393 'num_decoders_with_attention_blob',
394 value=1 / float(len(attention_weights)),
399 attention_weights_average = _weighted_sum(
401 values=attention_weights,
402 weight=num_decoders_with_attention_blob,
403 output_name=
'attention_weights_average',
406 num_decoders_blob = self.model.param_init_net.ConstantFill(
409 value=1 / float(len(output_log_probs)),
413 output_log_probs_average = _weighted_sum(
415 values=output_log_probs,
416 weight=num_decoders_blob,
417 output_name=
'output_log_probs_average',
419 word_rewards = self.model.param_init_net.ConstantFill(
424 dtype=core.DataType.FLOAT,
427 self.output_token_beam_list,
428 self.output_prev_index_beam_list,
429 self.output_score_beam_list,
430 self.output_attention_weights_beam_list,
431 ) = beam_decoder.apply(
434 log_probs=output_log_probs_average,
435 attentions=attention_weights_average,
436 state_configs=state_configs,
437 data_dependencies=[],
438 word_rewards=word_rewards,
441 workspace.RunNetOnce(self.model.param_init_net)
446 word_reward=translate_params[
'decoding_params'][
'word_reward'],
447 unk_reward=translate_params[
'decoding_params'][
'unk_reward'],
460 logger.info(
'Params created: ')
461 for param
in self.model.params:
464 def decode(self, numberized_input, max_output_seq_len):
468 [token_id]
for token_id
in reversed(numberized_input)
469 ]).astype(dtype=np.int32),
473 np.array([len(numberized_input)]).astype(dtype=np.int32),
477 np.array([max_output_seq_len]).astype(dtype=np.int64),
480 workspace.RunNet(self.model.net)
482 num_steps = max_output_seq_len
483 score_beam_list = workspace.FetchBlob(self.output_score_beam_list)
485 workspace.FetchBlob(self.output_token_beam_list)
487 prev_index_beam_list = (
488 workspace.FetchBlob(self.output_prev_index_beam_list)
491 attention_weights_beam_list = (
492 workspace.FetchBlob(self.output_attention_weights_beam_list)
494 best_indices = (num_steps, 0)
495 for i
in range(num_steps + 1):
499 token_beam_list[i][hyp_index][0] ==
500 seq2seq_util.EOS_ID
or 504 score_beam_list[i][hyp_index][0] >
505 score_beam_list[best_indices[0]][best_indices[1]][0]
508 best_indices = (i, hyp_index)
510 i, hyp_index = best_indices
512 attention_weights_per_token = []
513 best_score = -score_beam_list[i][hyp_index][0]
515 output.append(token_beam_list[i][hyp_index][0])
516 attention_weights_per_token.append(
517 attention_weights_beam_list[i][hyp_index]
519 hyp_index = prev_index_beam_list[i][hyp_index][0]
522 attention_weights_per_token = reversed(attention_weights_per_token)
524 attention_weights_per_token = [
525 list(reversed(attention_weights))[:len(numberized_input)]
526 for attention_weights
in attention_weights_per_token
528 output = list(reversed(output))
529 return output, attention_weights_per_token, best_score
532 def run_seq2seq_beam_decoder(args, model_params, decoding_params):
533 source_vocab = seq2seq_util.gen_vocab(
537 logger.info(
'Source vocab size {}'.format(len(source_vocab)))
538 target_vocab = seq2seq_util.gen_vocab(
542 inversed_target_vocab = {v: k
for (k, v)
in viewitems(target_vocab)}
543 logger.info(
'Target vocab size {}'.format(len(target_vocab)))
546 translate_params=dict(
547 ensemble_models=[dict(
548 source_vocab=source_vocab,
549 target_vocab=target_vocab,
550 model_params=model_params,
551 model_file=args.checkpoint,
553 decoding_params=decoding_params,
556 decoder.load_models()
558 for line
in sys.stdin:
559 numerized_source_sentence = seq2seq_util.get_numberized_sentence(
563 translation, alignment, _ = decoder.decode(
564 numerized_source_sentence,
565 2 * len(numerized_source_sentence) + 5,
567 print(
' '.join([inversed_target_vocab[tid]
for tid
in translation]))
571 parser = argparse.ArgumentParser(
572 description=
'Caffe2: Seq2Seq Translation',
574 parser.add_argument(
'--source-corpus', type=str, default=
None,
575 help=
'Path to source corpus in a text file format. Each ' 576 'line in the file should contain a single sentence',
578 parser.add_argument(
'--target-corpus', type=str, default=
None,
579 help=
'Path to target corpus in a text file format',
581 parser.add_argument(
'--unk-threshold', type=int, default=50,
582 help=
'Threshold frequency under which token becomes ' 583 'labeled unknown token')
585 parser.add_argument(
'--use-bidirectional-encoder', action=
'store_true',
586 help=
'Set flag to use bidirectional recurrent network ' 588 parser.add_argument(
'--use-attention', action=
'store_true',
589 help=
'Set flag to use seq2seq with attention model')
590 parser.add_argument(
'--encoder-cell-num-units', type=int, default=512,
591 help=
'Number of cell units per encoder layer')
592 parser.add_argument(
'--encoder-num-layers', type=int, default=2,
593 help=
'Number encoder layers')
594 parser.add_argument(
'--decoder-cell-num-units', type=int, default=512,
595 help=
'Number of cell units in the decoder layer')
596 parser.add_argument(
'--decoder-num-layers', type=int, default=2,
597 help=
'Number decoder layers')
598 parser.add_argument(
'--encoder-embedding-size', type=int, default=256,
599 help=
'Size of embedding in the encoder layer')
600 parser.add_argument(
'--decoder-embedding-size', type=int, default=512,
601 help=
'Size of embedding in the decoder layer')
602 parser.add_argument(
'--decoder-softmax-size', type=int, default=
None,
603 help=
'Size of softmax layer in the decoder')
605 parser.add_argument(
'--beam-size', type=int, default=6,
606 help=
'Size of beam for the decoder')
607 parser.add_argument(
'--word-reward', type=float, default=0.0,
608 help=
'Reward per each word generated.')
609 parser.add_argument(
'--unk-reward', type=float, default=0.0,
610 help=
'Reward per each UNK token generated. ' 611 'Typically should be negative.')
613 parser.add_argument(
'--checkpoint', type=str, default=
None,
614 help=
'Path to checkpoint', required=
True)
616 args = parser.parse_args()
618 encoder_layer_configs = [
620 num_units=args.encoder_cell_num_units,
622 ] * args.encoder_num_layers
624 if args.use_bidirectional_encoder:
625 assert args.encoder_cell_num_units % 2 == 0
626 encoder_layer_configs[0][
'num_units'] /= 2
628 decoder_layer_configs = [
630 num_units=args.decoder_cell_num_units,
632 ] * args.decoder_num_layers
634 run_seq2seq_beam_decoder(
637 attention=(
'regular' if args.use_attention
else 'none'),
638 decoder_layer_configs=decoder_layer_configs,
640 encoder_layer_configs=encoder_layer_configs,
641 use_bidirectional_encoder=args.use_bidirectional_encoder,
643 encoder_embedding_size=args.encoder_embedding_size,
644 decoder_embedding_size=args.decoder_embedding_size,
645 decoder_softmax_size=args.decoder_softmax_size,
647 decoding_params=dict(
648 beam_size=args.beam_size,
649 word_reward=args.word_reward,
650 unk_reward=args.unk_reward,
655 if __name__ ==
'__main__':
Module caffe2.python.scope.
def _build_decoder(self, model, step_model, model_params, scope, previous_tokens, timestep, fake_seq_lengths)
def build_word_rewards(self, vocab_size, word_reward, unk_reward)