Caffe2 - Python API
A deep learning, cross platform ML framework
rnn_cell.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package rnn_cell
17 # Module caffe2.python.rnn_cell
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import functools
24 import inspect
25 import itertools
26 import logging
27 import numpy as np
28 import random
29 import six
30 from future.utils import viewkeys
31 
32 from caffe2.proto import caffe2_pb2
33 from caffe2.python.attention import (
34  apply_dot_attention,
35  apply_recurrent_attention,
36  apply_regular_attention,
37  apply_soft_coverage_attention,
38  AttentionType,
39 )
40 from caffe2.python import core, recurrent, workspace, brew, scope, utils
41 from caffe2.python.modeling.parameter_sharing import ParameterSharing
42 from caffe2.python.modeling.parameter_info import ParameterTags
43 from caffe2.python.modeling.initializers import Initializer
44 from caffe2.python.model_helper import ModelHelper
45 
46 
47 def _RectifyName(blob_reference_or_name):
48  if blob_reference_or_name is None:
49  return None
50  if isinstance(blob_reference_or_name, six.string_types):
51  return core.ScopedBlobReference(blob_reference_or_name)
52  if not isinstance(blob_reference_or_name, core.BlobReference):
53  raise Exception("Unknown blob reference type")
54  return blob_reference_or_name
55 
56 
57 def _RectifyNames(blob_references_or_names):
58  if blob_references_or_names is None:
59  return None
60  return list(map(_RectifyName, blob_references_or_names))
61 
62 
63 class RNNCell(object):
64  '''
65  Base class for writing recurrent / stateful operations.
66 
67  One needs to implement 2 methods: apply_override
68  and get_state_names_override.
69 
70  As a result base class will provice apply_over_sequence method, which
71  allows you to apply recurrent operations over a sequence of any length.
72 
73  As optional you could add input and output preparation steps by overriding
74  corresponding methods.
75  '''
76  def __init__(self, name=None, forward_only=False, initializer=None):
77  self.name = name
78  self.recompute_blobs = []
79  self.forward_only = forward_only
80  self._initializer = initializer
81 
82  @property
83  def initializer(self):
84  return self._initializer
85 
86  @initializer.setter
87  def initializer(self, value):
88  self._initializer = value
89 
90  def scope(self, name):
91  return self.name + '/' + name if self.name is not None else name
92 
93  def apply_over_sequence(
94  self,
95  model,
96  inputs,
97  seq_lengths=None,
98  initial_states=None,
99  outputs_with_grads=None,
100  ):
101  if initial_states is None:
102  with scope.NameScope(self.name):
103  if self.initializer is None:
104  raise Exception("Either initial states "
105  "or initializer have to be set")
106  initial_states = self.initializer.create_states(model)
107 
108  preprocessed_inputs = self.prepare_input(model, inputs)
109  step_model = ModelHelper(name=self.name, param_model=model)
110  input_t, timestep = step_model.net.AddScopedExternalInputs(
111  'input_t',
112  'timestep',
113  )
114  utils.raiseIfNotEqual(
115  len(initial_states), len(self.get_state_names()),
116  "Number of initial state values provided doesn't match the number "
117  "of states"
118  )
119  states_prev = step_model.net.AddScopedExternalInputs(*[
120  s + '_prev' for s in self.get_state_names()
121  ])
122  states = self._apply(
123  model=step_model,
124  input_t=input_t,
125  seq_lengths=seq_lengths,
126  states=states_prev,
127  timestep=timestep,
128  )
129 
130  external_outputs = set(step_model.net.Proto().external_output)
131  for state in states:
132  if state not in external_outputs:
133  step_model.net.AddExternalOutput(state)
134 
135  if outputs_with_grads is None:
136  outputs_with_grads = [self.get_output_state_index() * 2]
137 
138  # states_for_all_steps consists of combination of
139  # states gather for all steps and final states. It looks like this:
140  # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
141  states_for_all_steps = recurrent.recurrent_net(
142  net=model.net,
143  cell_net=step_model.net,
144  inputs=[(input_t, preprocessed_inputs)],
145  initial_cell_inputs=list(zip(states_prev, initial_states)),
146  links=dict(zip(states_prev, states)),
147  timestep=timestep,
148  scope=self.name,
149  forward_only=self.forward_only,
150  outputs_with_grads=outputs_with_grads,
151  recompute_blobs_on_backward=self.recompute_blobs,
152  )
153 
154  output = self._prepare_output_sequence(
155  model,
156  states_for_all_steps,
157  )
158  return output, states_for_all_steps
159 
160  def apply(self, model, input_t, seq_lengths, states, timestep):
161  input_t = self.prepare_input(model, input_t)
162  states = self._apply(
163  model, input_t, seq_lengths, states, timestep)
164  output = self._prepare_output(model, states)
165  return output, states
166 
167  def _apply(
168  self,
169  model, input_t, seq_lengths, states, timestep, extra_inputs=None
170  ):
171  '''
172  This method uses apply_override provided by a custom cell.
173  On the top it takes care of applying self.scope() to all the outputs.
174  While all the inputs stay within the scope this function was called
175  from.
176  '''
177  args = self._rectify_apply_inputs(
178  input_t, seq_lengths, states, timestep, extra_inputs)
179  with core.NameScope(self.name):
180  return self.apply_override(model, *args)
181 
182  def _rectify_apply_inputs(
183  self, input_t, seq_lengths, states, timestep, extra_inputs):
184  '''
185  Before applying a scope we make sure that all external blob names
186  are converted to blob reference. So further scoping doesn't affect them
187  '''
188 
189  input_t, seq_lengths, timestep = _RectifyNames(
190  [input_t, seq_lengths, timestep])
191  states = _RectifyNames(states)
192  if extra_inputs:
193  extra_input_names, extra_input_sizes = zip(*extra_inputs)
194  extra_inputs = _RectifyNames(extra_input_names)
195  extra_inputs = zip(extra_input_names, extra_input_sizes)
196 
197  arg_names = inspect.getargspec(self.apply_override).args
198  rectified = [input_t, seq_lengths, states, timestep]
199  if 'extra_inputs' in arg_names:
200  rectified.append(extra_inputs)
201  return rectified
202 
203 
204  def apply_override(
205  self,
206  model, input_t, seq_lengths, timestep, extra_inputs=None,
207  ):
208  '''
209  A single step of a recurrent network to be implemented by each custom
210  RNNCell.
211 
212  model: ModelHelper object new operators would be added to
213 
214  input_t: singlse input with shape (1, batch_size, input_dim)
215 
216  seq_lengths: blob containing sequence lengths which would be passed to
217  LSTMUnit operator
218 
219  states: previous recurrent states
220 
221  timestep: current recurrent iteration. Could be used together with
222  seq_lengths in order to determine, if some shorter sequences
223  in the batch have already ended.
224 
225  extra_inputs: list of tuples (input, dim). specifies additional input
226  which is not subject to prepare_input(). (useful when a cell is a
227  component of a larger recurrent structure, e.g., attention)
228  '''
229  raise NotImplementedError('Abstract method')
230 
231  def prepare_input(self, model, input_blob):
232  '''
233  If some operations in _apply method depend only on the input,
234  not on recurrent states, they could be computed in advance.
235 
236  model: ModelHelper object new operators would be added to
237 
238  input_blob: either the whole input sequence with shape
239  (sequence_length, batch_size, input_dim) or a single input with shape
240  (1, batch_size, input_dim).
241  '''
242  return input_blob
243 
245  '''
246  Return index into state list of the "primary" step-wise output.
247  '''
248  return 0
249 
250  def get_state_names(self):
251  '''
252  Returns recurrent state names with self.name scoping applied
253  '''
254  return list(map(self.scope, self.get_state_names_override()))
255 
257  '''
258  Override this funtion in your custom cell.
259  It should return the names of the recurrent states.
260 
261  It's required by apply_over_sequence method in order to allocate
262  recurrent states for all steps with meaningful names.
263  '''
264  raise NotImplementedError('Abstract method')
265 
266  def get_output_dim(self):
267  '''
268  Specifies the dimension (number of units) of stepwise output.
269  '''
270  raise NotImplementedError('Abstract method')
271 
272  def _prepare_output(self, model, states):
273  '''
274  Allows arbitrary post-processing of primary output.
275  '''
276  return states[self.get_output_state_index()]
277 
278  def _prepare_output_sequence(self, model, state_outputs):
279  '''
280  Allows arbitrary post-processing of primary sequence output.
281 
282  (Note that state_outputs alternates between full-sequence and final
283  output for each state, thus the index multiplier 2.)
284  '''
285  output_sequence_index = 2 * self.get_output_state_index()
286  return state_outputs[output_sequence_index]
287 
288 
289 class LSTMInitializer(object):
290  def __init__(self, hidden_size):
291  self.hidden_size = hidden_size
292 
293  def create_states(self, model):
294  return [
295  model.create_param(
296  param_name='initial_hidden_state',
297  initializer=Initializer(operator_name='ConstantFill',
298  value=0.0),
299  shape=[self.hidden_size],
300  ),
301  model.create_param(
302  param_name='initial_cell_state',
303  initializer=Initializer(operator_name='ConstantFill',
304  value=0.0),
305  shape=[self.hidden_size],
306  )
307  ]
308 
309 
310 # based on http://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
312  def __init__(
313  self,
314  input_size,
315  hidden_size,
316  forget_bias,
317  memory_optimization,
318  drop_states=False,
319  initializer=None,
320  activation=None,
321  **kwargs
322  ):
323  super(BasicRNNCell, self).__init__(**kwargs)
324  self.drop_states = drop_states
325  self.input_size = input_size
326  self.hidden_size = hidden_size
327  self.activation = activation
328 
329  if self.activation not in ['relu', 'tanh']:
330  raise RuntimeError(
331  'BasicRNNCell with unknown activation function (%s)'
332  % self.activation)
333 
334  def apply_override(
335  self,
336  model,
337  input_t,
338  seq_lengths,
339  states,
340  timestep,
341  extra_inputs=None,
342  ):
343  hidden_t_prev = states[0]
344 
345  gates_t = brew.fc(
346  model,
347  hidden_t_prev,
348  'gates_t',
349  dim_in=self.hidden_size,
350  dim_out=self.hidden_size,
351  axis=2,
352  )
353 
354  brew.sum(model, [gates_t, input_t], gates_t)
355  if self.activation == 'tanh':
356  hidden_t = model.net.Tanh(gates_t, 'hidden_t')
357  elif self.activation == 'relu':
358  hidden_t = model.net.Relu(gates_t, 'hidden_t')
359  else:
360  raise RuntimeError(
361  'BasicRNNCell with unknown activation function (%s)'
362  % self.activation)
363 
364  if seq_lengths is not None:
365  # TODO If this codepath becomes popular, it may be worth
366  # taking a look at optimizing it - for now a simple
367  # implementation is used to round out compatibility with
368  # ONNX.
369  timestep = model.net.CopyFromCPUInput(
370  timestep, 'timestep_gpu')
371  valid_b = model.net.GT(
372  [seq_lengths, timestep], 'valid_b', broadcast=1)
373  invalid_b = model.net.LE(
374  [seq_lengths, timestep], 'invalid_b', broadcast=1)
375  valid = model.net.Cast(valid_b, 'valid', to='float')
376  invalid = model.net.Cast(invalid_b, 'invalid', to='float')
377 
378  hidden_valid = model.net.Mul(
379  [hidden_t, valid],
380  'hidden_valid',
381  broadcast=1,
382  axis=1,
383  )
384  if self.drop_states:
385  hidden_t = hidden_valid
386  else:
387  hidden_invalid = model.net.Mul(
388  [hidden_t_prev, invalid],
389  'hidden_invalid',
390  broadcast=1, axis=1)
391  hidden_t = model.net.Add(
392  [hidden_valid, hidden_invalid], hidden_t)
393  return (hidden_t,)
394 
395  def prepare_input(self, model, input_blob):
396  return brew.fc(
397  model,
398  input_blob,
399  self.scope('i2h'),
400  dim_in=self.input_size,
401  dim_out=self.hidden_size,
402  axis=2,
403  )
404 
405  def get_state_names(self):
406  return (self.scope('hidden_t'),)
407 
408  def get_output_dim(self):
409  return self.hidden_size
410 
411 
413 
414  def __init__(
415  self,
416  input_size,
417  hidden_size,
418  forget_bias,
419  memory_optimization,
420  drop_states=False,
421  initializer=None,
422  **kwargs
423  ):
424  super(LSTMCell, self).__init__(initializer=initializer, **kwargs)
425  self.initializer = initializer or LSTMInitializer(
426  hidden_size=hidden_size)
427 
428  self.input_size = input_size
429  self.hidden_size = hidden_size
430  self.forget_bias = float(forget_bias)
431  self.memory_optimization = memory_optimization
432  self.drop_states = drop_states
433  self.gates_size = 4 * self.hidden_size
434 
435  def apply_override(
436  self,
437  model,
438  input_t,
439  seq_lengths,
440  states,
441  timestep,
442  extra_inputs=None,
443  ):
444  hidden_t_prev, cell_t_prev = states
445 
446  fc_input = hidden_t_prev
447  fc_input_dim = self.hidden_size
448 
449  if extra_inputs is not None:
450  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
451  fc_input = brew.concat(
452  model,
453  [hidden_t_prev] + list(extra_input_blobs),
454  'gates_concatenated_input_t',
455  axis=2,
456  )
457  fc_input_dim += sum(extra_input_sizes)
458 
459  gates_t = brew.fc(
460  model,
461  fc_input,
462  'gates_t',
463  dim_in=fc_input_dim,
464  dim_out=self.gates_size,
465  axis=2,
466  )
467  brew.sum(model, [gates_t, input_t], gates_t)
468 
469  if seq_lengths is not None:
470  inputs = [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep]
471  else:
472  inputs = [hidden_t_prev, cell_t_prev, gates_t, timestep]
473 
474  hidden_t, cell_t = model.net.LSTMUnit(
475  inputs,
476  ['hidden_state', 'cell_state'],
477  forget_bias=self.forget_bias,
478  drop_states=self.drop_states,
479  sequence_lengths=(seq_lengths is not None),
480  )
481  model.net.AddExternalOutputs(hidden_t, cell_t)
482  if self.memory_optimization:
483  self.recompute_blobs = [gates_t]
484 
485  return hidden_t, cell_t
486 
487  def get_input_params(self):
488  return {
489  'weights': self.scope('i2h') + '_w',
490  'biases': self.scope('i2h') + '_b',
491  }
492 
493  def get_recurrent_params(self):
494  return {
495  'weights': self.scope('gates_t') + '_w',
496  'biases': self.scope('gates_t') + '_b',
497  }
498 
499  def prepare_input(self, model, input_blob):
500  return brew.fc(
501  model,
502  input_blob,
503  self.scope('i2h'),
504  dim_in=self.input_size,
505  dim_out=self.gates_size,
506  axis=2,
507  )
508 
509  def get_state_names_override(self):
510  return ['hidden_t', 'cell_t']
511 
512  def get_output_dim(self):
513  return self.hidden_size
514 
515 
517 
518  def __init__(
519  self,
520  input_size,
521  hidden_size,
522  forget_bias,
523  memory_optimization,
524  drop_states=False,
525  initializer=None,
526  **kwargs
527  ):
528  super(LayerNormLSTMCell, self).__init__(
529  initializer=initializer, **kwargs
530  )
531  self.initializer = initializer or LSTMInitializer(
532  hidden_size=hidden_size
533  )
534 
535  self.input_size = input_size
536  self.hidden_size = hidden_size
537  self.forget_bias = float(forget_bias)
538  self.memory_optimization = memory_optimization
539  self.drop_states = drop_states
540  self.gates_size = 4 * self.hidden_size
541 
542  def _apply(
543  self,
544  model,
545  input_t,
546  seq_lengths,
547  states,
548  timestep,
549  extra_inputs=None,
550  ):
551  hidden_t_prev, cell_t_prev = states
552 
553  fc_input = hidden_t_prev
554  fc_input_dim = self.hidden_size
555 
556  if extra_inputs is not None:
557  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
558  fc_input = brew.concat(
559  model,
560  [hidden_t_prev] + list(extra_input_blobs),
561  self.scope('gates_concatenated_input_t'),
562  axis=2,
563  )
564  fc_input_dim += sum(extra_input_sizes)
565 
566  gates_t = brew.fc(
567  model,
568  fc_input,
569  self.scope('gates_t'),
570  dim_in=fc_input_dim,
571  dim_out=self.gates_size,
572  axis=2,
573  )
574  brew.sum(model, [gates_t, input_t], gates_t)
575 
576  # brew.layer_norm call is only difference from LSTMCell
577  gates_t, _, _ = brew.layer_norm(
578  model,
579  self.scope('gates_t'),
580  self.scope('gates_t_norm'),
581  dim_in=self.gates_size,
582  axis=-1,
583  )
584 
585  hidden_t, cell_t = model.net.LSTMUnit(
586  [
587  hidden_t_prev,
588  cell_t_prev,
589  gates_t,
590  seq_lengths,
591  timestep,
592  ],
593  self.get_state_names(),
594  forget_bias=self.forget_bias,
595  drop_states=self.drop_states,
596  )
597  model.net.AddExternalOutputs(hidden_t, cell_t)
598  if self.memory_optimization:
599  self.recompute_blobs = [gates_t]
600 
601  return hidden_t, cell_t
602 
603  def get_input_params(self):
604  return {
605  'weights': self.scope('i2h') + '_w',
606  'biases': self.scope('i2h') + '_b',
607  }
608 
609  def prepare_input(self, model, input_blob):
610  return brew.fc(
611  model,
612  input_blob,
613  self.scope('i2h'),
614  dim_in=self.input_size,
615  dim_out=self.gates_size,
616  axis=2,
617  )
618 
619  def get_state_names(self):
620  return (self.scope('hidden_t'), self.scope('cell_t'))
621 
622 
624 
625  def _apply(
626  self,
627  model,
628  input_t,
629  seq_lengths,
630  states,
631  timestep,
632  extra_inputs=None,
633  ):
634  hidden_t_prev, cell_t_prev = states
635 
636  fc_input = hidden_t_prev
637  fc_input_dim = self.hidden_size
638 
639  if extra_inputs is not None:
640  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
641  fc_input = brew.concat(
642  model,
643  [hidden_t_prev] + list(extra_input_blobs),
644  self.scope('gates_concatenated_input_t'),
645  axis=2,
646  )
647  fc_input_dim += sum(extra_input_sizes)
648 
649  prev_t = brew.fc(
650  model,
651  fc_input,
652  self.scope('prev_t'),
653  dim_in=fc_input_dim,
654  dim_out=self.gates_size,
655  axis=2,
656  )
657 
658  # defining initializers for MI parameters
659  alpha = model.create_param(
660  self.scope('alpha'),
661  shape=[self.gates_size],
662  initializer=Initializer('ConstantFill', value=1.0),
663  )
664  beta_h = model.create_param(
665  self.scope('beta1'),
666  shape=[self.gates_size],
667  initializer=Initializer('ConstantFill', value=1.0),
668  )
669  beta_i = model.create_param(
670  self.scope('beta2'),
671  shape=[self.gates_size],
672  initializer=Initializer('ConstantFill', value=1.0),
673  )
674  b = model.create_param(
675  self.scope('b'),
676  shape=[self.gates_size],
677  initializer=Initializer('ConstantFill', value=0.0),
678  )
679 
680  # alpha * input_t + beta_h
681  # Shape: [1, batch_size, 4 * hidden_size]
682  alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
683  [input_t, alpha, beta_h],
684  self.scope('alpha_by_input_t_plus_beta_h'),
685  axis=2,
686  )
687  # (alpha * input_t + beta_h) * prev_t =
688  # alpha * input_t * prev_t + beta_h * prev_t
689  # Shape: [1, batch_size, 4 * hidden_size]
690  alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
691  [alpha_by_input_t_plus_beta_h, prev_t],
692  self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
693  )
694  # beta_i * input_t + b
695  # Shape: [1, batch_size, 4 * hidden_size]
696  beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
697  [input_t, beta_i, b],
698  self.scope('beta_i_by_input_t_plus_b'),
699  axis=2,
700  )
701  # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
702  # Shape: [1, batch_size, 4 * hidden_size]
703  gates_t = brew.sum(
704  model,
705  [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
706  self.scope('gates_t')
707  )
708  hidden_t, cell_t = model.net.LSTMUnit(
709  [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
710  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
711  forget_bias=self.forget_bias,
712  drop_states=self.drop_states,
713  )
714  model.net.AddExternalOutputs(
715  cell_t,
716  hidden_t,
717  )
718  if self.memory_optimization:
719  self.recompute_blobs = [gates_t]
720  return hidden_t, cell_t
721 
722 
724 
725  def _apply(
726  self,
727  model,
728  input_t,
729  seq_lengths,
730  states,
731  timestep,
732  extra_inputs=None,
733  ):
734  hidden_t_prev, cell_t_prev = states
735 
736  fc_input = hidden_t_prev
737  fc_input_dim = self.hidden_size
738 
739  if extra_inputs is not None:
740  extra_input_blobs, extra_input_sizes = zip(*extra_inputs)
741  fc_input = brew.concat(
742  model,
743  [hidden_t_prev] + list(extra_input_blobs),
744  self.scope('gates_concatenated_input_t'),
745  axis=2,
746  )
747  fc_input_dim += sum(extra_input_sizes)
748 
749  prev_t = brew.fc(
750  model,
751  fc_input,
752  self.scope('prev_t'),
753  dim_in=fc_input_dim,
754  dim_out=self.gates_size,
755  axis=2,
756  )
757 
758  # defining initializers for MI parameters
759  alpha = model.create_param(
760  self.scope('alpha'),
761  shape=[self.gates_size],
762  initializer=Initializer('ConstantFill', value=1.0),
763  )
764  beta_h = model.create_param(
765  self.scope('beta1'),
766  shape=[self.gates_size],
767  initializer=Initializer('ConstantFill', value=1.0),
768  )
769  beta_i = model.create_param(
770  self.scope('beta2'),
771  shape=[self.gates_size],
772  initializer=Initializer('ConstantFill', value=1.0),
773  )
774  b = model.create_param(
775  self.scope('b'),
776  shape=[self.gates_size],
777  initializer=Initializer('ConstantFill', value=0.0),
778  )
779 
780  # alpha * input_t + beta_h
781  # Shape: [1, batch_size, 4 * hidden_size]
782  alpha_by_input_t_plus_beta_h = model.net.ElementwiseLinear(
783  [input_t, alpha, beta_h],
784  self.scope('alpha_by_input_t_plus_beta_h'),
785  axis=2,
786  )
787  # (alpha * input_t + beta_h) * prev_t =
788  # alpha * input_t * prev_t + beta_h * prev_t
789  # Shape: [1, batch_size, 4 * hidden_size]
790  alpha_by_input_t_plus_beta_h_by_prev_t = model.net.Mul(
791  [alpha_by_input_t_plus_beta_h, prev_t],
792  self.scope('alpha_by_input_t_plus_beta_h_by_prev_t')
793  )
794  # beta_i * input_t + b
795  # Shape: [1, batch_size, 4 * hidden_size]
796  beta_i_by_input_t_plus_b = model.net.ElementwiseLinear(
797  [input_t, beta_i, b],
798  self.scope('beta_i_by_input_t_plus_b'),
799  axis=2,
800  )
801  # alpha * input_t * prev_t + beta_h * prev_t + beta_i * input_t + b
802  # Shape: [1, batch_size, 4 * hidden_size]
803  gates_t = brew.sum(
804  model,
805  [alpha_by_input_t_plus_beta_h_by_prev_t, beta_i_by_input_t_plus_b],
806  self.scope('gates_t')
807  )
808  # brew.layer_norm call is only difference from MILSTMCell._apply
809  gates_t, _, _ = brew.layer_norm(
810  model,
811  self.scope('gates_t'),
812  self.scope('gates_t_norm'),
813  dim_in=self.gates_size,
814  axis=-1,
815  )
816  hidden_t, cell_t = model.net.LSTMUnit(
817  [hidden_t_prev, cell_t_prev, gates_t, seq_lengths, timestep],
818  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
819  forget_bias=self.forget_bias,
820  drop_states=self.drop_states,
821  )
822  model.net.AddExternalOutputs(
823  cell_t,
824  hidden_t,
825  )
826  if self.memory_optimization:
827  self.recompute_blobs = [gates_t]
828  return hidden_t, cell_t
829 
830 
832  '''
833  Wraps arbitrary RNNCell, applying dropout to its output (but not to the
834  recurrent connection for the corresponding state).
835  '''
836 
837  def __init__(
838  self,
839  internal_cell,
840  dropout_ratio=None,
841  use_cudnn=False,
842  **kwargs
843  ):
844  self.internal_cell = internal_cell
845  self.dropout_ratio = dropout_ratio
846  assert 'is_test' in kwargs, "Argument 'is_test' is required"
847  self.is_test = kwargs.pop('is_test')
848  self.use_cudnn = use_cudnn
849  super(DropoutCell, self).__init__(**kwargs)
850 
851  self.prepare_input = internal_cell.prepare_input
852  self.get_output_state_index = internal_cell.get_output_state_index
853  self.get_state_names = internal_cell.get_state_names
854  self.get_output_dim = internal_cell.get_output_dim
855 
856  self.mask = 0
857 
858  def _apply(
859  self,
860  model,
861  input_t,
862  seq_lengths,
863  states,
864  timestep,
865  extra_inputs=None,
866  ):
867  return self.internal_cell._apply(
868  model,
869  input_t,
870  seq_lengths,
871  states,
872  timestep,
873  extra_inputs,
874  )
875 
876  def _prepare_output(self, model, states):
877  output = self.internal_cell._prepare_output(
878  model,
879  states,
880  )
881  if self.dropout_ratio is not None:
882  output = self._apply_dropout(model, output)
883  return output
884 
885  def _prepare_output_sequence(self, model, state_outputs):
886  output = self.internal_cell._prepare_output_sequence(
887  model,
888  state_outputs,
889  )
890  if self.dropout_ratio is not None:
891  output = self._apply_dropout(model, output)
892  return output
893 
894  def _apply_dropout(self, model, output):
895  if self.dropout_ratio and not self.forward_only:
896  with core.NameScope(self.name or ''):
897  output = brew.dropout(
898  model,
899  output,
900  str(output) + '_with_dropout_mask{}'.format(self.mask),
901  ratio=float(self.dropout_ratio),
902  is_test=self.is_test,
903  use_cudnn=self.use_cudnn,
904  )
905  self.mask += 1
906  return output
907 
908 
910  def __init__(self, cells):
911  self.cells = cells
912 
913  def create_states(self, model):
914  states = []
915  for i, cell in enumerate(self.cells):
916  if cell.initializer is None:
917  raise Exception("Either initial states "
918  "or initializer have to be set")
919 
920  with core.NameScope("layer_{}".format(i)),\
921  core.NameScope(cell.name):
922  states.extend(cell.initializer.create_states(model))
923  return states
924 
925 
927  '''
928  Multilayer RNN via the composition of RNNCell instance.
929 
930  It is the resposibility of calling code to ensure the compatibility
931  of the successive layers in terms of input/output dimensiality, etc.,
932  and to ensure that their blobs do not have name conflicts, typically by
933  creating the cells with names that specify layer number.
934 
935  Assumes first state (recurrent output) for each layer should be the input
936  to the next layer.
937  '''
938 
939  def __init__(self, cells, residual_output_layers=None, **kwargs):
940  '''
941  cells: list of RNNCell instances, from input to output side.
942 
943  name: string designating network component (for scoping)
944 
945  residual_output_layers: list of indices of layers whose input will
946  be added elementwise to their output elementwise. (It is the
947  responsibility of the client code to ensure shape compatibility.)
948  Note that layer 0 (zero) cannot have residual output because of the
949  timing of prepare_input().
950 
951  forward_only: used to construct inference-only network.
952  '''
953  super(MultiRNNCell, self).__init__(**kwargs)
954  self.cells = cells
955 
956  if residual_output_layers is None:
957  self.residual_output_layers = []
958  else:
959  self.residual_output_layers = residual_output_layers
960 
961  output_index_per_layer = []
962  base_index = 0
963  for cell in self.cells:
964  output_index_per_layer.append(
965  base_index + cell.get_output_state_index(),
966  )
967  base_index += len(cell.get_state_names())
968 
969  self.output_connected_layers = []
970  self.output_indices = []
971  for i in range(len(self.cells) - 1):
972  if (i + 1) in self.residual_output_layers:
973  self.output_connected_layers.append(i)
974  self.output_indices.append(output_index_per_layer[i])
975  else:
976  self.output_connected_layers = []
977  self.output_indices = []
978  self.output_connected_layers.append(len(self.cells) - 1)
979  self.output_indices.append(output_index_per_layer[-1])
980 
981  self.state_names = []
982  for i, cell in enumerate(self.cells):
983  self.state_names.extend(
984  map(self.layer_scoper(i), cell.get_state_names())
985  )
986 
988 
989  def layer_scoper(self, layer_id):
990  def helper(name):
991  return "layer_{}/{}".format(layer_id, name)
992  return helper
993 
994  def prepare_input(self, model, input_blob):
995  return self.cells[0].prepare_input(model, input_blob)
996 
997  def _apply(
998  self,
999  model,
1000  input_t,
1001  seq_lengths,
1002  states,
1003  timestep,
1004  extra_inputs=None,
1005  ):
1006  '''
1007  Because below we will do scoping across layers, we need
1008  to make sure that string blob names are convereted to BlobReference
1009  objects.
1010  '''
1011 
1012  input_t, seq_lengths, states, timestep, extra_inputs = \
1013  self._rectify_apply_inputs(
1014  input_t, seq_lengths, states, timestep, extra_inputs)
1015 
1016  states_per_layer = [len(cell.get_state_names()) for cell in self.cells]
1017  assert len(states) == sum(states_per_layer)
1018 
1019  next_states = []
1020  states_index = 0
1021 
1022  layer_input = input_t
1023  for i, layer_cell in enumerate(self.cells):
1024  # # If cells don't have different names we still
1025  # take care of scoping
1026  with core.NameScope(self.name), core.NameScope("layer_{}".format(i)):
1027  num_states = states_per_layer[i]
1028  layer_states = states[states_index:(states_index + num_states)]
1029  states_index += num_states
1030 
1031  if i > 0:
1032  prepared_input = layer_cell.prepare_input(
1033  model, layer_input)
1034  else:
1035  prepared_input = layer_input
1036 
1037  layer_next_states = layer_cell._apply(
1038  model,
1039  prepared_input,
1040  seq_lengths,
1041  layer_states,
1042  timestep,
1043  extra_inputs=(None if i > 0 else extra_inputs),
1044  )
1045  # Since we're using here non-public method _apply,
1046  # instead of apply, we have to manually extract output
1047  # from states
1048  if i != len(self.cells) - 1:
1049  layer_output = layer_cell._prepare_output(
1050  model,
1051  layer_next_states,
1052  )
1053  if i > 0 and i in self.residual_output_layers:
1054  layer_input = brew.sum(
1055  model,
1056  [layer_output, layer_input],
1057  self.scope('residual_output_{}'.format(i)),
1058  )
1059  else:
1060  layer_input = layer_output
1061 
1062  next_states.extend(layer_next_states)
1063  return next_states
1064 
1065  def get_state_names(self):
1066  return self.state_names
1067 
1068  def get_output_state_index(self):
1069  index = 0
1070  for cell in self.cells[:-1]:
1071  index += len(cell.get_state_names())
1072  index += self.cells[-1].get_output_state_index()
1073  return index
1074 
1075  def _prepare_output(self, model, states):
1076  connected_outputs = []
1077  state_index = 0
1078  for i, cell in enumerate(self.cells):
1079  num_states = len(cell.get_state_names())
1080  if i in self.output_connected_layers:
1081  layer_states = states[state_index:state_index + num_states]
1082  layer_output = cell._prepare_output(
1083  model,
1084  layer_states
1085  )
1086  connected_outputs.append(layer_output)
1087  state_index += num_states
1088  if len(connected_outputs) > 1:
1089  output = brew.sum(
1090  model,
1091  connected_outputs,
1092  self.scope('residual_output'),
1093  )
1094  else:
1095  output = connected_outputs[0]
1096  return output
1097 
1098  def _prepare_output_sequence(self, model, states):
1099  connected_outputs = []
1100  state_index = 0
1101  for i, cell in enumerate(self.cells):
1102  num_states = 2 * len(cell.get_state_names())
1103  if i in self.output_connected_layers:
1104  layer_states = states[state_index:state_index + num_states]
1105  layer_output = cell._prepare_output_sequence(
1106  model,
1107  layer_states
1108  )
1109  connected_outputs.append(layer_output)
1110  state_index += num_states
1111  if len(connected_outputs) > 1:
1112  output = brew.sum(
1113  model,
1114  connected_outputs,
1115  self.scope('residual_output_sequence'),
1116  )
1117  else:
1118  output = connected_outputs[0]
1119  return output
1120 
1121 
1123 
1124  def __init__(
1125  self,
1126  encoder_output_dim,
1127  encoder_outputs,
1128  encoder_lengths,
1129  decoder_cell,
1130  decoder_state_dim,
1131  attention_type,
1132  weighted_encoder_outputs,
1133  attention_memory_optimization,
1134  **kwargs
1135  ):
1136  super(AttentionCell, self).__init__(**kwargs)
1137  self.encoder_output_dim = encoder_output_dim
1138  self.encoder_outputs = encoder_outputs
1139  self.encoder_lengths = encoder_lengths
1140  self.decoder_cell = decoder_cell
1141  self.decoder_state_dim = decoder_state_dim
1142  self.weighted_encoder_outputs = weighted_encoder_outputs
1143  self.encoder_outputs_transposed = None
1144  assert attention_type in [
1145  AttentionType.Regular,
1146  AttentionType.Recurrent,
1147  AttentionType.Dot,
1148  AttentionType.SoftCoverage,
1149  ]
1150  self.attention_type = attention_type
1151  self.attention_memory_optimization = attention_memory_optimization
1152 
1153  def _apply(
1154  self,
1155  model,
1156  input_t,
1157  seq_lengths,
1158  states,
1159  timestep,
1160  extra_inputs=None,
1161  ):
1162  if self.attention_type == AttentionType.SoftCoverage:
1163  decoder_prev_states = states[:-2]
1164  attention_weighted_encoder_context_t_prev = states[-2]
1165  coverage_t_prev = states[-1]
1166  else:
1167  decoder_prev_states = states[:-1]
1168  attention_weighted_encoder_context_t_prev = states[-1]
1169 
1170  assert extra_inputs is None
1171 
1172  decoder_states = self.decoder_cell._apply(
1173  model,
1174  input_t,
1175  seq_lengths,
1176  decoder_prev_states,
1177  timestep,
1178  extra_inputs=[(
1179  attention_weighted_encoder_context_t_prev,
1180  self.encoder_output_dim,
1181  )],
1182  )
1183 
1184  self.hidden_t_intermediate = self.decoder_cell._prepare_output(
1185  model,
1186  decoder_states,
1187  )
1188 
1189  if self.attention_type == AttentionType.Recurrent:
1190  (
1191  attention_weighted_encoder_context_t,
1192  self.attention_weights_3d,
1193  attention_blobs,
1194  ) = apply_recurrent_attention(
1195  model=model,
1196  encoder_output_dim=self.encoder_output_dim,
1197  encoder_outputs_transposed=self.encoder_outputs_transposed,
1198  weighted_encoder_outputs=self.weighted_encoder_outputs,
1199  decoder_hidden_state_t=self.hidden_t_intermediate,
1200  decoder_hidden_state_dim=self.decoder_state_dim,
1201  scope=self.name,
1202  attention_weighted_encoder_context_t_prev=(
1203  attention_weighted_encoder_context_t_prev
1204  ),
1205  encoder_lengths=self.encoder_lengths,
1206  )
1207  elif self.attention_type == AttentionType.Regular:
1208  (
1209  attention_weighted_encoder_context_t,
1210  self.attention_weights_3d,
1211  attention_blobs,
1212  ) = apply_regular_attention(
1213  model=model,
1214  encoder_output_dim=self.encoder_output_dim,
1215  encoder_outputs_transposed=self.encoder_outputs_transposed,
1216  weighted_encoder_outputs=self.weighted_encoder_outputs,
1217  decoder_hidden_state_t=self.hidden_t_intermediate,
1218  decoder_hidden_state_dim=self.decoder_state_dim,
1219  scope=self.name,
1220  encoder_lengths=self.encoder_lengths,
1221  )
1222  elif self.attention_type == AttentionType.Dot:
1223  (
1224  attention_weighted_encoder_context_t,
1225  self.attention_weights_3d,
1226  attention_blobs,
1227  ) = apply_dot_attention(
1228  model=model,
1229  encoder_output_dim=self.encoder_output_dim,
1230  encoder_outputs_transposed=self.encoder_outputs_transposed,
1231  decoder_hidden_state_t=self.hidden_t_intermediate,
1232  decoder_hidden_state_dim=self.decoder_state_dim,
1233  scope=self.name,
1234  encoder_lengths=self.encoder_lengths,
1235  )
1236  elif self.attention_type == AttentionType.SoftCoverage:
1237  (
1238  attention_weighted_encoder_context_t,
1239  self.attention_weights_3d,
1240  attention_blobs,
1241  coverage_t,
1242  ) = apply_soft_coverage_attention(
1243  model=model,
1244  encoder_output_dim=self.encoder_output_dim,
1245  encoder_outputs_transposed=self.encoder_outputs_transposed,
1246  weighted_encoder_outputs=self.weighted_encoder_outputs,
1247  decoder_hidden_state_t=self.hidden_t_intermediate,
1248  decoder_hidden_state_dim=self.decoder_state_dim,
1249  scope=self.name,
1250  encoder_lengths=self.encoder_lengths,
1251  coverage_t_prev=coverage_t_prev,
1252  coverage_weights=self.coverage_weights,
1253  )
1254  else:
1255  raise Exception('Attention type {} not implemented'.format(
1256  self.attention_type
1257  ))
1258 
1260  self.recompute_blobs.extend(attention_blobs)
1261 
1262  output = list(decoder_states) + [attention_weighted_encoder_context_t]
1263  if self.attention_type == AttentionType.SoftCoverage:
1264  output.append(coverage_t)
1265 
1266  output[self.decoder_cell.get_output_state_index()] = model.Copy(
1267  output[self.decoder_cell.get_output_state_index()],
1268  self.scope('hidden_t_external'),
1269  )
1270  model.net.AddExternalOutputs(*output)
1271 
1272  return output
1273 
1274  def get_attention_weights(self):
1275  # [batch_size, encoder_length, 1]
1276  return self.attention_weights_3d
1277 
1278  def prepare_input(self, model, input_blob):
1279  if self.encoder_outputs_transposed is None:
1280  self.encoder_outputs_transposed = brew.transpose(
1281  model,
1282  self.encoder_outputs,
1283  self.scope('encoder_outputs_transposed'),
1284  axes=[1, 2, 0],
1285  )
1286  if (
1287  self.weighted_encoder_outputs is None and
1288  self.attention_type != AttentionType.Dot
1289  ):
1290  self.weighted_encoder_outputs = brew.fc(
1291  model,
1292  self.encoder_outputs,
1293  self.scope('weighted_encoder_outputs'),
1294  dim_in=self.encoder_output_dim,
1295  dim_out=self.encoder_output_dim,
1296  axis=2,
1297  )
1298 
1299  return self.decoder_cell.prepare_input(model, input_blob)
1300 
1301  def build_initial_coverage(self, model):
1302  """
1303  initial_coverage is always zeros of shape [encoder_length],
1304  which shape must be determined programmatically dureing network
1305  computation.
1306 
1307  This method also sets self.coverage_weights, a separate transform
1308  of encoder_outputs which is used to determine coverage contribution
1309  tp attention.
1310  """
1311  assert self.attention_type == AttentionType.SoftCoverage
1312 
1313  # [encoder_length, batch_size, encoder_output_dim]
1314  self.coverage_weights = brew.fc(
1315  model,
1316  self.encoder_outputs,
1317  self.scope('coverage_weights'),
1318  dim_in=self.encoder_output_dim,
1319  dim_out=self.encoder_output_dim,
1320  axis=2,
1321  )
1322 
1323  encoder_length = model.net.Slice(
1324  model.net.Shape(self.encoder_outputs),
1325  starts=[0],
1326  ends=[1],
1327  )
1328  if (
1329  scope.CurrentDeviceScope() is not None and
1330  scope.CurrentDeviceScope().device_type == caffe2_pb2.CUDA
1331  ):
1332  encoder_length = model.net.CopyGPUToCPU(
1333  encoder_length,
1334  'encoder_length_cpu',
1335  )
1336  # total attention weight applied across decoding steps_per_checkpoint
1337  # shape: [encoder_length]
1338  initial_coverage = model.net.ConstantFill(
1339  encoder_length,
1340  self.scope('initial_coverage'),
1341  value=0.0,
1342  input_as_shape=1,
1343  )
1344  return initial_coverage
1345 
1346  def get_state_names(self):
1347  state_names = list(self.decoder_cell.get_state_names())
1348  state_names[self.get_output_state_index()] = self.scope(
1349  'hidden_t_external',
1350  )
1351  state_names.append(self.scope('attention_weighted_encoder_context_t'))
1352  if self.attention_type == AttentionType.SoftCoverage:
1353  state_names.append(self.scope('coverage_t'))
1354  return state_names
1355 
1356  def get_output_dim(self):
1357  return self.decoder_state_dim + self.encoder_output_dim
1358 
1359  def get_output_state_index(self):
1360  return self.decoder_cell.get_output_state_index()
1361 
1362  def _prepare_output(self, model, states):
1363  if self.attention_type == AttentionType.SoftCoverage:
1364  attention_context = states[-2]
1365  else:
1366  attention_context = states[-1]
1367 
1368  with core.NameScope(self.name or ''):
1369  output = brew.concat(
1370  model,
1371  [self.hidden_t_intermediate, attention_context],
1372  'states_and_context_combination',
1373  axis=2,
1374  )
1375 
1376  return output
1377 
1378  def _prepare_output_sequence(self, model, state_outputs):
1379  if self.attention_type == AttentionType.SoftCoverage:
1380  decoder_state_outputs = state_outputs[:-4]
1381  else:
1382  decoder_state_outputs = state_outputs[:-2]
1383 
1384  decoder_output = self.decoder_cell._prepare_output_sequence(
1385  model,
1386  decoder_state_outputs,
1387  )
1388 
1389  if self.attention_type == AttentionType.SoftCoverage:
1390  attention_context_index = 2 * (len(self.get_state_names()) - 2)
1391  else:
1392  attention_context_index = 2 * (len(self.get_state_names()) - 1)
1393 
1394  with core.NameScope(self.name or ''):
1395  output = brew.concat(
1396  model,
1397  [
1398  decoder_output,
1399  state_outputs[attention_context_index],
1400  ],
1401  'states_and_context_combination',
1402  axis=2,
1403  )
1404  return output
1405 
1406 
1408 
1409  def __init__(
1410  self,
1411  encoder_output_dim,
1412  encoder_outputs,
1413  encoder_lengths,
1414  decoder_input_dim,
1415  decoder_state_dim,
1416  name,
1417  attention_type,
1418  weighted_encoder_outputs,
1419  forget_bias,
1420  lstm_memory_optimization,
1421  attention_memory_optimization,
1422  forward_only=False,
1423  ):
1424  decoder_cell = LSTMCell(
1425  input_size=decoder_input_dim,
1426  hidden_size=decoder_state_dim,
1427  forget_bias=forget_bias,
1428  memory_optimization=lstm_memory_optimization,
1429  name='{}/decoder'.format(name),
1430  forward_only=False,
1431  drop_states=False,
1432  )
1433  super(LSTMWithAttentionCell, self).__init__(
1434  encoder_output_dim=encoder_output_dim,
1435  encoder_outputs=encoder_outputs,
1436  encoder_lengths=encoder_lengths,
1437  decoder_cell=decoder_cell,
1438  decoder_state_dim=decoder_state_dim,
1439  name=name,
1440  attention_type=attention_type,
1441  weighted_encoder_outputs=weighted_encoder_outputs,
1442  attention_memory_optimization=attention_memory_optimization,
1443  forward_only=forward_only,
1444  )
1445 
1446 
1448 
1449  def __init__(
1450  self,
1451  encoder_output_dim,
1452  encoder_outputs,
1453  decoder_input_dim,
1454  decoder_state_dim,
1455  name,
1456  attention_type,
1457  weighted_encoder_outputs,
1458  forget_bias,
1459  lstm_memory_optimization,
1460  attention_memory_optimization,
1461  forward_only=False,
1462  ):
1463  decoder_cell = MILSTMCell(
1464  input_size=decoder_input_dim,
1465  hidden_size=decoder_state_dim,
1466  forget_bias=forget_bias,
1467  memory_optimization=lstm_memory_optimization,
1468  name='{}/decoder'.format(name),
1469  forward_only=False,
1470  drop_states=False,
1471  )
1472  super(MILSTMWithAttentionCell, self).__init__(
1473  encoder_output_dim=encoder_output_dim,
1474  encoder_outputs=encoder_outputs,
1475  decoder_cell=decoder_cell,
1476  decoder_state_dim=decoder_state_dim,
1477  name=name,
1478  attention_type=attention_type,
1479  weighted_encoder_outputs=weighted_encoder_outputs,
1480  attention_memory_optimization=attention_memory_optimization,
1481  forward_only=forward_only,
1482  )
1483 
1484 
1485 def _LSTM(
1486  cell_class,
1487  model,
1488  input_blob,
1489  seq_lengths,
1490  initial_states,
1491  dim_in,
1492  dim_out,
1493  scope=None,
1494  outputs_with_grads=(0,),
1495  return_params=False,
1496  memory_optimization=False,
1497  forget_bias=0.0,
1498  forward_only=False,
1499  drop_states=False,
1500  return_last_layer_only=True,
1501  static_rnn_unroll_size=None,
1502  **cell_kwargs
1503 ):
1504  '''
1505  Adds a standard LSTM recurrent network operator to a model.
1506 
1507  cell_class: LSTMCell or compatible subclass
1508 
1509  model: ModelHelper object new operators would be added to
1510 
1511  input_blob: the input sequence in a format T x N x D
1512  where T is sequence size, N - batch size and D - input dimension
1513 
1514  seq_lengths: blob containing sequence lengths which would be passed to
1515  LSTMUnit operator
1516 
1517  initial_states: a list of (2 * num_layers) blobs representing the initial
1518  hidden and cell states of each layer. If this argument is None,
1519  these states will be added to the model as network parameters.
1520 
1521  dim_in: input dimension
1522 
1523  dim_out: number of units per LSTM layer
1524  (use int for single-layer LSTM, list of ints for multi-layer)
1525 
1526  outputs_with_grads : position indices of output blobs for LAST LAYER which
1527  will receive external error gradient during backpropagation.
1528  These outputs are: (h_all, h_last, c_all, c_last)
1529 
1530  return_params: if True, will return a dictionary of parameters of the LSTM
1531 
1532  memory_optimization: if enabled, the LSTM step is recomputed on backward
1533  step so that we don't need to store forward activations for each
1534  timestep. Saves memory with cost of computation.
1535 
1536  forget_bias: forget gate bias (default 0.0)
1537 
1538  forward_only: whether to create a backward pass
1539 
1540  drop_states: drop invalid states, passed through to LSTMUnit operator
1541 
1542  return_last_layer_only: only return outputs from final layer
1543  (so that length of results does depend on number of layers)
1544 
1545  static_rnn_unroll_size: if not None, we will use static RNN which is
1546  unrolled into Caffe2 graph. The size of the unroll is the value of
1547  this parameter.
1548  '''
1549  if type(dim_out) is not list and type(dim_out) is not tuple:
1550  dim_out = [dim_out]
1551  num_layers = len(dim_out)
1552 
1553  cells = []
1554  for i in range(num_layers):
1555  cell = cell_class(
1556  input_size=(dim_in if i == 0 else dim_out[i - 1]),
1557  hidden_size=dim_out[i],
1558  forget_bias=forget_bias,
1559  memory_optimization=memory_optimization,
1560  name=scope,
1561  forward_only=forward_only,
1562  drop_states=drop_states,
1563  **cell_kwargs
1564  )
1565  cells.append(cell)
1566 
1567  cell = MultiRNNCell(
1568  cells,
1569  name=scope,
1570  forward_only=forward_only,
1571  ) if num_layers > 1 else cells[0]
1572 
1573  cell = (
1574  cell if static_rnn_unroll_size is None
1575  else UnrolledCell(cell, static_rnn_unroll_size))
1576 
1577  # outputs_with_grads argument indexes into final layer
1578  outputs_with_grads = [4 * (num_layers - 1) + i for i in outputs_with_grads]
1579  _, result = cell.apply_over_sequence(
1580  model=model,
1581  inputs=input_blob,
1582  seq_lengths=seq_lengths,
1583  initial_states=initial_states,
1584  outputs_with_grads=outputs_with_grads,
1585  )
1586 
1587  if return_last_layer_only:
1588  result = result[4 * (num_layers - 1):]
1589  if return_params:
1590  result = list(result) + [{
1591  'input': cell.get_input_params(),
1592  'recurrent': cell.get_recurrent_params(),
1593  }]
1594  return tuple(result)
1595 
1596 
1597 LSTM = functools.partial(_LSTM, LSTMCell)
1598 BasicRNN = functools.partial(_LSTM, BasicRNNCell)
1599 MILSTM = functools.partial(_LSTM, MILSTMCell)
1600 LayerNormLSTM = functools.partial(_LSTM, LayerNormLSTMCell)
1601 LayerNormMILSTM = functools.partial(_LSTM, LayerNormMILSTMCell)
1602 
1603 
1605  def __init__(self, cell, T):
1606  self.T = T
1607  self.cell = cell
1608 
1609  def apply_over_sequence(
1610  self,
1611  model,
1612  inputs,
1613  seq_lengths,
1614  initial_states,
1615  outputs_with_grads=None,
1616  ):
1617  inputs = self.cell.prepare_input(model, inputs)
1618 
1619  # Now they are blob references - outputs of splitting the input sequence
1620  split_inputs = model.net.Split(
1621  inputs,
1622  [str(inputs) + "_timestep_{}".format(i)
1623  for i in range(self.T)],
1624  axis=0)
1625  if self.T == 1:
1626  split_inputs = [split_inputs]
1627 
1628  states = initial_states
1629  all_states = []
1630  for t in range(0, self.T):
1631  scope_name = "timestep_{}".format(t)
1632  # Parameters of all timesteps are shared
1633  with ParameterSharing({scope_name: ''}),\
1634  scope.NameScope(scope_name):
1635  timestep = model.param_init_net.ConstantFill(
1636  [], "timestep", value=t, shape=[1],
1637  dtype=core.DataType.INT32,
1638  device_option=core.DeviceOption(caffe2_pb2.CPU))
1639  states = self.cell._apply(
1640  model=model,
1641  input_t=split_inputs[t],
1642  seq_lengths=seq_lengths,
1643  states=states,
1644  timestep=timestep,
1645  )
1646  all_states.append(states)
1647 
1648  all_states = zip(*all_states)
1649  all_states = [
1650  model.net.Concat(
1651  list(full_output),
1652  [
1653  str(full_output[0])[len("timestep_0/"):] + "_concat",
1654  str(full_output[0])[len("timestep_0/"):] + "_concat_info"
1655 
1656  ],
1657  axis=0)[0]
1658  for full_output in all_states
1659  ]
1660  outputs = tuple(
1661  six.next(it) for it in
1662  itertools.cycle([iter(all_states), iter(states)])
1663  )
1664  outputs_without_grad = set(range(len(outputs))) - set(
1665  outputs_with_grads)
1666  for i in outputs_without_grad:
1667  model.net.ZeroGradient(outputs[i], [])
1668  logging.debug("Added 0 gradients for blobs:",
1669  [outputs[i] for i in outputs_without_grad])
1670 
1671  final_output = self.cell._prepare_output_sequence(model, outputs)
1672 
1673  return final_output, outputs
1674 
1675 
1676 def GetLSTMParamNames():
1677  weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
1678  bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
1679  return {'weights': weight_params, 'biases': bias_params}
1680 
1681 
1682 def InitFromLSTMParams(lstm_pblobs, param_values):
1683  '''
1684  Set the parameters of LSTM based on predefined values
1685  '''
1686  weight_params = GetLSTMParamNames()['weights']
1687  bias_params = GetLSTMParamNames()['biases']
1688  for input_type in viewkeys(param_values):
1689  weight_values = [
1690  param_values[input_type][w].flatten()
1691  for w in weight_params
1692  ]
1693  wmat = np.array([])
1694  for w in weight_values:
1695  wmat = np.append(wmat, w)
1696  bias_values = [
1697  param_values[input_type][b].flatten()
1698  for b in bias_params
1699  ]
1700  bm = np.array([])
1701  for b in bias_values:
1702  bm = np.append(bm, b)
1703 
1704  weights_blob = lstm_pblobs[input_type]['weights']
1705  bias_blob = lstm_pblobs[input_type]['biases']
1706  cur_weight = workspace.FetchBlob(weights_blob)
1707  cur_biases = workspace.FetchBlob(bias_blob)
1708 
1709  workspace.FeedBlob(
1710  weights_blob,
1711  wmat.reshape(cur_weight.shape).astype(np.float32))
1712  workspace.FeedBlob(
1713  bias_blob,
1714  bm.reshape(cur_biases.shape).astype(np.float32))
1715 
1716 
1717 def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
1718  scope, recurrent_params=None, input_params=None,
1719  num_layers=1, return_params=False):
1720  '''
1721  CuDNN version of LSTM for GPUs.
1722  input_blob Blob containing the input. Will need to be available
1723  when param_init_net is run, because the sequence lengths
1724  and batch sizes will be inferred from the size of this
1725  blob.
1726  initial_states tuple of (hidden_init, cell_init) blobs
1727  dim_in input dimensions
1728  dim_out output/hidden dimension
1729  scope namescope to apply
1730  recurrent_params dict of blobs containing values for recurrent
1731  gate weights, biases (if None, use random init values)
1732  See GetLSTMParamNames() for format.
1733  input_params dict of blobs containing values for input
1734  gate weights, biases (if None, use random init values)
1735  See GetLSTMParamNames() for format.
1736  num_layers number of LSTM layers
1737  return_params if True, returns (param_extract_net, param_mapping)
1738  where param_extract_net is a net that when run, will
1739  populate the blobs specified in param_mapping with the
1740  current gate weights and biases (input/recurrent).
1741  Useful for assigning the values back to non-cuDNN
1742  LSTM.
1743  '''
1744  with core.NameScope(scope):
1745  weight_params = GetLSTMParamNames()['weights']
1746  bias_params = GetLSTMParamNames()['biases']
1747 
1748  input_weight_size = dim_out * dim_in
1749  upper_layer_input_weight_size = dim_out * dim_out
1750  recurrent_weight_size = dim_out * dim_out
1751  input_bias_size = dim_out
1752  recurrent_bias_size = dim_out
1753 
1754  def init(layer, pname, input_type):
1755  input_weight_size_for_layer = input_weight_size if layer == 0 else \
1756  upper_layer_input_weight_size
1757  if pname in weight_params:
1758  sz = input_weight_size_for_layer if input_type == 'input' \
1759  else recurrent_weight_size
1760  elif pname in bias_params:
1761  sz = input_bias_size if input_type == 'input' \
1762  else recurrent_bias_size
1763  else:
1764  assert False, "unknown parameter type {}".format(pname)
1765  return model.param_init_net.UniformFill(
1766  [],
1767  "lstm_init_{}_{}_{}".format(input_type, pname, layer),
1768  shape=[sz])
1769 
1770  # Multiply by 4 since we have 4 gates per LSTM unit
1771  first_layer_sz = input_weight_size + recurrent_weight_size + \
1772  input_bias_size + recurrent_bias_size
1773  upper_layer_sz = upper_layer_input_weight_size + \
1774  recurrent_weight_size + input_bias_size + \
1775  recurrent_bias_size
1776  total_sz = 4 * (first_layer_sz + (num_layers - 1) * upper_layer_sz)
1777 
1778  weights = model.create_param(
1779  'lstm_weight',
1780  shape=[total_sz],
1781  initializer=Initializer('UniformFill'),
1782  tags=ParameterTags.WEIGHT,
1783  )
1784 
1785  lstm_args = {
1786  'hidden_size': dim_out,
1787  'rnn_mode': 'lstm',
1788  'bidirectional': 0, # TODO
1789  'dropout': 1.0, # TODO
1790  'input_mode': 'linear', # TODO
1791  'num_layers': num_layers,
1792  'engine': 'CUDNN'
1793  }
1794 
1795  param_extract_net = core.Net("lstm_param_extractor")
1796  param_extract_net.AddExternalInputs([input_blob, weights])
1797  param_extract_mapping = {}
1798 
1799  # Populate the weights-blob from blobs containing parameters for
1800  # the individual components of the LSTM, such as forget/input gate
1801  # weights and bises. Also, create a special param_extract_net that
1802  # can be used to grab those individual params from the black-box
1803  # weights blob. These results can be then fed to InitFromLSTMParams()
1804  for input_type in ['input', 'recurrent']:
1805  param_extract_mapping[input_type] = {}
1806  p = recurrent_params if input_type == 'recurrent' else input_params
1807  if p is None:
1808  p = {}
1809  for pname in weight_params + bias_params:
1810  for j in range(0, num_layers):
1811  values = p[pname] if pname in p else init(j, pname, input_type)
1812  model.param_init_net.RecurrentParamSet(
1813  [input_blob, weights, values],
1814  weights,
1815  layer=j,
1816  input_type=input_type,
1817  param_type=pname,
1818  **lstm_args
1819  )
1820  if pname not in param_extract_mapping[input_type]:
1821  param_extract_mapping[input_type][pname] = {}
1822  b = param_extract_net.RecurrentParamGet(
1823  [input_blob, weights],
1824  ["lstm_{}_{}_{}".format(input_type, pname, j)],
1825  layer=j,
1826  input_type=input_type,
1827  param_type=pname,
1828  **lstm_args
1829  )
1830  param_extract_mapping[input_type][pname][j] = b
1831 
1832  (hidden_input_blob, cell_input_blob) = initial_states
1833  output, hidden_output, cell_output, rnn_scratch, dropout_states = \
1834  model.net.Recurrent(
1835  [input_blob, hidden_input_blob, cell_input_blob, weights],
1836  ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
1837  "lstm_rnn_scratch", "lstm_dropout_states"],
1838  seed=random.randint(0, 100000), # TODO: dropout seed
1839  **lstm_args
1840  )
1841  model.net.AddExternalOutputs(
1842  hidden_output, cell_output, rnn_scratch, dropout_states)
1843 
1844  if return_params:
1845  param_extract = param_extract_net, param_extract_mapping
1846  return output, hidden_output, cell_output, param_extract
1847  else:
1848  return output, hidden_output, cell_output
1849 
1850 
1851 def LSTMWithAttention(
1852  model,
1853  decoder_inputs,
1854  decoder_input_lengths,
1855  initial_decoder_hidden_state,
1856  initial_decoder_cell_state,
1857  initial_attention_weighted_encoder_context,
1858  encoder_output_dim,
1859  encoder_outputs,
1860  encoder_lengths,
1861  decoder_input_dim,
1862  decoder_state_dim,
1863  scope,
1864  attention_type=AttentionType.Regular,
1865  outputs_with_grads=(0, 4),
1866  weighted_encoder_outputs=None,
1867  lstm_memory_optimization=False,
1868  attention_memory_optimization=False,
1869  forget_bias=0.0,
1870  forward_only=False,
1871 ):
1872  '''
1873  Adds a LSTM with attention mechanism to a model.
1874 
1875  The implementation is based on https://arxiv.org/abs/1409.0473, with
1876  a small difference in the order
1877  how we compute new attention context and new hidden state, similarly to
1878  https://arxiv.org/abs/1508.04025.
1879 
1880  The model uses encoder-decoder naming conventions,
1881  where the decoder is the sequence the op is iterating over,
1882  while computing the attention context over the encoder.
1883 
1884  model: ModelHelper object new operators would be added to
1885 
1886  decoder_inputs: the input sequence in a format T x N x D
1887  where T is sequence size, N - batch size and D - input dimension
1888 
1889  decoder_input_lengths: blob containing sequence lengths
1890  which would be passed to LSTMUnit operator
1891 
1892  initial_decoder_hidden_state: initial hidden state of LSTM
1893 
1894  initial_decoder_cell_state: initial cell state of LSTM
1895 
1896  initial_attention_weighted_encoder_context: initial attention context
1897 
1898  encoder_output_dim: dimension of encoder outputs
1899 
1900  encoder_outputs: the sequence, on which we compute the attention context
1901  at every iteration
1902 
1903  encoder_lengths: a tensor with lengths of each encoder sequence in batch
1904  (may be None, meaning all encoder sequences are of same length)
1905 
1906  decoder_input_dim: input dimension (last dimension on decoder_inputs)
1907 
1908  decoder_state_dim: size of hidden states of LSTM
1909 
1910  attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
1911  Determines which type of attention mechanism to use.
1912 
1913  outputs_with_grads : position indices of output blobs which will receive
1914  external error gradient during backpropagation
1915 
1916  weighted_encoder_outputs: encoder outputs to be used to compute attention
1917  weights. In the basic case it's just linear transformation of
1918  encoder outputs (that the default, when weighted_encoder_outputs is None).
1919  However, it can be something more complicated - like a separate
1920  encoder network (for example, in case of convolutional encoder)
1921 
1922  lstm_memory_optimization: recompute LSTM activations on backward pass, so
1923  we don't need to store their values in forward passes
1924 
1925  attention_memory_optimization: recompute attention for backward pass
1926 
1927  forward_only: whether to create only forward pass
1928  '''
1929  cell = LSTMWithAttentionCell(
1930  encoder_output_dim=encoder_output_dim,
1931  encoder_outputs=encoder_outputs,
1932  encoder_lengths=encoder_lengths,
1933  decoder_input_dim=decoder_input_dim,
1934  decoder_state_dim=decoder_state_dim,
1935  name=scope,
1936  attention_type=attention_type,
1937  weighted_encoder_outputs=weighted_encoder_outputs,
1938  forget_bias=forget_bias,
1939  lstm_memory_optimization=lstm_memory_optimization,
1940  attention_memory_optimization=attention_memory_optimization,
1941  forward_only=forward_only,
1942  )
1943  initial_states = [
1944  initial_decoder_hidden_state,
1945  initial_decoder_cell_state,
1946  initial_attention_weighted_encoder_context,
1947  ]
1948  if attention_type == AttentionType.SoftCoverage:
1949  initial_states.append(cell.build_initial_coverage(model))
1950  _, result = cell.apply_over_sequence(
1951  model=model,
1952  inputs=decoder_inputs,
1953  seq_lengths=decoder_input_lengths,
1954  initial_states=initial_states,
1955  outputs_with_grads=outputs_with_grads,
1956  )
1957  return result
1958 
1959 
1960 def _layered_LSTM(
1961  model, input_blob, seq_lengths, initial_states,
1962  dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False,
1963  memory_optimization=False, forget_bias=0.0, forward_only=False,
1964  drop_states=False, create_lstm=None):
1965  params = locals() # leave it as a first line to grab all params
1966  params.pop('create_lstm')
1967  if not isinstance(dim_out, list):
1968  return create_lstm(**params)
1969  elif len(dim_out) == 1:
1970  params['dim_out'] = dim_out[0]
1971  return create_lstm(**params)
1972 
1973  assert len(dim_out) != 0, "dim_out list can't be empty"
1974  assert return_params is False, "return_params not supported for layering"
1975  for i, output_dim in enumerate(dim_out):
1976  params.update({
1977  'dim_out': output_dim
1978  })
1979  output, last_output, all_states, last_state = create_lstm(**params)
1980  params.update({
1981  'input_blob': output,
1982  'dim_in': output_dim,
1983  'initial_states': (last_output, last_state),
1984  'scope': scope + '_layer_{}'.format(i + 1)
1985  })
1986  return output, last_output, all_states, last_state
1987 
1988 
1989 layered_LSTM = functools.partial(_layered_LSTM, create_lstm=LSTM)
def apply_override(self, model, input_t, seq_lengths, timestep, extra_inputs=None)
Definition: rnn_cell.py:207
def __init__(self, cells, residual_output_layers=None, kwargs)
Definition: rnn_cell.py:939
Module caffe2.python.scope.
def layer_scoper(self, layer_id)
Definition: rnn_cell.py:989
def _rectify_apply_inputs(self, input_t, seq_lengths, states, timestep, extra_inputs)
Definition: rnn_cell.py:183
def _apply(self, model, input_t, seq_lengths, states, timestep, extra_inputs=None)
Definition: rnn_cell.py:170
def _prepare_output_sequence(self, model, state_outputs)
Definition: rnn_cell.py:278
def build_initial_coverage(self, model)
Definition: rnn_cell.py:1301
def prepare_input(self, model, input_blob)
Definition: rnn_cell.py:231
def _prepare_output(self, model, states)
Definition: rnn_cell.py:272
def _apply_dropout(self, model, output)
Definition: rnn_cell.py:894