Caffe2 - Python API
A deep learning, cross platform ML framework
rnn_cell.py
1 ## @package rnn_cell
2 # Module caffe2.python.rnn_cell
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import functools
9 import inspect
10 import itertools
11 import logging
12 import numpy as np
13 import random
14 import six
15 from future.utils import viewkeys
16 
17 from caffe2.proto import caffe2_pb2
18 from caffe2.python.attention import (
19  apply_dot_attention,
20  apply_recurrent_attention,
21  apply_regular_attention,
22  apply_soft_coverage_attention,
23  AttentionType,
24 )
25 from caffe2.python import core, recurrent, workspace, brew, scope, utils
26 from caffe2.python.modeling.parameter_sharing import ParameterSharing
27 from caffe2.python.modeling.parameter_info import ParameterTags
28 from caffe2.python.modeling.initializers import Initializer
29 from caffe2.python.model_helper import ModelHelper
30 
31 
32 def _RectifyName(blob_reference_or_name):
33  if blob_reference_or_name is None:
34  return None
35  if isinstance(blob_reference_or_name, six.string_types):
36  return core.ScopedBlobReference(blob_reference_or_name)
37  if not isinstance(blob_reference_or_name, core.BlobReference):
38  raise Exception("Unknown blob reference type")
39  return blob_reference_or_name
40 
41 
42 def _RectifyNames(blob_references_or_names):
43  if blob_references_or_names is None:
44  return None
45  return list(map(_RectifyName, blob_references_or_names))
46 
47 
48 class RNNCell(object):
49  '''
50  Base class for writing recurrent / stateful operations.
51 
52  One needs to implement 2 methods: apply_override
53  and get_state_names_override.
54 
55  As a result base class will provice apply_over_sequence method, which
56  allows you to apply recurrent operations over a sequence of any length.
57 
58  As optional you could add input and output preparation steps by overriding
59  corresponding methods.
60  '''
61  def __init__(self, name=None, forward_only=False, initializer=None):
62  self.name = name
63  self.recompute_blobs = []
64  self.forward_only = forward_only
65  self._initializer = initializer
66 
67  @property
68  def initializer(self):
69  return self._initializer
70 
71  @initializer.setter
72  def initializer(self, value):
73  self._initializer = value
74 
75  def scope(self, name):
76  return self.name + '/' + name if self.name is not None else name
77 
78  def apply_over_sequence(
79  self,
80  model,
81  inputs,
82  seq_lengths=None,
83  initial_states=None,
84  outputs_with_grads=None,
85  ):
86  if initial_states is None:
87  with scope.NameScope(self.name):
88  if self.initializer is None:
89  raise Exception("Either initial states "
90  "or initializer have to be set")
91  initial_states = self.initializer.create_states(model)
92 
93  preprocessed_inputs = self.prepare_input(model, inputs)
94  step_model = ModelHelper(name=self.name, param_model=model)
95  input_t, timestep = step_model.net.AddScopedExternalInputs(
96  'input_t',
97  'timestep',
98  )
99  utils.raiseIfNotEqual(
100  len(initial_states), len(self.get_state_names()),
101  "Number of initial state values provided doesn't match the number "
102  "of states"
103  )
104  states_prev = step_model.net.AddScopedExternalInputs(*[
105  s + '_prev' for s in self.get_state_names()
106  ])
107  states = self._apply(
108  model=step_model,
109  input_t=input_t,
110  seq_lengths=seq_lengths,
111  states=states_prev,
112  timestep=timestep,
113  )
114 
115  external_outputs = set(step_model.net.Proto().external_output)
116  for state in states:
117  if state not in external_outputs:
118  step_model.net.AddExternalOutput(state)
119 
120  if outputs_with_grads is None:
121  outputs_with_grads = [self.get_output_state_index() * 2]
122 
123  # states_for_all_steps consists of combination of
124  # states gather for all steps and final states. It looks like this:
125  # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
126  states_for_all_steps = recurrent.recurrent_net(
127  net=model.net,
128  cell_net=step_model.net,
129  inputs=[(input_t, preprocessed_inputs)],
130  initial_cell_inputs=list(zip(states_prev, initial_states)),
131  links=dict(zip(states_prev, states)),
132  timestep=timestep,
133  scope=self.name,
134  forward_only=self.forward_only,
135  outputs_with_grads=outputs_with_grads,
136  recompute_blobs_on_backward=self.recompute_blobs,
137  )
138 
139  output = self._prepare_output_sequence(
140  model,
141  states_for_all_steps,
142  )
143  return output, states_for_all_steps
144 
145  def apply(self, model, input_t, seq_lengths, states, timestep):
146  input_t = self.prepare_input(model, input_t)
147  states = self._apply(
148  model, input_t, seq_lengths, states, timestep)
149  output = self._prepare_output(model, states)
150  return output, states
151 
152  def _apply(
153  self,
154  model, input_t, seq_lengths, states, timestep, extra_inputs=None
155  ):
156  '''
157  This method uses apply_override provided by a custom cell.
158  On the top it takes care of applying self.scope() to all the outputs.
159  While all the inputs stay within the scope this function was called
160  from.
161  '''
162  args = self._rectify_apply_inputs(
163  input_t, seq_lengths, states, timestep, extra_inputs)
164  with core.NameScope(self.name):
165  return self.apply_override(model, *args)
166 
167  def _rectify_apply_inputs(
168  self, input_t, seq_lengths, states, timestep, extra_inputs):
169  '''
170  Before applying a scope we make sure that all external blob names
171  are converted to blob reference. So further scoping doesn't affect them
172  '''
173 
174  input_t, seq_lengths, timestep = _RectifyNames(
175  [input_t, seq_lengths, timestep])
176  states = _RectifyNames(states)
177  if extra_inputs:
178  extra_input_names, extra_input_sizes = zip(*extra_inputs)
179  extra_inputs = _RectifyNames(extra_input_names)
180  extra_inputs = zip(extra_input_names, extra_input_sizes)
181 
182  arg_names = inspect.getargspec(self.apply_override).args
183  rectified = [input_t, seq_lengths, states, timestep]
184  if 'extra_inputs' in arg_names:
185  rectified.append(extra_inputs)
186  return rectified
187 
188 
189  def apply_override(
190  self,
191  model, input_t, seq_lengths, timestep, extra_inputs=None,
192  ):
193  '''
194  A single step of a recurrent network to be implemented by each custom
195  RNNCell.
196 
197  model: ModelHelper object new operators would be added to
198 
199  input_t: singlse input with shape (1, batch_size, input_dim)
200 
201  seq_lengths: blob containing sequence lengths which would be passed to
202  LSTMUnit operator
203 
204  states: previous recurrent states
205 
206  timestep: current recurrent iteration. Could be used together with
207  seq_lengths in order to determine, if some shorter sequences
208  in the batch have already ended.
209 
210  extra_inputs: list of tuples (input, dim). specifies additional input
211  which is not subject to prepare_input(). (useful when a cell is a
212  component of a larger recurrent structure, e.g., attention)
213  '''
214  raise NotImplementedError('Abstract method')
215 
216  def prepare_input(self, model, input_blob):
217  '''
218  If some operations in _apply method depend only on the input,
219  not on recurrent states, they could be computed in advance.
220 
221  model: ModelHelper object new operators would be added to
222 
223  input_blob: either the whole input sequence with shape
224  (sequence_length, batch_size, input_dim) or a single input with shape
225  (1, batch_size, input_dim).
226  '''
227  return input_blob
228 
230  '''
231  Return index into state list of the "primary" step-wise output.
232  '''
233  return 0
234 
235  def get_state_names(self):
236  '''
237  Returns recurrent state names with self.name scoping applied
238  '''
239  return list(map(self.scope, self.get_state_names_override()))
240 
242  '''
243  Override this function in your custom cell.
244  It should return the names of the recurrent states.
245 
246  It's required by apply_over_sequence method in order to allocate
247  recurrent states for all steps with meaningful names.
248  '''
249  raise NotImplementedError('Abstract method')
250 
251  def get_output_dim(self):
252  '''
253  Specifies the dimension (number of units) of stepwise output.
254  '''
255  raise NotImplementedError('Abstract method')
256 
257  def _prepare_output(self, model, states):
258  '''
259  Allows arbitrary post-processing of primary output.
260  '''
261  return states[self.get_output_state_index()]
262 
263  def _prepare_output_sequence(self, model, state_outputs):
264  '''
265  Allows arbitrary post-processing of primary sequence output.
266 
267  (Note that state_outputs alternates between full-sequence and final
268  output for each state, thus the index multiplier 2.)
269  '''
270  output_sequence_index = 2 * self.get_output_state_index()
271  return state_outputs[output_sequence_index]
272 
273 
274 class LSTMInitializer(object):
275  def __init__(self, hidden_size):
276  self.hidden_size = hidden_size
277 
278  def create_states(self, model):
279  return [
280  model.create_param(
281  param_name='initial_hidden_state',
282  initializer=Initializer(operator_name='ConstantFill',
283  value=0.0),
284  shape=[self.hidden_size],
285  ),
286  model.create_param(
287  param_name='initial_cell_state',
288  initializer=Initializer(operator_name='ConstantFill',
289  value=0.0),
290  shape=[self.hidden_size],
291  )
292  ]
293 
294 
295 # based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
297  def __init__(
298  self,
299  input_size,
300  hidden_size,
301  forget_bias,
302  memory_optimization,
303  drop_states=False,
304  initializer=None,
305  activation=None,
306  **kwargs
307  ):
308  super(BasicRNNCell, self).__init__(**kwargs)
309  self.drop_states = drop_states
310  self.input_size = input_size
311  self.hidden_size = hidden_size
312  self.activation = activation
313 
314  if self.activation not in ['relu', 'tanh']:
315  raise RuntimeError(
316  'BasicRNNCell with unknown activation function (%s)'
317  % self.activation)
318 
319  def apply_override(
320  self,
321  model,
322  input_t,
323  seq_lengths,
324  states,
325  timestep,
326  extra_inputs=None,
327  ):
328  hidden_t_prev = states[0]
329 
330  gates_t = brew.fc(
331  model,
332  hidden_t_prev,
333  'gates_t',
334  dim_in=self.hidden_size,
335  dim_out=self.hidden_size,
336  axis=2,
337  )
338 
339  brew.sum(model, [gates_t, input_t], gates_t)
340  if self.activation == 'tanh':
341  hidden_t = model.net.Tanh(gates_t, 'hidden_t')
342  elif self.activation == 'relu':
343  hidden_t = model.net.Relu(gates_t, 'hidden_t')
344  else:
345  raise RuntimeError(
346  'BasicRNNCell with unknown activation function (%s)'
347  % self.activation)
348 
349  if seq_lengths is not None:
350  # TODO If this codepath becomes popular, it may be worth
351  # taking a look at optimizing it - for now a simple
352  # implementation is used to round out compatibility with
353  # ONNX.
354  timestep = model.net.CopyFromCPUInput(
355  timestep, 'timestep_gpu')
356  valid_b = model.net.GT(
357  [seq_lengths, timestep], 'valid_b', broadcast=1)
358  invalid_b = model.net.LE(
359  [seq_lengths, timestep], 'invalid_b', broadcast=1)
360  valid = model.net.Cast(valid_b, 'valid', to='float')
361  invalid = model.net.Cast(invalid_b, 'invalid', to='float')
362 
363  hidden_valid = model.net.Mul(
364  [hidden_t, valid],
365  'hidden_valid',
366  broadcast=1,
367  axis=1,
368  )
369  if self.drop_states:
370  hidden_t = hidden_valid
371  else:
372  hidden_invalid = model.net.Mul(
373  [hidden_t_prev, invalid],
374  'hidden_invalid',
375  broadcast=1, axis=1)
376  hidden_t = model.net.Add(
377  [hidden_valid, hidden_invalid], hidden_t)
378  return (hidden_t,)
379 
380  def prepare_input(self, model, input_blob):
381  return brew.fc(
382  model,
383  input_blob,
384  self.scope('i2h'),
385  dim_in=self.input_size,
386  dim_out=self.hidden_size,
387  axis=2,
388  )
389 
390  def get_state_names(self):
391  return (self.scope('hidden_t'),)
392 
393  def get_output_dim(self):
394  return self.hidden_size
395 
396 
398 
399  def __init__(
400  self,
401  input_size,
402  hidden_size,
403  forget_bias,
404  memory_optimization,
405  drop_states=False,
406  initializer=None,
407  **kwargs
408  ):
409  super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
410  self.initializer = initializer or LSTMInitializer(
411  hidden_size=hidden_size)
412 
413  self.input_size = input_size
414  self.hidden_size = hidden_size
415  self.forget_bias = float(forget_bias)
416  self.memory_optimization = memory_optimization
417  self.drop_states = drop_states
418  self.gates_size = 4 * self.hidden_size
419 
420  def apply_override(
421  self,
422  model,
423  input_t,
424  seq_lengths,
425  states,
426  timestep,
427  extra_inputs=None,
428  ):
429  hidden_t_prev, cell_t_prev = states
430 
431  fc_input = hidden_t_prev
432  fc_input_dim = self.hidden_size
433 
434  if extra_inputs is not None:
435  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
436  fc_input = brew.concat(
437  model,
438  [hidden_t_prev] + list(extra_input_blobs),
439  'gates_concatenated_input_t',
440  axis=2,
441  )
442  fc_input_dim += sum(extra_input_sizes)
443 
444  gates_t = brew.fc(
445  model,
446  fc_input,
447  'gates_t',
448  dim_in=fc_input_dim,
449  dim_out=self.gates_size,
450  axis=2,
451  )
452  brew.sum(model, [gates_t, input_t], gates_t)
453 
454  if seq_lengths is not None:
455  inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
456  else:
457  inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
458 
459  hidden_t, cell_t = model.net.LSTMUnit(
460  inputs,
461  ['hidden_state', 'cell_state'],
462  forget_bias=self.forget_bias,
463  drop_states=self.drop_states,
464  sequence_lengths=(seq_lengths is not None),
465  )
466  model.net.AddExternalOutputs(hidden_t, cell_t)
467  if self.memory_optimization:
468  self.recompute_blobs = [gates_t]
469 
470  return hidden_t, cell_t
471 
472  def get_input_params(self):
473  return {
474  'weights': self.scope('i2h') + '_w',
475  'biases': self.scope('i2h') + '_b',
476  }
477 
478  def get_recurrent_params(self):
479  return {
480  'weights': self.scope('gates_t') + '_w',
481  'biases': self.scope('gates_t') + '_b',
482  }
483 
484  def prepare_input(self, model, input_blob):
485  return brew.fc(
486  model,
487  input_blob,
488  self.scope('i2h'),
489  dim_in=self.input_size,
490  dim_out=self.gates_size,
491  axis=2,
492  )
493 
494  def get_state_names_override(self):
495  return ['hidden_t', 'cell_t']
496 
497  def get_output_dim(self):
498  return self.hidden_size
499 
500 
502 
503  def __init__(
504  self,
505  input_size,
506  hidden_size,
507  forget_bias,
508  memory_optimization,
509  drop_states=False,
510  initializer=None,
511  **kwargs
512  ):
513  super(LayerNormLSTMCell, self).__init__(
514  initializer=initializer, **kwargs
515  )
516  self.initializer = initializer or LSTMInitializer(
517  hidden_size=hidden_size
518  )
519 
520  self.input_size = input_size
521  self.hidden_size = hidden_size
522  self.forget_bias = float(forget_bias)
523  self.memory_optimization = memory_optimization
524  self.drop_states = drop_states
525  self.gates_size = 4 * self.hidden_size
526 
527  def _apply(
528  self,
529  model,
530  input_t,
531  seq_lengths,
532  states,
533  timestep,
534  extra_inputs=None,
535  ):
536  hidden_t_prev, cell_t_prev = states
537 
538  fc_input = hidden_t_prev
539  fc_input_dim = self.hidden_size
540 
541  if extra_inputs is not None:
542  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
543  fc_input = brew.concat(
544  model,
545  [hidden_t_prev] + list(extra_input_blobs),
546  self.scope('gates_concatenated_input_t'),
547  axis=2,
548  )
549  fc_input_dim += sum(extra_input_sizes)
550 
551  gates_t = brew.fc(
552  model,
553  fc_input,
554  self.scope('gates_t'),
555  dim_in=fc_input_dim,
556  dim_out=self.gates_size,
557  axis=2,
558  )
559  brew.sum(model, [gates_t, input_t], gates_t)
560 
561  # brew.layer_norm call is only difference from LSTMCell
562  gates_t, _, _ = brew.layer_norm(
563  model,
564  self.scope('gates_t'),
565  self.scope('gates_t_norm'),
566  dim_in=self.gates_size,
567  axis=-1,
568  )
569 
570  hidden_t, cell_t = model.net.LSTMUnit(
571  [
572  hidden_t_prev,
573  cell_t_prev,
574  gates_t,
575  seq_lengths,
576  timestep,
577  ],
578  self.get_state_names(),
579  forget_bias=self.forget_bias,
580  drop_states=self.drop_states,
581  )
582  model.net.AddExternalOutputs(hidden_t, cell_t)
583  if self.memory_optimization:
584  self.recompute_blobs = [gates_t]
585 
586  return hidden_t, cell_t
587 
588  def get_input_params(self):
589  return {
590  'weights': self.scope('i2h') + '_w',
591  'biases': self.scope('i2h') + '_b',
592  }
593 
594  def prepare_input(self, model, input_blob):
595  return brew.fc(
596  model,
597  input_blob,
598  self.scope('i2h'),
599  dim_in=self.input_size,
600  dim_out=self.gates_size,
601  axis=2,
602  )
603 
604  def get_state_names(self):
605  return (self.scope('hidden_t'), self.scope('cell_t'))
606 
607 
609 
610  def _apply(
611  self,
612  model,
613  input_t,
614  seq_lengths,
615  states,
616  timestep,
617  extra_inputs=None,
618  ):
619  hidden_t_prev, cell_t_prev = states
620 
621  fc_input = hidden_t_prev
622  fc_input_dim = self.hidden_size
623 
624  if extra_inputs is not None:
625  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
626  fc_input = brew.concat(
627  model,
628  [hidden_t_prev] + list(extra_input_blobs),
629  self.scope('gates_concatenated_input_t'),
630  axis=2,
631  )
632  fc_input_dim += sum(extra_input_sizes)
633 
634  prev_t = brew.fc(
635  model,
636  fc_input,
637  self.scope('prev_t'),
638  dim_in=fc_input_dim,
639  dim_out=self.gates_size,
640  axis=2,
641  )
642 
643  # defining initializers for MI parameters
644  alpha = model.create_param(
645  self.scope('alpha'),
646  shape=[self.gates_size],
647  initializer=Initializer('ConstantFill', value=1.0),
648  )
649  beta_h = model.create_param(
650  self.scope('beta1'),
651  shape=[self.gates_size],
652  initializer=Initializer('ConstantFill', value=1.0),
653  )
654  beta_i = model.create_param(
655  self.scope('beta2'),
656  shape=[self.gates_size],
657  initializer=Initializer('ConstantFill', value=1.0),
658  )
659  b = model.create_param(
660  self.scope('b'),
661  shape=[self.gates_size],
662  initializer=Initializer('ConstantFill', value=0.0),
663  )
664 
665  # alpha * input_t + beta_h
666  # Shape: [1, batch_size, 4 * hidden_size]
667  alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
668  [input_t, alpha, beta_h],
669  self.scope('alpha_by_input_t_plus_beta_h'),
670  axis=2,
671  )
672  # (alpha * input_t + beta_h) * prev_t =
673  # alpha * input_t * prev_t + beta_h * prev_t
674  # Shape: [1, batch_size, 4 * hidden_size]
675  alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
676  [alpha_by_input_t_plus_beta_h, prev_t],
677  self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
678  )
679  # beta_i * input_t + b
680  # Shape: [1, batch_size, 4 * hidden_size]
681  beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
682  [input_t, beta_i, b],
683  self.scope('beta_i_by_input_t_plus_b'),
684  axis=2,
685  )
686  # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
687  # Shape: [1, batch_size, 4 * hidden_size]
688  gates_t = brew.sum(
689  model,
690  [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
691  self.scope('gates_t')
692  )
693  hidden_t, cell_t = model.net.LSTMUnit(
694  [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
695  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
696  forget_bias=self.forget_bias,
697  drop_states=self.drop_states,
698  )
699  model.net.AddExternalOutputs(
700  cell_t,
701  hidden_t,
702  )
703  if self.memory_optimization:
704  self.recompute_blobs = [gates_t]
705  return hidden_t, cell_t
706 
707 
709 
710  def _apply(
711  self,
712  model,
713  input_t,
714  seq_lengths,
715  states,
716  timestep,
717  extra_inputs=None,
718  ):
719  hidden_t_prev, cell_t_prev = states
720 
721  fc_input = hidden_t_prev
722  fc_input_dim = self.hidden_size
723 
724  if extra_inputs is not None:
725  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
726  fc_input = brew.concat(
727  model,
728  [hidden_t_prev] + list(extra_input_blobs),
729  self.scope('gates_concatenated_input_t'),
730  axis=2,
731  )
732  fc_input_dim += sum(extra_input_sizes)
733 
734  prev_t = brew.fc(
735  model,
736  fc_input,
737  self.scope('prev_t'),
738  dim_in=fc_input_dim,
739  dim_out=self.gates_size,
740  axis=2,
741  )
742 
743  # defining initializers for MI parameters
744  alpha = model.create_param(
745  self.scope('alpha'),
746  shape=[self.gates_size],
747  initializer=Initializer('ConstantFill', value=1.0),
748  )
749  beta_h = model.create_param(
750  self.scope('beta1'),
751  shape=[self.gates_size],
752  initializer=Initializer('ConstantFill', value=1.0),
753  )
754  beta_i = model.create_param(
755  self.scope('beta2'),
756  shape=[self.gates_size],
757  initializer=Initializer('ConstantFill', value=1.0),
758  )
759  b = model.create_param(
760  self.scope('b'),
761  shape=[self.gates_size],
762  initializer=Initializer('ConstantFill', value=0.0),
763  )
764 
765  # alpha * input_t + beta_h
766  # Shape: [1, batch_size, 4 * hidden_size]
767  alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
768  [input_t, alpha, beta_h],
769  self.scope('alpha_by_input_t_plus_beta_h'),
770  axis=2,
771  )
772  # (alpha * input_t + beta_h) * prev_t =
773  # alpha * input_t * prev_t + beta_h * prev_t
774  # Shape: [1, batch_size, 4 * hidden_size]
775  alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
776  [alpha_by_input_t_plus_beta_h, prev_t],
777  self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
778  )
779  # beta_i * input_t + b
780  # Shape: [1, batch_size, 4 * hidden_size]
781  beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
782  [input_t, beta_i, b],
783  self.scope('beta_i_by_input_t_plus_b'),
784  axis=2,
785  )
786  # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
787  # Shape: [1, batch_size, 4 * hidden_size]
788  gates_t = brew.sum(
789  model,
790  [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
791  self.scope('gates_t')
792  )
793  # brew.layer_norm call is only difference from MILSTMCell._apply
794  gates_t, _, _ = brew.layer_norm(
795  model,
796  self.scope('gates_t'),
797  self.scope('gates_t_norm'),
798  dim_in=self.gates_size,
799  axis=-1,
800  )
801  hidden_t, cell_t = model.net.LSTMUnit(
802  [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
803  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
804  forget_bias=self.forget_bias,
805  drop_states=self.drop_states,
806  )
807  model.net.AddExternalOutputs(
808  cell_t,
809  hidden_t,
810  )
811  if self.memory_optimization:
812  self.recompute_blobs = [gates_t]
813  return hidden_t, cell_t
814 
815 
817  '''
818  Wraps arbitrary RNNCell, applying dropout to its output (but not to the
819  recurrent connection for the corresponding state).
820  '''
821 
822  def __init__(
823  self,
824  internal_cell,
825  dropout_ratio=None,
826  use_cudnn=False,
827  **kwargs
828  ):
829  self.internal_cell = internal_cell
830  self.dropout_ratio = dropout_ratio
831  assert 'is_test' in kwargs, "Argument 'is_test' is required"
832  self.is_test = kwargs.pop('is_test')
833  self.use_cudnn = use_cudnn
834  super(DropoutCell, self).__init__(**kwargs)
835 
836  self.prepare_input = internal_cell.prepare_input
837  self.get_output_state_index = internal_cell.get_output_state_index
838  self.get_state_names = internal_cell.get_state_names
839  self.get_output_dim = internal_cell.get_output_dim
840 
841  self.mask = 0
842 
843  def _apply(
844  self,
845  model,
846  input_t,
847  seq_lengths,
848  states,
849  timestep,
850  extra_inputs=None,
851  ):
852  return self.internal_cell._apply(
853  model,
854  input_t,
855  seq_lengths,
856  states,
857  timestep,
858  extra_inputs,
859  )
860 
861  def _prepare_output(self, model, states):
862  output = self.internal_cell._prepare_output(
863  model,
864  states,
865  )
866  if self.dropout_ratio is not None:
867  output = self._apply_dropout(model, output)
868  return output
869 
870  def _prepare_output_sequence(self, model, state_outputs):
871  output = self.internal_cell._prepare_output_sequence(
872  model,
873  state_outputs,
874  )
875  if self.dropout_ratio is not None:
876  output = self._apply_dropout(model, output)
877  return output
878 
879  def _apply_dropout(self, model, output):
880  if self.dropout_ratio and not self.forward_only:
881  with core.NameScope(self.name or ''):
882  output = brew.dropout(
883  model,
884  output,
885  str(output) + '_with_dropout_mask{}'.format(self.mask),
886  ratio=float(self.dropout_ratio),
887  is_test=self.is_test,
888  use_cudnn=self.use_cudnn,
889  )
890  self.mask += 1
891  return output
892 
893 
895  def __init__(self, cells):
896  self.cells = cells
897 
898  def create_states(self, model):
899  states = []
900  for i, cell in enumerate(self.cells):
901  if cell.initializer is None:
902  raise Exception("Either initial states "
903  "or initializer have to be set")
904 
905  with core.NameScope("layer_{}".format(i)),\
906  core.NameScope(cell.name):
907  states.extend(cell.initializer.create_states(model))
908  return states
909 
910 
912  '''
913  Multilayer RNN via the composition of RNNCell instance.
914 
915  It is the resposibility of calling code to ensure the compatibility
916  of the successive layers in terms of input/output dimensiality, etc.,
917  and to ensure that their blobs do not have name conflicts, typically by
918  creating the cells with names that specify layer number.
919 
920  Assumes first state (recurrent output) for each layer should be the input
921  to the next layer.
922  '''
923 
924  def __init__(self, cells, residual_output_layers=None, **kwargs):
925  '''
926  cells: list of RNNCell instances, from input to output side.
927 
928  name: string designating network component (for scoping)
929 
930  residual_output_layers: list of indices of layers whose input will
931  be added elementwise to their output elementwise. (It is the
932  responsibility of the client code to ensure shape compatibility.)
933  Note that layer 0 (zero) cannot have residual output because of the
934  timing of prepare_input().
935 
936  forward_only: used to construct inference-only network.
937  '''
938  super(MultiRNNCell, self).__init__(**kwargs)
939  self.cells = cells
940 
941  if residual_output_layers is None:
942  self.residual_output_layers = []
943  else:
944  self.residual_output_layers = residual_output_layers
945 
946  output_index_per_layer = []
947  base_index = 0
948  for cell in self.cells:
949  output_index_per_layer.append(
950  base_index + cell.get_output_state_index(),
951  )
952  base_index += len(cell.get_state_names())
953 
954  self.output_connected_layers = []
955  self.output_indices = []
956  for i in range(len(self.cells) - 1):
957  if (i + 1) in self.residual_output_layers:
958  self.output_connected_layers.append(i)
959  self.output_indices.append(output_index_per_layer[i])
960  else:
961  self.output_connected_layers = []
962  self.output_indices = []
963  self.output_connected_layers.append(len(self.cells) - 1)
964  self.output_indices.append(output_index_per_layer[-1])
965 
966  self.state_names = []
967  for i, cell in enumerate(self.cells):
968  self.state_names.extend(
969  map(self.layer_scoper(i), cell.get_state_names())
970  )
971 
973 
974  def layer_scoper(self, layer_id):
975  def helper(name):
976  return "{}/layer_{}/{}".format(self.name, layer_id, name)
977  return helper
978 
979  def prepare_input(self, model, input_blob):
980  input_blob = _RectifyName(input_blob)
981  with core.NameScope(self.name or ''):
982  return self.cells[0].prepare_input(model, input_blob)
983 
984  def _apply(
985  self,
986  model,
987  input_t,
988  seq_lengths,
989  states,
990  timestep,
991  extra_inputs=None,
992  ):
993  '''
994  Because below we will do scoping across layers, we need
995  to make sure that string blob names are convereted to BlobReference
996  objects.
997  '''
998 
999  input_t, seq_lengths, states, timestep, extra_inputs = \
1000  self._rectify_apply_inputs(
1001  input_t, seq_lengths, states, timestep, extra_inputs)
1002 
1003  states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
1004  assert len(states) == sum(states_per_layer)
1005 
1006  next_states = []
1007  states_index = 0
1008 
1009  layer_input = input_t
1010  for i, layer_cell in enumerate(self.cells):
1011  # # If cells don't have different names we still
1012  # take care of scoping
1013  with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
1014  num_states = states_per_layer[i]
1015  layer_states = states[states_index:(states_index + num_states)]
1016  states_index += num_states
1017 
1018  if i > 0:
1019  prepared_input = layer_cell.prepare_input(
1020  model, layer_input)
1021  else:
1022  prepared_input = layer_input
1023 
1024  layer_next_states = layer_cell._apply(
1025  model,
1026  prepared_input,
1027  seq_lengths,
1028  layer_states,
1029  timestep,
1030  extra_inputs=(None if i > 0 else extra_inputs),
1031  )
1032  # Since we're using here non-public method _apply,
1033  # instead of apply, we have to manually extract output
1034  # from states
1035  if i != len(self.cells) - 1:
1036  layer_output = layer_cell._prepare_output(
1037  model,
1038  layer_next_states,
1039  )
1040  if i > 0 and i in self.residual_output_layers:
1041  layer_input = brew.sum(
1042  model,
1043  [layer_output, layer_input],
1044  self.scope('residual_output_{}'.format(i)),
1045  )
1046  else:
1047  layer_input = layer_output
1048 
1049  next_states.extend(layer_next_states)
1050  return next_states
1051 
1052  def get_state_names(self):
1053  return self.state_names
1054 
1055  def get_output_state_index(self):
1056  index = 0
1057  for cell in self.cells[:-1]:
1058  index += len(cell.get_state_names())
1059  index += self.cells[-1].get_output_state_index()
1060  return index
1061 
1062  def _prepare_output(self, model, states):
1063  connected_outputs = []
1064  state_index = 0
1065  for i, cell in enumerate(self.cells):
1066  num_states = len(cell.get_state_names())
1067  if i in self.output_connected_layers:
1068  layer_states = states[state_index:state_index + num_states]
1069  layer_output = cell._prepare_output(
1070  model,
1071  layer_states
1072  )
1073  connected_outputs.append(layer_output)
1074  state_index += num_states
1075  if len(connected_outputs) > 1:
1076  output = brew.sum(
1077  model,
1078  connected_outputs,
1079  self.scope('residual_output'),
1080  )
1081  else:
1082  output = connected_outputs[0]
1083  return output
1084 
1085  def _prepare_output_sequence(self, model, states):
1086  connected_outputs = []
1087  state_index = 0
1088  for i, cell in enumerate(self.cells):
1089  num_states = 2 * len(cell.get_state_names())
1090  if i in self.output_connected_layers:
1091  layer_states = states[state_index:state_index + num_states]
1092  layer_output = cell._prepare_output_sequence(
1093  model,
1094  layer_states
1095  )
1096  connected_outputs.append(layer_output)
1097  state_index += num_states
1098  if len(connected_outputs) > 1:
1099  output = brew.sum(
1100  model,
1101  connected_outputs,
1102  self.scope('residual_output_sequence'),
1103  )
1104  else:
1105  output = connected_outputs[0]
1106  return output
1107 
1108 
1110 
1111  def __init__(
1112  self,
1113  encoder_output_dim,
1114  encoder_outputs,
1115  encoder_lengths,
1116  decoder_cell,
1117  decoder_state_dim,
1118  attention_type,
1119  weighted_encoder_outputs,
1120  attention_memory_optimization,
1121  **kwargs
1122  ):
1123  super(AttentionCell, self).__init__(**kwargs)
1124  self.encoder_output_dim = encoder_output_dim
1125  self.encoder_outputs = encoder_outputs
1126  self.encoder_lengths = encoder_lengths
1127  self.decoder_cell = decoder_cell
1128  self.decoder_state_dim = decoder_state_dim
1129  self.weighted_encoder_outputs = weighted_encoder_outputs
1130  self.encoder_outputs_transposed = None
1131  assert attention_type in [
1132  AttentionType.Regular,
1133  AttentionType.Recurrent,
1134  AttentionType.Dot,
1135  AttentionType.SoftCoverage,
1136  ]
1137  self.attention_type = attention_type
1138  self.attention_memory_optimization = attention_memory_optimization
1139 
1140  def _apply(
1141  self,
1142  model,
1143  input_t,
1144  seq_lengths,
1145  states,
1146  timestep,
1147  extra_inputs=None,
1148  ):
1149  if self.attention_type == AttentionType.SoftCoverage:
1150  decoder_prev_states = states[:-2]
1151  attention_weighted_encoder_context_t_prev = states[-2]
1152  coverage_t_prev = states[-1]
1153  else:
1154  decoder_prev_states = states[:-1]
1155  attention_weighted_encoder_context_t_prev = states[-1]
1156 
1157  assert extra_inputs is None
1158 
1159  decoder_states = self.decoder_cell._apply(
1160  model,
1161  input_t,
1162  seq_lengths,
1163  decoder_prev_states,
1164  timestep,
1165  extra_inputs=[(
1166  attention_weighted_encoder_context_t_prev,
1167  self.encoder_output_dim,
1168  )],
1169  )
1170 
1171  self.hidden_t_intermediate = self.decoder_cell._prepare_output(
1172  model,
1173  decoder_states,
1174  )
1175 
1176  if self.attention_type == AttentionType.Recurrent:
1177  (
1178  attention_weighted_encoder_context_t,
1179  self.attention_weights_3d,
1180  attention_blobs,
1181  ) = apply_recurrent_attention(
1182  model=model,
1183  encoder_output_dim=self.encoder_output_dim,
1184  encoder_outputs_transposed=self.encoder_outputs_transposed,
1185  weighted_encoder_outputs=self.weighted_encoder_outputs,
1186  decoder_hidden_state_t=self.hidden_t_intermediate,
1187  decoder_hidden_state_dim=self.decoder_state_dim,
1188  scope=self.name,
1189  attention_weighted_encoder_context_t_prev=(
1190  attention_weighted_encoder_context_t_prev
1191  ),
1192  encoder_lengths=self.encoder_lengths,
1193  )
1194  elif self.attention_type == AttentionType.Regular:
1195  (
1196  attention_weighted_encoder_context_t,
1197  self.attention_weights_3d,
1198  attention_blobs,
1199  ) = apply_regular_attention(
1200  model=model,
1201  encoder_output_dim=self.encoder_output_dim,
1202  encoder_outputs_transposed=self.encoder_outputs_transposed,
1203  weighted_encoder_outputs=self.weighted_encoder_outputs,
1204  decoder_hidden_state_t=self.hidden_t_intermediate,
1205  decoder_hidden_state_dim=self.decoder_state_dim,
1206  scope=self.name,
1207  encoder_lengths=self.encoder_lengths,
1208  )
1209  elif self.attention_type == AttentionType.Dot:
1210  (
1211  attention_weighted_encoder_context_t,
1212  self.attention_weights_3d,
1213  attention_blobs,
1214  ) = apply_dot_attention(
1215  model=model,
1216  encoder_output_dim=self.encoder_output_dim,
1217  encoder_outputs_transposed=self.encoder_outputs_transposed,
1218  decoder_hidden_state_t=self.hidden_t_intermediate,
1219  decoder_hidden_state_dim=self.decoder_state_dim,
1220  scope=self.name,
1221  encoder_lengths=self.encoder_lengths,
1222  )
1223  elif self.attention_type == AttentionType.SoftCoverage:
1224  (
1225  attention_weighted_encoder_context_t,
1226  self.attention_weights_3d,
1227  attention_blobs,
1228  coverage_t,
1229  ) = apply_soft_coverage_attention(
1230  model=model,
1231  encoder_output_dim=self.encoder_output_dim,
1232  encoder_outputs_transposed=self.encoder_outputs_transposed,
1233  weighted_encoder_outputs=self.weighted_encoder_outputs,
1234  decoder_hidden_state_t=self.hidden_t_intermediate,
1235  decoder_hidden_state_dim=self.decoder_state_dim,
1236  scope=self.name,
1237  encoder_lengths=self.encoder_lengths,
1238  coverage_t_prev=coverage_t_prev,
1239  coverage_weights=self.coverage_weights,
1240  )
1241  else:
1242  raise Exception('Attention type {} not implemented'.format(
1243  self.attention_type
1244  ))
1245 
1247  self.recompute_blobs.extend(attention_blobs)
1248 
1249  output = list(decoder_states) + [attention_weighted_encoder_context_t]
1250  if self.attention_type == AttentionType.SoftCoverage:
1251  output.append(coverage_t)
1252 
1253  output[self.decoder_cell.get_output_state_index()] = model.Copy(
1254  output[self.decoder_cell.get_output_state_index()],
1255  self.scope('hidden_t_external'),
1256  )
1257  model.net.AddExternalOutputs(*output)
1258 
1259  return output
1260 
1261  def get_attention_weights(self):
1262  # [batch_size, encoder_length, 1]
1263  return self.attention_weights_3d
1264 
1265  def prepare_input(self, model, input_blob):
1266  if self.encoder_outputs_transposed is None:
1267  self.encoder_outputs_transposed = brew.transpose(
1268  model,
1269  self.encoder_outputs,
1270  self.scope('encoder_outputs_transposed'),
1271  axes=[1, 2, 0],
1272  )
1273  if (
1274  self.weighted_encoder_outputs is None and
1275  self.attention_type != AttentionType.Dot
1276  ):
1277  self.weighted_encoder_outputs = brew.fc(
1278  model,
1279  self.encoder_outputs,
1280  self.scope('weighted_encoder_outputs'),
1281  dim_in=self.encoder_output_dim,
1282  dim_out=self.encoder_output_dim,
1283  axis=2,
1284  )
1285 
1286  return self.decoder_cell.prepare_input(model, input_blob)
1287 
1288  def build_initial_coverage(self, model):
1289  """
1290  initial_coverage is always zeros of shape [encoder_length],
1291  which shape must be determined programmatically dureing network
1292  computation.
1293 
1294  This method also sets self.coverage_weights, a separate transform
1295  of encoder_outputs which is used to determine coverage contribution
1296  tp attention.
1297  """
1298  assert self.attention_type == AttentionType.SoftCoverage
1299 
1300  # [encoder_length, batch_size, encoder_output_dim]
1301  self.coverage_weights = brew.fc(
1302  model,
1303  self.encoder_outputs,
1304  self.scope('coverage_weights'),
1305  dim_in=self.encoder_output_dim,
1306  dim_out=self.encoder_output_dim,
1307  axis=2,
1308  )
1309 
1310  encoder_length = model.net.Slice(
1311  model.net.Shape(self.encoder_outputs),
1312  starts=[0],
1313  ends=[1],
1314  )
1315  if (
1316  scope.CurrentDeviceScope() is not None and
1317  core.IsGPUDeviceType(scope.CurrentDeviceScope().device_type)
1318  ):
1319  encoder_length = model.net.CopyGPUToCPU(
1320  encoder_length,
1321  'encoder_length_cpu',
1322  )
1323  # total attention weight applied across decoding steps_per_checkpoint
1324  # shape: [encoder_length]
1325  initial_coverage = model.net.ConstantFill(
1326  encoder_length,
1327  self.scope('initial_coverage'),
1328  value=0.0,
1329  input_as_shape=1,
1330  )
1331  return initial_coverage
1332 
1333  def get_state_names(self):
1334  state_names = list(self.decoder_cell.get_state_names())
1335  state_names[self.get_output_state_index()] = self.scope(
1336  'hidden_t_external',
1337  )
1338  state_names.append(self.scope('attention_weighted_encoder_context_t'))
1339  if self.attention_type == AttentionType.SoftCoverage:
1340  state_names.append(self.scope('coverage_t'))
1341  return state_names
1342 
1343  def get_output_dim(self):
1344  return self.decoder_state_dim + self.encoder_output_dim
1345 
1346  def get_output_state_index(self):
1347  return self.decoder_cell.get_output_state_index()
1348 
1349  def _prepare_output(self, model, states):
1350  if self.attention_type == AttentionType.SoftCoverage:
1351  attention_context = states[-2]
1352  else:
1353  attention_context = states[-1]
1354 
1355  with core.NameScope(self.name or ''):
1356  output = brew.concat(
1357  model,
1358  [self.hidden_t_intermediate, attention_context],
1359  'states_and_context_combination',
1360  axis=2,
1361  )
1362 
1363  return output
1364 
1365  def _prepare_output_sequence(self, model, state_outputs):
1366  if self.attention_type == AttentionType.SoftCoverage:
1367  decoder_state_outputs = state_outputs[:-4]
1368  else:
1369  decoder_state_outputs = state_outputs[:-2]
1370 
1371  decoder_output = self.decoder_cell._prepare_output_sequence(
1372  model,
1373  decoder_state_outputs,
1374  )
1375 
1376  if self.attention_type == AttentionType.SoftCoverage:
1377  attention_context_index = 2 * (len(self.get_state_names()) - 2)
1378  else:
1379  attention_context_index = 2 * (len(self.get_state_names()) - 1)
1380 
1381  with core.NameScope(self.name or ''):
1382  output = brew.concat(
1383  model,
1384  [
1385  decoder_output,
1386  state_outputs[attention_context_index],
1387  ],
1388  'states_and_context_combination',
1389  axis=2,
1390  )
1391  return output
1392 
1393 
1395 
1396  def __init__(
1397  self,
1398  encoder_output_dim,
1399  encoder_outputs,
1400  encoder_lengths,
1401  decoder_input_dim,
1402  decoder_state_dim,
1403  name,
1404  attention_type,
1405  weighted_encoder_outputs,
1406  forget_bias,
1407  lstm_memory_optimization,
1408  attention_memory_optimization,
1409  forward_only=False,
1410  ):
1411  decoder_cell = LSTMCell(
1412  input_size=decoder_input_dim,
1413  hidden_size=decoder_state_dim,
1414  forget_bias=forget_bias,
1415  memory_optimization=lstm_memory_optimization,
1416  name='{}/decoder'.format(name),
1417  forward_only=False,
1418  drop_states=False,
1419  )
1420  super(LSTMWithAttentionCell, self).__init__(
1421  encoder_output_dim=encoder_output_dim,
1422  encoder_outputs=encoder_outputs,
1423  encoder_lengths=encoder_lengths,
1424  decoder_cell=decoder_cell,
1425  decoder_state_dim=decoder_state_dim,
1426  name=name,
1427  attention_type=attention_type,
1428  weighted_encoder_outputs=weighted_encoder_outputs,
1429  attention_memory_optimization=attention_memory_optimization,
1430  forward_only=forward_only,
1431  )
1432 
1433 
1435 
1436  def __init__(
1437  self,
1438  encoder_output_dim,
1439  encoder_outputs,
1440  decoder_input_dim,
1441  decoder_state_dim,
1442  name,
1443  attention_type,
1444  weighted_encoder_outputs,
1445  forget_bias,
1446  lstm_memory_optimization,
1447  attention_memory_optimization,
1448  forward_only=False,
1449  ):
1450  decoder_cell = MILSTMCell(
1451  input_size=decoder_input_dim,
1452  hidden_size=decoder_state_dim,
1453  forget_bias=forget_bias,
1454  memory_optimization=lstm_memory_optimization,
1455  name='{}/decoder'.format(name),
1456  forward_only=False,
1457  drop_states=False,
1458  )
1459  super(MILSTMWithAttentionCell, self).__init__(
1460  encoder_output_dim=encoder_output_dim,
1461  encoder_outputs=encoder_outputs,
1462  decoder_cell=decoder_cell,
1463  decoder_state_dim=decoder_state_dim,
1464  name=name,
1465  attention_type=attention_type,
1466  weighted_encoder_outputs=weighted_encoder_outputs,
1467  attention_memory_optimization=attention_memory_optimization,
1468  forward_only=forward_only,
1469  )
1470 
1471 
1472 def _LSTM(
1473  cell_class,
1474  model,
1475  input_blob,
1476  seq_lengths,
1477  initial_states,
1478  dim_in,
1479  dim_out,
1480  scope=None,
1481  outputs_with_grads=(0,),
1482  return_params=False,
1483  memory_optimization=False,
1484  forget_bias=0.0,
1485  forward_only=False,
1486  drop_states=False,
1487  return_last_layer_only=True,
1488  static_rnn_unroll_size=None,
1489  **cell_kwargs
1490 ):
1491  '''
1492  Adds a standard LSTM recurrent network operator to a model.
1493 
1494  cell_class: LSTMCell or compatible subclass
1495 
1496  model: ModelHelper object new operators would be added to
1497 
1498  input_blob: the input sequence in a format T x N x D
1499  where T is sequence size, N - batch size and D - input dimension
1500 
1501  seq_lengths: blob containing sequence lengths which would be passed to
1502  LSTMUnit operator
1503 
1504  initial_states: a list of (2 * num_layers) blobs representing the initial
1505  hidden and cell states of each layer. If this argument is None,
1506  these states will be added to the model as network parameters.
1507 
1508  dim_in: input dimension
1509 
1510  dim_out: number of units per LSTM layer
1511  (use int for single-layer LSTM, list of ints for multi-layer)
1512 
1513  outputs_with_grads : position indices of output blobs for LAST LAYER which
1514  will receive external error gradient during backpropagation.
1515  These outputs are: (h_all, h_last, c_all, c_last)
1516 
1517  return_params: if True, will return a dictionary of parameters of the LSTM
1518 
1519  memory_optimization: if enabled, the LSTM step is recomputed on backward
1520  step so that we don't need to store forward activations for each
1521  timestep. Saves memory with cost of computation.
1522 
1523  forget_bias: forget gate bias (default 0.0)
1524 
1525  forward_only: whether to create a backward pass
1526 
1527  drop_states: drop invalid states, passed through to LSTMUnit operator
1528 
1529  return_last_layer_only: only return outputs from final layer
1530  (so that length of results does depend on number of layers)
1531 
1532  static_rnn_unroll_size: if not None, we will use static RNN which is
1533  unrolled into Caffe2 graph. The size of the unroll is the value of
1534  this parameter.
1535  '''
1536  if type(dim_out) is not list and type(dim_out) is not tuple:
1537  dim_out = [dim_out]
1538  num_layers = len(dim_out)
1539 
1540  cells = []
1541  for i in range(num_layers):
1542  cell = cell_class(
1543  input_size=(dim_in if i == 0 else dim_out[i - 1]),
1544  hidden_size=dim_out[i],
1545  forget_bias=forget_bias,
1546  memory_optimization=memory_optimization,
1547  name=scope if num_layers == 1 else None,
1548  forward_only=forward_only,
1549  drop_states=drop_states,
1550  **cell_kwargs
1551  )
1552  cells.append(cell)
1553 
1554  cell = MultiRNNCell(
1555  cells,
1556  name=scope,
1557  forward_only=forward_only,
1558  ) if num_layers > 1 else cells[0]
1559 
1560  cell = (
1561  cell if static_rnn_unroll_size is None
1562  else UnrolledCell(cell, static_rnn_unroll_size))
1563 
1564  # outputs_with_grads argument indexes into final layer
1565  outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
1566  _, result = cell.apply_over_sequence(
1567  model=model,
1568  inputs=input_blob,
1569  seq_lengths=seq_lengths,
1570  initial_states=initial_states,
1571  outputs_with_grads=outputs_with_grads,
1572  )
1573 
1574  if return_last_layer_only:
1575  result = result[4 * (num_layers - 1):]
1576  if return_params:
1577  result = list(result) + [{
1578  'input': cell.get_input_params(),
1579  'recurrent': cell.get_recurrent_params(),
1580  }]
1581  return tuple(result)
1582 
1583 
1584 LSTM = functools.partial(_LSTM, LSTMCell)
1585 BasicRNN = functools.partial(_LSTM, BasicRNNCell)
1586 MILSTM = functools.partial(_LSTM, MILSTMCell)
1587 LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
1588 LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
1589 
1590 
1592  def __init__(self, cell, T):
1593  self.T = T
1594  self.cell = cell
1595 
1596  def apply_over_sequence(
1597  self,
1598  model,
1599  inputs,
1600  seq_lengths,
1601  initial_states,
1602  outputs_with_grads=None,
1603  ):
1604  inputs = self.cell.prepare_input(model, inputs)
1605 
1606  # Now they are blob references - outputs of splitting the input sequence
1607  split_inputs = model.net.Split(
1608  inputs,
1609  [str(inputs) + "_timestep_{}".format(i)
1610  for i in range(self.T)],
1611  axis=0)
1612  if self.T == 1:
1613  split_inputs = [split_inputs]
1614 
1615  states = initial_states
1616  all_states = []
1617  for t in range(0, self.T):
1618  scope_name = "timestep_{}".format(t)
1619  # Parameters of all timesteps are shared
1620  with ParameterSharing({scope_name: ''}),\
1621  scope.NameScope(scope_name):
1622  timestep = model.param_init_net.ConstantFill(
1623  [], "timestep", value=t, shape=[1],
1624  dtype=core.DataType.INT32,
1625  device_option=core.DeviceOption(caffe2_pb2.CPU))
1626  states = self.cell._apply(
1627  model=model,
1628  input_t=split_inputs[t],
1629  seq_lengths=seq_lengths,
1630  states=states,
1631  timestep=timestep,
1632  )
1633  all_states.append(states)
1634 
1635  all_states = zip(*all_states)
1636  all_states = [
1637  model.net.Concat(
1638  list(full_output),
1639  [
1640  str(full_output[0])[len("timestep_0/"):] + "_concat",
1641  str(full_output[0])[len("timestep_0/"):] + "_concat_info"
1642 
1643  ],
1644  axis=0)[0]
1645  for full_output in all_states
1646  ]
1647  outputs = tuple(
1648  six.next(it) for it in
1649  itertools.cycle([iter(all_states), iter(states)])
1650  )
1651  outputs_without_grad = set(range(len(outputs))) - set(
1652  outputs_with_grads)
1653  for i in outputs_without_grad:
1654  model.net.ZeroGradient(outputs[i], [])
1655  logging.debug("Added 0 gradients for blobs:",
1656  [outputs[i] for i in outputs_without_grad])
1657 
1658  final_output = self.cell._prepare_output_sequence(model, outputs)
1659 
1660  return final_output, outputs
1661 
1662 
1663 def GetLSTMParamNames():
1664  weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
1665  bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
1666  return {'weights': weight_params, 'biases': bias_params}
1667 
1668 
1669 def InitFromLSTMParams(lstm_pblobs, param_values):
1670  '''
1671  Set the parameters of LSTM based on predefined values
1672  '''
1673  weight_params = GetLSTMParamNames()['weights']
1674  bias_params = GetLSTMParamNames()['biases']
1675  for input_type in viewkeys(param_values):
1676  weight_values = [
1677  param_values[input_type][w].flatten()
1678  for w in weight_params
1679  ]
1680  wmat = np.array([])
1681  for w in weight_values:
1682  wmat = np.append(wmat, w)
1683  bias_values = [
1684  param_values[input_type][b].flatten()
1685  for b in bias_params
1686  ]
1687  bm = np.array([])
1688  for b in bias_values:
1689  bm = np.append(bm, b)
1690 
1691  weights_blob = lstm_pblobs[input_type]['weights']
1692  bias_blob = lstm_pblobs[input_type]['biases']
1693  cur_weight = workspace.FetchBlob(weights_blob)
1694  cur_biases = workspace.FetchBlob(bias_blob)
1695 
1696  workspace.FeedBlob(
1697  weights_blob,
1698  wmat.reshape(cur_weight.shape).astype(np.float32))
1699  workspace.FeedBlob(
1700  bias_blob,
1701  bm.reshape(cur_biases.shape).astype(np.float32))
1702 
1703 
1704 def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
1705  scope, recurrent_params=None, input_params=None,
1706  num_layers=1, return_params=False):
1707  '''
1708  CuDNN version of LSTM for GPUs.
1709  input_blob Blob containing the input. Will need to be available
1710  when param_init_net is run, because the sequence lengths
1711  and batch sizes will be inferred from the size of this
1712  blob.
1713  initial_states tuple of (hidden_init, cell_init) blobs
1714  dim_in input dimensions
1715  dim_out output/hidden dimension
1716  scope namescope to apply
1717  recurrent_params dict of blobs containing values for recurrent
1718  gate weights, biases (if None, use random init values)
1719  See GetLSTMParamNames() for format.
1720  input_params dict of blobs containing values for input
1721  gate weights, biases (if None, use random init values)
1722  See GetLSTMParamNames() for format.
1723  num_layers number of LSTM layers
1724  return_params if True, returns (param_extract_net, param_mapping)
1725  where param_extract_net is a net that when run, will
1726  populate the blobs specified in param_mapping with the
1727  current gate weights and biases (input/recurrent).
1728  Useful for assigning the values back to non-cuDNN
1729  LSTM.
1730  '''
1731  with core.NameScope(scope):
1732  weight_params = GetLSTMParamNames()['weights']
1733  bias_params = GetLSTMParamNames()['biases']
1734 
1735  input_weight_size = dim_out * dim_in
1736  upper_layer_input_weight_size = dim_out * dim_out
1737  recurrent_weight_size = dim_out * dim_out
1738  input_bias_size = dim_out
1739  recurrent_bias_size = dim_out
1740 
1741  def init(layer, pname, input_type):
1742  input_weight_size_for_layer = input_weight_size if layer == 0 else \
1743  upper_layer_input_weight_size
1744  if pname in weight_params:
1745  sz = input_weight_size_for_layer if input_type == 'input' \
1746  else recurrent_weight_size
1747  elif pname in bias_params:
1748  sz = input_bias_size if input_type == 'input' \
1749  else recurrent_bias_size
1750  else:
1751  assert False, "unknown parameter type {}".format(pname)
1752  return model.param_init_net.UniformFill(
1753  [],
1754  "lstm_init_{}_{}_{}".format(input_type, pname, layer),
1755  shape=[sz])
1756 
1757  # Multiply by 4 since we have 4 gates per LSTM unit
1758  first_layer_sz = input_weight_size + recurrent_weight_size + \
1759  input_bias_size + recurrent_bias_size
1760  upper_layer_sz = upper_layer_input_weight_size + \
1761  recurrent_weight_size + input_bias_size + \
1762  recurrent_bias_size
1763  total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
1764 
1765  weights = model.create_param(
1766  'lstm_weight',
1767  shape=[total_sz],
1768  initializer=Initializer('UniformFill'),
1769  tags=ParameterTags.WEIGHT,
1770  )
1771 
1772  lstm_args = {
1773  'hidden_size': dim_out,
1774  'rnn_mode': 'lstm',
1775  'bidirectional': 0, # TODO
1776  'dropout': 1.0, # TODO
1777  'input_mode': 'linear', # TODO
1778  'num_layers': num_layers,
1779  'engine': 'CUDNN'
1780  }
1781 
1782  param_extract_net = core.Net("lstm_param_extractor")
1783  param_extract_net.AddExternalInputs([input_blob, weights])
1784  param_extract_mapping = {}
1785 
1786  # Populate the weights-blob from blobs containing parameters for
1787  # the individual components of the LSTM, such as forget/input gate
1788  # weights and bises. Also, create a special param_extract_net that
1789  # can be used to grab those individual params from the black-box
1790  # weights blob. These results can be then fed to InitFromLSTMParams()
1791  for input_type in ['input', 'recurrent']:
1792  param_extract_mapping[input_type] = {}
1793  p = recurrent_params if input_type == 'recurrent' else input_params
1794  if p is None:
1795  p = {}
1796  for pname in weight_params + bias_params:
1797  for j in range(0, num_layers):
1798  values = p[pname] if pname in p else init(j, pname, input_type)
1799  model.param_init_net.RecurrentParamSet(
1800  [input_blob, weights, values],
1801  weights,
1802  layer=j,
1803  input_type=input_type,
1804  param_type=pname,
1805  **lstm_args
1806  )
1807  if pname not in param_extract_mapping[input_type]:
1808  param_extract_mapping[input_type][pname] = {}
1809  b = param_extract_net.RecurrentParamGet(
1810  [input_blob, weights],
1811  ["lstm_{}_{}_{}".format(input_type, pname, j)],
1812  layer=j,
1813  input_type=input_type,
1814  param_type=pname,
1815  **lstm_args
1816  )
1817  param_extract_mapping[input_type][pname][j] = b
1818 
1819  (hidden_input_blob, cell_input_blob) = initial_states
1820  output, hidden_output, cell_output, rnn_scratch, dropout_states = \
1821  model.net.Recurrent(
1822  [input_blob, hidden_input_blob, cell_input_blob, weights],
1823  ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
1824  "lstm_rnn_scratch", "lstm_dropout_states"],
1825  seed=random.randint(0, 100000), # TODO: dropout seed
1826  **lstm_args
1827  )
1828  model.net.AddExternalOutputs(
1829  hidden_output, cell_output, rnn_scratch, dropout_states)
1830 
1831  if return_params:
1832  param_extract = param_extract_net, param_extract_mapping
1833  return output, hidden_output, cell_output, param_extract
1834  else:
1835  return output, hidden_output, cell_output
1836 
1837 
1838 def LSTMWithAttention(
1839  model,
1840  decoder_inputs,
1841  decoder_input_lengths,
1842  initial_decoder_hidden_state,
1843  initial_decoder_cell_state,
1844  initial_attention_weighted_encoder_context,
1845  encoder_output_dim,
1846  encoder_outputs,
1847  encoder_lengths,
1848  decoder_input_dim,
1849  decoder_state_dim,
1850  scope,
1851  attention_type=AttentionType.Regular,
1852  outputs_with_grads=(0, 4),
1853  weighted_encoder_outputs=None,
1854  lstm_memory_optimization=False,
1855  attention_memory_optimization=False,
1856  forget_bias=0.0,
1857  forward_only=False,
1858 ):
1859  '''
1860  Adds a LSTM with attention mechanism to a model.
1861 
1862  The implementation is based on https://arxiv.org/abs/1409.0473, with
1863  a small difference in the order
1864  how we compute new attention context and new hidden state, similarly to
1865  https://arxiv.org/abs/1508.04025.
1866 
1867  The model uses encoder-decoder naming conventions,
1868  where the decoder is the sequence the op is iterating over,
1869  while computing the attention context over the encoder.
1870 
1871  model: ModelHelper object new operators would be added to
1872 
1873  decoder_inputs: the input sequence in a format T x N x D
1874  where T is sequence size, N - batch size and D - input dimension
1875 
1876  decoder_input_lengths: blob containing sequence lengths
1877  which would be passed to LSTMUnit operator
1878 
1879  initial_decoder_hidden_state: initial hidden state of LSTM
1880 
1881  initial_decoder_cell_state: initial cell state of LSTM
1882 
1883  initial_attention_weighted_encoder_context: initial attention context
1884 
1885  encoder_output_dim: dimension of encoder outputs
1886 
1887  encoder_outputs: the sequence, on which we compute the attention context
1888  at every iteration
1889 
1890  encoder_lengths: a tensor with lengths of each encoder sequence in batch
1891  (may be None, meaning all encoder sequences are of same length)
1892 
1893  decoder_input_dim: input dimension (last dimension on decoder_inputs)
1894 
1895  decoder_state_dim: size of hidden states of LSTM
1896 
1897  attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
1898  Determines which type of attention mechanism to use.
1899 
1900  outputs_with_grads : position indices of output blobs which will receive
1901  external error gradient during backpropagation
1902 
1903  weighted_encoder_outputs: encoder outputs to be used to compute attention
1904  weights. In the basic case it's just linear transformation of
1905  encoder outputs (that the default, when weighted_encoder_outputs is None).
1906  However, it can be something more complicated - like a separate
1907  encoder network (for example, in case of convolutional encoder)
1908 
1909  lstm_memory_optimization: recompute LSTM activations on backward pass, so
1910  we don't need to store their values in forward passes
1911 
1912  attention_memory_optimization: recompute attention for backward pass
1913 
1914  forward_only: whether to create only forward pass
1915  '''
1916  cell = LSTMWithAttentionCell(
1917  encoder_output_dim=encoder_output_dim,
1918  encoder_outputs=encoder_outputs,
1919  encoder_lengths=encoder_lengths,
1920  decoder_input_dim=decoder_input_dim,
1921  decoder_state_dim=decoder_state_dim,
1922  name=scope,
1923  attention_type=attention_type,
1924  weighted_encoder_outputs=weighted_encoder_outputs,
1925  forget_bias=forget_bias,
1926  lstm_memory_optimization=lstm_memory_optimization,
1927  attention_memory_optimization=attention_memory_optimization,
1928  forward_only=forward_only,
1929  )
1930  initial_states = [
1931  initial_decoder_hidden_state,
1932  initial_decoder_cell_state,
1933  initial_attention_weighted_encoder_context,
1934  ]
1935  if attention_type == AttentionType.SoftCoverage:
1936  initial_states.append(cell.build_initial_coverage(model))
1937  _, result = cell.apply_over_sequence(
1938  model=model,
1939  inputs=decoder_inputs,
1940  seq_lengths=decoder_input_lengths,
1941  initial_states=initial_states,
1942  outputs_with_grads=outputs_with_grads,
1943  )
1944  return result
1945 
1946 
1947 def _layered_LSTM(
1948  model, input_blob, seq_lengths, initial_states,
1949  dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
1950  memory_optimization=False, forget_bias=0.0, forward_only=False,
1951  drop_states=False, create_lstm=None):
1952  params = locals() # leave it as a first line to grab all params
1953  params.pop('create_lstm')
1954  if not isinstance(dim_out, list):
1955  return create_lstm(**params)
1956  elif len(dim_out) == 1:
1957  params['dim_out'] = dim_out[0]
1958  return create_lstm(**params)
1959 
1960  assert len(dim_out) != 0, "dim_out list can't be empty"
1961  assert return_params is False, "return_params not supported for layering"
1962  for i, output_dim in enumerate(dim_out):
1963  params.update({
1964  'dim_out': output_dim
1965  })
1966  output, last_output, all_states, last_state = create_lstm(**params)
1967  params.update({
1968  'input_blob': output,
1969  'dim_in': output_dim,
1970  'initial_states': (last_output, last_state),
1971  'scope': scope + '_layer_{}'.format(i + 1)
1972  })
1973  return output, last_output, all_states, last_state
1974 
1975 
1976 layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
def apply_override(self, model, input_t, seq_lengths, timestep, extra_inputs=None)
Definition: rnn_cell.py:192
def __init__(self, cells, residual_output_layers=None, kwargs)
Definition: rnn_cell.py:924
Module caffe2.python.scope.
def layer_scoper(self, layer_id)
Definition: rnn_cell.py:974
def _rectify_apply_inputs(self, input_t, seq_lengths, states, timestep, extra_inputs)
Definition: rnn_cell.py:168
def _apply(self, model, input_t, seq_lengths, states, timestep, extra_inputs=None)
Definition: rnn_cell.py:155
def _prepare_output_sequence(self, model, state_outputs)
Definition: rnn_cell.py:263
def build_initial_coverage(self, model)
Definition: rnn_cell.py:1288
def prepare_input(self, model, input_blob)
Definition: rnn_cell.py:216
def _prepare_output(self, model, states)
Definition: rnn_cell.py:257
def _apply_dropout(self, model, output)
Definition: rnn_cell.py:879