3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
15 from future.utils
import viewkeys
17 from caffe2.proto
import caffe2_pb2
20 apply_recurrent_attention,
21 apply_regular_attention,
22 apply_soft_coverage_attention,
25 from caffe2.python import core, recurrent, workspace, brew, scope, utils
32 def _RectifyName(blob_reference_or_name):
33 if blob_reference_or_name
is None:
35 if isinstance(blob_reference_or_name, six.string_types):
36 return core.ScopedBlobReference(blob_reference_or_name)
37 if not isinstance(blob_reference_or_name, core.BlobReference):
38 raise Exception(
"Unknown blob reference type")
39 return blob_reference_or_name
42 def _RectifyNames(blob_references_or_names):
43 if blob_references_or_names
is None:
45 return list(map(_RectifyName, blob_references_or_names))
50 Base class for writing recurrent / stateful operations. 52 One needs to implement 2 methods: apply_override 53 and get_state_names_override. 55 As a result base class will provice apply_over_sequence method, which 56 allows you to apply recurrent operations over a sequence of any length. 58 As optional you could add input and output preparation steps by overriding 59 corresponding methods. 61 def __init__(self, name=None, forward_only=False, initializer=None):
68 def initializer(self):
72 def initializer(self, value):
75 def scope(self, name):
76 return self.
name +
'/' + name
if self.
name is not None else name
78 def apply_over_sequence(
84 outputs_with_grads=
None,
86 if initial_states
is None:
87 with scope.NameScope(self.
name):
89 raise Exception(
"Either initial states " 90 "or initializer have to be set")
91 initial_states = self.initializer.create_states(model)
95 input_t, timestep = step_model.net.AddScopedExternalInputs(
99 utils.raiseIfNotEqual(
101 "Number of initial state values provided doesn't match the number " 104 states_prev = step_model.net.AddScopedExternalInputs(*[
110 seq_lengths=seq_lengths,
115 external_outputs = set(step_model.net.Proto().external_output)
117 if state
not in external_outputs:
118 step_model.net.AddExternalOutput(state)
120 if outputs_with_grads
is None:
126 states_for_all_steps = recurrent.recurrent_net(
128 cell_net=step_model.net,
129 inputs=[(input_t, preprocessed_inputs)],
130 initial_cell_inputs=list(zip(states_prev, initial_states)),
131 links=dict(zip(states_prev, states)),
135 outputs_with_grads=outputs_with_grads,
141 states_for_all_steps,
143 return output, states_for_all_steps
145 def apply(self, model, input_t, seq_lengths, states, timestep):
148 model, input_t, seq_lengths, states, timestep)
150 return output, states
154 model, input_t, seq_lengths, states, timestep, extra_inputs=
None 157 This method uses apply_override provided by a custom cell. 158 On the top it takes care of applying self.scope() to all the outputs. 159 While all the inputs stay within the scope this function was called 163 input_t, seq_lengths, states, timestep, extra_inputs)
164 with core.NameScope(self.
name):
167 def _rectify_apply_inputs(
168 self, input_t, seq_lengths, states, timestep, extra_inputs):
170 Before applying a scope we make sure that all external blob names 171 are converted to blob reference. So further scoping doesn't affect them 174 input_t, seq_lengths, timestep = _RectifyNames(
175 [input_t, seq_lengths, timestep])
176 states = _RectifyNames(states)
178 extra_input_names, extra_input_sizes = zip(*extra_inputs)
179 extra_inputs = _RectifyNames(extra_input_names)
180 extra_inputs = zip(extra_input_names, extra_input_sizes)
183 rectified = [input_t, seq_lengths, states, timestep]
184 if 'extra_inputs' in arg_names:
185 rectified.append(extra_inputs)
191 model, input_t, seq_lengths, timestep, extra_inputs=
None,
194 A single step of a recurrent network to be implemented by each custom 197 model: ModelHelper object new operators would be added to 199 input_t: singlse input with shape (1, batch_size, input_dim) 201 seq_lengths: blob containing sequence lengths which would be passed to 204 states: previous recurrent states 206 timestep: current recurrent iteration. Could be used together with 207 seq_lengths in order to determine, if some shorter sequences 208 in the batch have already ended. 210 extra_inputs: list of tuples (input, dim). specifies additional input 211 which is not subject to prepare_input(). (useful when a cell is a 212 component of a larger recurrent structure, e.g., attention) 214 raise NotImplementedError(
'Abstract method')
218 If some operations in _apply method depend only on the input, 219 not on recurrent states, they could be computed in advance. 221 model: ModelHelper object new operators would be added to 223 input_blob: either the whole input sequence with shape 224 (sequence_length, batch_size, input_dim) or a single input with shape 225 (1, batch_size, input_dim). 231 Return index into state list of the "primary" step-wise output. 237 Returns recurrent state names with self.name scoping applied 243 Override this function in your custom cell. 244 It should return the names of the recurrent states. 246 It's required by apply_over_sequence method in order to allocate 247 recurrent states for all steps with meaningful names. 249 raise NotImplementedError(
'Abstract method')
253 Specifies the dimension (number of units) of stepwise output. 255 raise NotImplementedError(
'Abstract method')
257 def _prepare_output(self, model, states):
259 Allows arbitrary post-processing of primary output. 263 def _prepare_output_sequence(self, model, state_outputs):
265 Allows arbitrary post-processing of primary sequence output. 267 (Note that state_outputs alternates between full-sequence and final 268 output for each state, thus the index multiplier 2.) 271 return state_outputs[output_sequence_index]
275 def __init__(self, hidden_size):
278 def create_states(self, model):
281 param_name=
'initial_hidden_state',
282 initializer=
Initializer(operator_name=
'ConstantFill',
287 param_name=
'initial_cell_state',
288 initializer=
Initializer(operator_name=
'ConstantFill',
308 super(BasicRNNCell, self).__init__(**kwargs)
316 'BasicRNNCell with unknown activation function (%s)' 328 hidden_t_prev = states[0]
339 brew.sum(model, [gates_t, input_t], gates_t)
341 hidden_t = model.net.Tanh(gates_t,
'hidden_t')
343 hidden_t = model.net.Relu(gates_t,
'hidden_t')
346 'BasicRNNCell with unknown activation function (%s)' 349 if seq_lengths
is not None:
354 timestep = model.net.CopyFromCPUInput(
355 timestep,
'timestep_gpu')
356 valid_b = model.net.GT(
357 [seq_lengths, timestep],
'valid_b', broadcast=1)
358 invalid_b = model.net.LE(
359 [seq_lengths, timestep],
'invalid_b', broadcast=1)
360 valid = model.net.Cast(valid_b,
'valid', to=
'float')
361 invalid = model.net.Cast(invalid_b,
'invalid', to=
'float')
363 hidden_valid = model.net.Mul(
370 hidden_t = hidden_valid
372 hidden_invalid = model.net.Mul(
373 [hidden_t_prev, invalid],
376 hidden_t = model.net.Add(
377 [hidden_valid, hidden_invalid], hidden_t)
380 def prepare_input(self, model, input_blob):
390 def get_state_names(self):
391 return (self.
scope(
'hidden_t'),)
393 def get_output_dim(self):
409 super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
411 hidden_size=hidden_size)
429 hidden_t_prev, cell_t_prev = states
431 fc_input = hidden_t_prev
434 if extra_inputs
is not None:
435 extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
436 fc_input = brew.concat(
438 [hidden_t_prev] + list(extra_input_blobs),
439 'gates_concatenated_input_t',
442 fc_input_dim += sum(extra_input_sizes)
452 brew.sum(model, [gates_t, input_t], gates_t)
454 if seq_lengths
is not None:
455 inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
457 inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
459 hidden_t, cell_t = model.net.LSTMUnit(
461 [
'hidden_state',
'cell_state'],
464 sequence_lengths=(seq_lengths
is not None),
466 model.net.AddExternalOutputs(hidden_t, cell_t)
470 return hidden_t, cell_t
472 def get_input_params(self):
474 'weights': self.
scope(
'i2h') +
'_w',
475 'biases': self.
scope(
'i2h') +
'_b',
478 def get_recurrent_params(self):
480 'weights': self.
scope(
'gates_t') +
'_w',
481 'biases': self.
scope(
'gates_t') +
'_b',
484 def prepare_input(self, model, input_blob):
494 def get_state_names_override(self):
495 return [
'hidden_t',
'cell_t']
497 def get_output_dim(self):
513 super(LayerNormLSTMCell, self).__init__(
514 initializer=initializer, **kwargs
517 hidden_size=hidden_size
536 hidden_t_prev, cell_t_prev = states
538 fc_input = hidden_t_prev
541 if extra_inputs
is not None:
542 extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
543 fc_input = brew.concat(
545 [hidden_t_prev] + list(extra_input_blobs),
546 self.
scope(
'gates_concatenated_input_t'),
549 fc_input_dim += sum(extra_input_sizes)
554 self.
scope(
'gates_t'),
559 brew.sum(model, [gates_t, input_t], gates_t)
562 gates_t, _, _ = brew.layer_norm(
564 self.
scope(
'gates_t'),
565 self.
scope(
'gates_t_norm'),
570 hidden_t, cell_t = model.net.LSTMUnit(
582 model.net.AddExternalOutputs(hidden_t, cell_t)
586 return hidden_t, cell_t
588 def get_input_params(self):
590 'weights': self.
scope(
'i2h') +
'_w',
591 'biases': self.
scope(
'i2h') +
'_b',
594 def prepare_input(self, model, input_blob):
604 def get_state_names(self):
605 return (self.
scope(
'hidden_t'), self.
scope(
'cell_t'))
619 hidden_t_prev, cell_t_prev = states
621 fc_input = hidden_t_prev
624 if extra_inputs
is not None:
625 extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
626 fc_input = brew.concat(
628 [hidden_t_prev] + list(extra_input_blobs),
629 self.
scope(
'gates_concatenated_input_t'),
632 fc_input_dim += sum(extra_input_sizes)
637 self.
scope(
'prev_t'),
644 alpha = model.create_param(
647 initializer=
Initializer(
'ConstantFill', value=1.0),
649 beta_h = model.create_param(
652 initializer=
Initializer(
'ConstantFill', value=1.0),
654 beta_i = model.create_param(
657 initializer=
Initializer(
'ConstantFill', value=1.0),
659 b = model.create_param(
662 initializer=
Initializer(
'ConstantFill', value=0.0),
667 alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
668 [input_t, alpha, beta_h],
669 self.
scope(
'alpha_by_input_t_plus_beta_h'),
675 alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
676 [alpha_by_input_t_plus_beta_h, prev_t],
677 self.
scope(
'alpha_by_input_t_plus_beta_h_by_prev_t')
681 beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
682 [input_t, beta_i, b],
683 self.
scope(
'beta_i_by_input_t_plus_b'),
690 [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
691 self.
scope(
'gates_t')
693 hidden_t, cell_t = model.net.LSTMUnit(
694 [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
695 [self.
scope(
'hidden_t_intermediate'), self.
scope(
'cell_t')],
699 model.net.AddExternalOutputs(
705 return hidden_t, cell_t
719 hidden_t_prev, cell_t_prev = states
721 fc_input = hidden_t_prev
724 if extra_inputs
is not None:
725 extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
726 fc_input = brew.concat(
728 [hidden_t_prev] + list(extra_input_blobs),
729 self.
scope(
'gates_concatenated_input_t'),
732 fc_input_dim += sum(extra_input_sizes)
737 self.
scope(
'prev_t'),
744 alpha = model.create_param(
747 initializer=
Initializer(
'ConstantFill', value=1.0),
749 beta_h = model.create_param(
752 initializer=
Initializer(
'ConstantFill', value=1.0),
754 beta_i = model.create_param(
757 initializer=
Initializer(
'ConstantFill', value=1.0),
759 b = model.create_param(
762 initializer=
Initializer(
'ConstantFill', value=0.0),
767 alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
768 [input_t, alpha, beta_h],
769 self.
scope(
'alpha_by_input_t_plus_beta_h'),
775 alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
776 [alpha_by_input_t_plus_beta_h, prev_t],
777 self.
scope(
'alpha_by_input_t_plus_beta_h_by_prev_t')
781 beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
782 [input_t, beta_i, b],
783 self.
scope(
'beta_i_by_input_t_plus_b'),
790 [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
791 self.
scope(
'gates_t')
794 gates_t, _, _ = brew.layer_norm(
796 self.
scope(
'gates_t'),
797 self.
scope(
'gates_t_norm'),
801 hidden_t, cell_t = model.net.LSTMUnit(
802 [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
803 [self.
scope(
'hidden_t_intermediate'), self.
scope(
'cell_t')],
807 model.net.AddExternalOutputs(
813 return hidden_t, cell_t
818 Wraps arbitrary RNNCell, applying dropout to its output (but not to the 819 recurrent connection for the corresponding state). 831 assert 'is_test' in kwargs,
"Argument 'is_test' is required" 832 self.
is_test = kwargs.pop(
'is_test')
834 super(DropoutCell, self).__init__(**kwargs)
852 return self.internal_cell._apply(
861 def _prepare_output(self, model, states):
862 output = self.internal_cell._prepare_output(
870 def _prepare_output_sequence(self, model, state_outputs):
871 output = self.internal_cell._prepare_output_sequence(
879 def _apply_dropout(self, model, output):
881 with core.NameScope(self.
name or ''):
882 output = brew.dropout(
885 str(output) +
'_with_dropout_mask{}'.format(self.
mask),
895 def __init__(self, cells):
898 def create_states(self, model):
900 for i, cell
in enumerate(self.
cells):
901 if cell.initializer
is None:
902 raise Exception(
"Either initial states " 903 "or initializer have to be set")
905 with core.NameScope(
"layer_{}".format(i)),\
906 core.NameScope(cell.name):
907 states.extend(cell.initializer.create_states(model))
913 Multilayer RNN via the composition of RNNCell instance. 915 It is the resposibility of calling code to ensure the compatibility 916 of the successive layers in terms of input/output dimensiality, etc., 917 and to ensure that their blobs do not have name conflicts, typically by 918 creating the cells with names that specify layer number. 920 Assumes first state (recurrent output) for each layer should be the input 924 def __init__(self, cells, residual_output_layers=None, **kwargs):
926 cells: list of RNNCell instances, from input to output side. 928 name: string designating network component (for scoping) 930 residual_output_layers: list of indices of layers whose input will 931 be added elementwise to their output elementwise. (It is the 932 responsibility of the client code to ensure shape compatibility.) 933 Note that layer 0 (zero) cannot have residual output because of the 934 timing of prepare_input(). 936 forward_only: used to construct inference-only network. 938 super(MultiRNNCell, self).
__init__(**kwargs)
941 if residual_output_layers
is None:
946 output_index_per_layer = []
948 for cell
in self.
cells:
949 output_index_per_layer.append(
950 base_index + cell.get_output_state_index(),
952 base_index += len(cell.get_state_names())
956 for i
in range(len(self.
cells) - 1):
958 self.output_connected_layers.append(i)
959 self.output_indices.append(output_index_per_layer[i])
963 self.output_connected_layers.append(len(self.
cells) - 1)
964 self.output_indices.append(output_index_per_layer[-1])
967 for i, cell
in enumerate(self.
cells):
968 self.state_names.extend(
974 def layer_scoper(self, layer_id):
976 return "{}/layer_{}/{}".format(self.
name, layer_id, name)
979 def prepare_input(self, model, input_blob):
980 input_blob = _RectifyName(input_blob)
981 with core.NameScope(self.
name or ''):
982 return self.
cells[0].prepare_input(model, input_blob)
994 Because below we will do scoping across layers, we need 995 to make sure that string blob names are convereted to BlobReference 999 input_t, seq_lengths, states, timestep, extra_inputs = \
1001 input_t, seq_lengths, states, timestep, extra_inputs)
1003 states_per_layer = [len(cell.get_state_names())
for cell
in self.
cells]
1004 assert len(states) == sum(states_per_layer)
1009 layer_input = input_t
1010 for i, layer_cell
in enumerate(self.
cells):
1013 with core.NameScope(self.
name), core.NameScope(
"layer_{}".format(i)):
1014 num_states = states_per_layer[i]
1015 layer_states = states[states_index:(states_index + num_states)]
1016 states_index += num_states
1019 prepared_input = layer_cell.prepare_input(
1022 prepared_input = layer_input
1024 layer_next_states = layer_cell._apply(
1030 extra_inputs=(
None if i > 0
else extra_inputs),
1035 if i != len(self.
cells) - 1:
1036 layer_output = layer_cell._prepare_output(
1041 layer_input = brew.sum(
1043 [layer_output, layer_input],
1044 self.
scope(
'residual_output_{}'.format(i)),
1047 layer_input = layer_output
1049 next_states.extend(layer_next_states)
1052 def get_state_names(self):
1055 def get_output_state_index(self):
1057 for cell
in self.
cells[:-1]:
1058 index += len(cell.get_state_names())
1059 index += self.
cells[-1].get_output_state_index()
1062 def _prepare_output(self, model, states):
1063 connected_outputs = []
1065 for i, cell
in enumerate(self.
cells):
1066 num_states = len(cell.get_state_names())
1068 layer_states = states[state_index:state_index + num_states]
1069 layer_output = cell._prepare_output(
1073 connected_outputs.append(layer_output)
1074 state_index += num_states
1075 if len(connected_outputs) > 1:
1079 self.
scope(
'residual_output'),
1082 output = connected_outputs[0]
1085 def _prepare_output_sequence(self, model, states):
1086 connected_outputs = []
1088 for i, cell
in enumerate(self.
cells):
1089 num_states = 2 * len(cell.get_state_names())
1091 layer_states = states[state_index:state_index + num_states]
1092 layer_output = cell._prepare_output_sequence(
1096 connected_outputs.append(layer_output)
1097 state_index += num_states
1098 if len(connected_outputs) > 1:
1102 self.
scope(
'residual_output_sequence'),
1105 output = connected_outputs[0]
1119 weighted_encoder_outputs,
1120 attention_memory_optimization,
1123 super(AttentionCell, self).__init__(**kwargs)
1131 assert attention_type
in [
1132 AttentionType.Regular,
1133 AttentionType.Recurrent,
1135 AttentionType.SoftCoverage,
1150 decoder_prev_states = states[:-2]
1151 attention_weighted_encoder_context_t_prev = states[-2]
1152 coverage_t_prev = states[-1]
1154 decoder_prev_states = states[:-1]
1155 attention_weighted_encoder_context_t_prev = states[-1]
1157 assert extra_inputs
is None 1159 decoder_states = self.decoder_cell._apply(
1163 decoder_prev_states,
1166 attention_weighted_encoder_context_t_prev,
1178 attention_weighted_encoder_context_t,
1179 self.attention_weights_3d,
1181 ) = apply_recurrent_attention(
1189 attention_weighted_encoder_context_t_prev=(
1190 attention_weighted_encoder_context_t_prev
1196 attention_weighted_encoder_context_t,
1197 self.attention_weights_3d,
1199 ) = apply_regular_attention(
1211 attention_weighted_encoder_context_t,
1212 self.attention_weights_3d,
1214 ) = apply_dot_attention(
1225 attention_weighted_encoder_context_t,
1226 self.attention_weights_3d,
1229 ) = apply_soft_coverage_attention(
1238 coverage_t_prev=coverage_t_prev,
1242 raise Exception(
'Attention type {} not implemented'.format(
1247 self.recompute_blobs.extend(attention_blobs)
1249 output = list(decoder_states) + [attention_weighted_encoder_context_t]
1251 output.append(coverage_t)
1253 output[self.decoder_cell.get_output_state_index()] = model.Copy(
1254 output[self.decoder_cell.get_output_state_index()],
1255 self.
scope(
'hidden_t_external'),
1257 model.net.AddExternalOutputs(*output)
1261 def get_attention_weights(self):
1263 return self.attention_weights_3d
1265 def prepare_input(self, model, input_blob):
1270 self.
scope(
'encoder_outputs_transposed'),
1280 self.
scope(
'weighted_encoder_outputs'),
1286 return self.decoder_cell.prepare_input(model, input_blob)
1290 initial_coverage is always zeros of shape [encoder_length], 1291 which shape must be determined programmatically dureing network 1294 This method also sets self.coverage_weights, a separate transform 1295 of encoder_outputs which is used to determine coverage contribution 1304 self.
scope(
'coverage_weights'),
1310 encoder_length = model.net.Slice(
1316 scope.CurrentDeviceScope()
is not None and 1317 core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
1319 encoder_length = model.net.CopyGPUToCPU(
1321 'encoder_length_cpu',
1325 initial_coverage = model.net.ConstantFill(
1327 self.
scope(
'initial_coverage'),
1331 return initial_coverage
1333 def get_state_names(self):
1334 state_names = list(self.decoder_cell.get_state_names())
1336 'hidden_t_external',
1338 state_names.append(self.
scope(
'attention_weighted_encoder_context_t'))
1340 state_names.append(self.
scope(
'coverage_t'))
1343 def get_output_dim(self):
1346 def get_output_state_index(self):
1347 return self.decoder_cell.get_output_state_index()
1349 def _prepare_output(self, model, states):
1351 attention_context = states[-2]
1353 attention_context = states[-1]
1355 with core.NameScope(self.
name or ''):
1356 output = brew.concat(
1359 'states_and_context_combination',
1365 def _prepare_output_sequence(self, model, state_outputs):
1367 decoder_state_outputs = state_outputs[:-4]
1369 decoder_state_outputs = state_outputs[:-2]
1371 decoder_output = self.decoder_cell._prepare_output_sequence(
1373 decoder_state_outputs,
1381 with core.NameScope(self.
name or ''):
1382 output = brew.concat(
1386 state_outputs[attention_context_index],
1388 'states_and_context_combination',
1405 weighted_encoder_outputs,
1407 lstm_memory_optimization,
1408 attention_memory_optimization,
1412 input_size=decoder_input_dim,
1413 hidden_size=decoder_state_dim,
1414 forget_bias=forget_bias,
1415 memory_optimization=lstm_memory_optimization,
1416 name=
'{}/decoder'.format(name),
1420 super(LSTMWithAttentionCell, self).__init__(
1421 encoder_output_dim=encoder_output_dim,
1422 encoder_outputs=encoder_outputs,
1423 encoder_lengths=encoder_lengths,
1424 decoder_cell=decoder_cell,
1425 decoder_state_dim=decoder_state_dim,
1427 attention_type=attention_type,
1428 weighted_encoder_outputs=weighted_encoder_outputs,
1429 attention_memory_optimization=attention_memory_optimization,
1430 forward_only=forward_only,
1444 weighted_encoder_outputs,
1446 lstm_memory_optimization,
1447 attention_memory_optimization,
1451 input_size=decoder_input_dim,
1452 hidden_size=decoder_state_dim,
1453 forget_bias=forget_bias,
1454 memory_optimization=lstm_memory_optimization,
1455 name=
'{}/decoder'.format(name),
1459 super(MILSTMWithAttentionCell, self).__init__(
1460 encoder_output_dim=encoder_output_dim,
1461 encoder_outputs=encoder_outputs,
1462 decoder_cell=decoder_cell,
1463 decoder_state_dim=decoder_state_dim,
1465 attention_type=attention_type,
1466 weighted_encoder_outputs=weighted_encoder_outputs,
1467 attention_memory_optimization=attention_memory_optimization,
1468 forward_only=forward_only,
1481 outputs_with_grads=(0,),
1482 return_params=
False,
1483 memory_optimization=
False,
1487 return_last_layer_only=
True,
1488 static_rnn_unroll_size=
None,
1492 Adds a standard LSTM recurrent network operator to a model. 1494 cell_class: LSTMCell or compatible subclass 1496 model: ModelHelper object new operators would be added to 1498 input_blob: the input sequence in a format T x N x D 1499 where T is sequence size, N - batch size and D - input dimension 1501 seq_lengths: blob containing sequence lengths which would be passed to 1504 initial_states: a list of (2 * num_layers) blobs representing the initial 1505 hidden and cell states of each layer. If this argument is None, 1506 these states will be added to the model as network parameters. 1508 dim_in: input dimension 1510 dim_out: number of units per LSTM layer 1511 (use int for single-layer LSTM, list of ints for multi-layer) 1513 outputs_with_grads : position indices of output blobs for LAST LAYER which 1514 will receive external error gradient during backpropagation. 1515 These outputs are: (h_all, h_last, c_all, c_last) 1517 return_params: if True, will return a dictionary of parameters of the LSTM 1519 memory_optimization: if enabled, the LSTM step is recomputed on backward 1520 step so that we don't need to store forward activations for each 1521 timestep. Saves memory with cost of computation. 1523 forget_bias: forget gate bias (default 0.0) 1525 forward_only: whether to create a backward pass 1527 drop_states: drop invalid states, passed through to LSTMUnit operator 1529 return_last_layer_only: only return outputs from final layer 1530 (so that length of results does depend on number of layers) 1532 static_rnn_unroll_size: if not None, we will use static RNN which is 1533 unrolled into Caffe2 graph. The size of the unroll is the value of 1536 if type(dim_out)
is not list
and type(dim_out)
is not tuple:
1538 num_layers = len(dim_out)
1541 for i
in range(num_layers):
1543 input_size=(dim_in
if i == 0
else dim_out[i - 1]),
1544 hidden_size=dim_out[i],
1545 forget_bias=forget_bias,
1546 memory_optimization=memory_optimization,
1547 name=scope
if num_layers == 1
else None,
1548 forward_only=forward_only,
1549 drop_states=drop_states,
1557 forward_only=forward_only,
1558 )
if num_layers > 1
else cells[0]
1561 cell
if static_rnn_unroll_size
is None 1565 outputs_with_grads = [4 * (num_layers - 1) + i
for i
in outputs_with_grads]
1566 _, result = cell.apply_over_sequence(
1569 seq_lengths=seq_lengths,
1570 initial_states=initial_states,
1571 outputs_with_grads=outputs_with_grads,
1574 if return_last_layer_only:
1575 result = result[4 * (num_layers - 1):]
1577 result = list(result) + [{
1578 'input': cell.get_input_params(),
1579 'recurrent': cell.get_recurrent_params(),
1581 return tuple(result)
1584 LSTM = functools.partial(_LSTM, LSTMCell)
1585 BasicRNN = functools.partial(_LSTM, BasicRNNCell)
1586 MILSTM = functools.partial(_LSTM, MILSTMCell)
1587 LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
1588 LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
1592 def __init__(self, cell, T):
1596 def apply_over_sequence(
1602 outputs_with_grads=
None,
1604 inputs = self.cell.prepare_input(model, inputs)
1607 split_inputs = model.net.Split(
1609 [str(inputs) +
"_timestep_{}".format(i)
1610 for i
in range(self.
T)],
1613 split_inputs = [split_inputs]
1615 states = initial_states
1617 for t
in range(0, self.
T):
1618 scope_name =
"timestep_{}".format(t)
1620 with ParameterSharing({scope_name:
''}),\
1621 scope.NameScope(scope_name):
1622 timestep = model.param_init_net.ConstantFill(
1623 [],
"timestep", value=t, shape=[1],
1624 dtype=core.DataType.INT32,
1625 device_option=core.DeviceOption(caffe2_pb2.CPU))
1626 states = self.cell._apply(
1628 input_t=split_inputs[t],
1629 seq_lengths=seq_lengths,
1633 all_states.append(states)
1635 all_states = zip(*all_states)
1640 str(full_output[0])[len(
"timestep_0/"):] +
"_concat",
1641 str(full_output[0])[len(
"timestep_0/"):] +
"_concat_info" 1645 for full_output
in all_states
1648 six.next(it)
for it
in 1649 itertools.cycle([iter(all_states), iter(states)])
1651 outputs_without_grad = set(range(len(outputs))) - set(
1653 for i
in outputs_without_grad:
1654 model.net.ZeroGradient(outputs[i], [])
1655 logging.debug(
"Added 0 gradients for blobs:",
1656 [outputs[i]
for i
in outputs_without_grad])
1658 final_output = self.cell._prepare_output_sequence(model, outputs)
1660 return final_output, outputs
1663 def GetLSTMParamNames():
1664 weight_params = [
"input_gate_w",
"forget_gate_w",
"output_gate_w",
"cell_w"]
1665 bias_params = [
"input_gate_b",
"forget_gate_b",
"output_gate_b",
"cell_b"]
1666 return {
'weights': weight_params,
'biases': bias_params}
1669 def InitFromLSTMParams(lstm_pblobs, param_values):
1671 Set the parameters of LSTM based on predefined values 1673 weight_params = GetLSTMParamNames()[
'weights']
1674 bias_params = GetLSTMParamNames()[
'biases']
1675 for input_type
in viewkeys(param_values):
1677 param_values[input_type][w].flatten()
1678 for w
in weight_params
1681 for w
in weight_values:
1682 wmat = np.append(wmat, w)
1684 param_values[input_type][b].flatten()
1685 for b
in bias_params
1688 for b
in bias_values:
1689 bm = np.append(bm, b)
1691 weights_blob = lstm_pblobs[input_type][
'weights']
1692 bias_blob = lstm_pblobs[input_type][
'biases']
1693 cur_weight = workspace.FetchBlob(weights_blob)
1694 cur_biases = workspace.FetchBlob(bias_blob)
1698 wmat.reshape(cur_weight.shape).astype(np.float32))
1701 bm.reshape(cur_biases.shape).astype(np.float32))
1704 def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
1705 scope, recurrent_params=
None, input_params=
None,
1706 num_layers=1, return_params=
False):
1708 CuDNN version of LSTM for GPUs. 1709 input_blob Blob containing the input. Will need to be available 1710 when param_init_net is run, because the sequence lengths 1711 and batch sizes will be inferred from the size of this 1713 initial_states tuple of (hidden_init, cell_init) blobs 1714 dim_in input dimensions 1715 dim_out output/hidden dimension 1716 scope namescope to apply 1717 recurrent_params dict of blobs containing values for recurrent 1718 gate weights, biases (if None, use random init values) 1719 See GetLSTMParamNames() for format. 1720 input_params dict of blobs containing values for input 1721 gate weights, biases (if None, use random init values) 1722 See GetLSTMParamNames() for format. 1723 num_layers number of LSTM layers 1724 return_params if True, returns (param_extract_net, param_mapping) 1725 where param_extract_net is a net that when run, will 1726 populate the blobs specified in param_mapping with the 1727 current gate weights and biases (input/recurrent). 1728 Useful for assigning the values back to non-cuDNN 1731 with core.NameScope(scope):
1732 weight_params = GetLSTMParamNames()[
'weights']
1733 bias_params = GetLSTMParamNames()[
'biases']
1735 input_weight_size = dim_out * dim_in
1736 upper_layer_input_weight_size = dim_out * dim_out
1737 recurrent_weight_size = dim_out * dim_out
1738 input_bias_size = dim_out
1739 recurrent_bias_size = dim_out
1741 def init(layer, pname, input_type):
1742 input_weight_size_for_layer = input_weight_size
if layer == 0
else \
1743 upper_layer_input_weight_size
1744 if pname
in weight_params:
1745 sz = input_weight_size_for_layer
if input_type ==
'input' \
1746 else recurrent_weight_size
1747 elif pname
in bias_params:
1748 sz = input_bias_size
if input_type ==
'input' \
1749 else recurrent_bias_size
1751 assert False,
"unknown parameter type {}".format(pname)
1752 return model.param_init_net.UniformFill(
1754 "lstm_init_{}_{}_{}".format(input_type, pname, layer),
1758 first_layer_sz = input_weight_size + recurrent_weight_size + \
1759 input_bias_size + recurrent_bias_size
1760 upper_layer_sz = upper_layer_input_weight_size + \
1761 recurrent_weight_size + input_bias_size + \
1763 total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
1765 weights = model.create_param(
1769 tags=ParameterTags.WEIGHT,
1773 'hidden_size': dim_out,
1777 'input_mode':
'linear',
1778 'num_layers': num_layers,
1782 param_extract_net =
core.Net(
"lstm_param_extractor")
1783 param_extract_net.AddExternalInputs([input_blob, weights])
1784 param_extract_mapping = {}
1791 for input_type
in [
'input',
'recurrent']:
1792 param_extract_mapping[input_type] = {}
1793 p = recurrent_params
if input_type ==
'recurrent' else input_params
1796 for pname
in weight_params + bias_params:
1797 for j
in range(0, num_layers):
1798 values = p[pname]
if pname
in p
else init(j, pname, input_type)
1799 model.param_init_net.RecurrentParamSet(
1800 [input_blob, weights, values],
1803 input_type=input_type,
1807 if pname
not in param_extract_mapping[input_type]:
1808 param_extract_mapping[input_type][pname] = {}
1809 b = param_extract_net.RecurrentParamGet(
1810 [input_blob, weights],
1811 [
"lstm_{}_{}_{}".format(input_type, pname, j)],
1813 input_type=input_type,
1817 param_extract_mapping[input_type][pname][j] = b
1819 (hidden_input_blob, cell_input_blob) = initial_states
1820 output, hidden_output, cell_output, rnn_scratch, dropout_states = \
1821 model.net.Recurrent(
1822 [input_blob, hidden_input_blob, cell_input_blob, weights],
1823 [
"lstm_output",
"lstm_hidden_output",
"lstm_cell_output",
1824 "lstm_rnn_scratch",
"lstm_dropout_states"],
1825 seed=random.randint(0, 100000),
1828 model.net.AddExternalOutputs(
1829 hidden_output, cell_output, rnn_scratch, dropout_states)
1832 param_extract = param_extract_net, param_extract_mapping
1833 return output, hidden_output, cell_output, param_extract
1835 return output, hidden_output, cell_output
1838 def LSTMWithAttention(
1841 decoder_input_lengths,
1842 initial_decoder_hidden_state,
1843 initial_decoder_cell_state,
1844 initial_attention_weighted_encoder_context,
1851 attention_type=AttentionType.Regular,
1852 outputs_with_grads=(0, 4),
1853 weighted_encoder_outputs=
None,
1854 lstm_memory_optimization=
False,
1855 attention_memory_optimization=
False,
1860 Adds a LSTM with attention mechanism to a model. 1862 The implementation is based on https://arxiv.org/abs/1409.0473, with 1863 a small difference in the order 1864 how we compute new attention context and new hidden state, similarly to 1865 https://arxiv.org/abs/1508.04025. 1867 The model uses encoder-decoder naming conventions, 1868 where the decoder is the sequence the op is iterating over, 1869 while computing the attention context over the encoder. 1871 model: ModelHelper object new operators would be added to 1873 decoder_inputs: the input sequence in a format T x N x D 1874 where T is sequence size, N - batch size and D - input dimension 1876 decoder_input_lengths: blob containing sequence lengths 1877 which would be passed to LSTMUnit operator 1879 initial_decoder_hidden_state: initial hidden state of LSTM 1881 initial_decoder_cell_state: initial cell state of LSTM 1883 initial_attention_weighted_encoder_context: initial attention context 1885 encoder_output_dim: dimension of encoder outputs 1887 encoder_outputs: the sequence, on which we compute the attention context 1890 encoder_lengths: a tensor with lengths of each encoder sequence in batch 1891 (may be None, meaning all encoder sequences are of same length) 1893 decoder_input_dim: input dimension (last dimension on decoder_inputs) 1895 decoder_state_dim: size of hidden states of LSTM 1897 attention_type: One of: AttentionType.Regular, AttentionType.Recurrent. 1898 Determines which type of attention mechanism to use. 1900 outputs_with_grads : position indices of output blobs which will receive 1901 external error gradient during backpropagation 1903 weighted_encoder_outputs: encoder outputs to be used to compute attention 1904 weights. In the basic case it's just linear transformation of 1905 encoder outputs (that the default, when weighted_encoder_outputs is None). 1906 However, it can be something more complicated - like a separate 1907 encoder network (for example, in case of convolutional encoder) 1909 lstm_memory_optimization: recompute LSTM activations on backward pass, so 1910 we don't need to store their values in forward passes 1912 attention_memory_optimization: recompute attention for backward pass 1914 forward_only: whether to create only forward pass 1917 encoder_output_dim=encoder_output_dim,
1918 encoder_outputs=encoder_outputs,
1919 encoder_lengths=encoder_lengths,
1920 decoder_input_dim=decoder_input_dim,
1921 decoder_state_dim=decoder_state_dim,
1923 attention_type=attention_type,
1924 weighted_encoder_outputs=weighted_encoder_outputs,
1925 forget_bias=forget_bias,
1926 lstm_memory_optimization=lstm_memory_optimization,
1927 attention_memory_optimization=attention_memory_optimization,
1928 forward_only=forward_only,
1931 initial_decoder_hidden_state,
1932 initial_decoder_cell_state,
1933 initial_attention_weighted_encoder_context,
1935 if attention_type == AttentionType.SoftCoverage:
1936 initial_states.append(cell.build_initial_coverage(model))
1937 _, result = cell.apply_over_sequence(
1939 inputs=decoder_inputs,
1940 seq_lengths=decoder_input_lengths,
1941 initial_states=initial_states,
1942 outputs_with_grads=outputs_with_grads,
1948 model, input_blob, seq_lengths, initial_states,
1949 dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=
False,
1950 memory_optimization=
False, forget_bias=0.0, forward_only=
False,
1951 drop_states=
False, create_lstm=
None):
1953 params.pop(
'create_lstm')
1954 if not isinstance(dim_out, list):
1955 return create_lstm(**params)
1956 elif len(dim_out) == 1:
1957 params[
'dim_out'] = dim_out[0]
1958 return create_lstm(**params)
1960 assert len(dim_out) != 0,
"dim_out list can't be empty" 1961 assert return_params
is False,
"return_params not supported for layering" 1962 for i, output_dim
in enumerate(dim_out):
1964 'dim_out': output_dim
1966 output, last_output, all_states, last_state = create_lstm(**params)
1968 'input_blob': output,
1969 'dim_in': output_dim,
1970 'initial_states': (last_output, last_state),
1971 'scope': scope +
'_layer_{}'.format(i + 1)
1973 return output, last_output, all_states, last_state
1976 layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
def apply_override(self, model, input_t, seq_lengths, timestep, extra_inputs=None)
def __init__(self, cells, residual_output_layers=None, kwargs)
Module caffe2.python.scope.
def layer_scoper(self, layer_id)
def get_output_state_index(self)
def get_state_names(self)
def _rectify_apply_inputs(self, input_t, seq_lengths, states, timestep, extra_inputs)
attention_memory_optimization
def _apply(self, model, input_t, seq_lengths, states, timestep, extra_inputs=None)
def _prepare_output_sequence(self, model, state_outputs)
def get_state_names_override(self)
def build_initial_coverage(self, model)
encoder_outputs_transposed
def prepare_input(self, model, input_blob)
def _prepare_output(self, model, states)
def _apply_dropout(self, model, output)