Caffe2 - Python API
A deep learning, cross platform ML framework
recurrent.py
1 ## @package recurrent
2 # Module caffe2.python.recurrent
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, workspace
9 from future.utils import viewitems, viewkeys
10 
11 def recurrent_net(
12  net, cell_net, inputs, initial_cell_inputs,
13  links, timestep=None, scope=None, outputs_with_grads=(0,),
14  recompute_blobs_on_backward=None, forward_only=False,
15 ):
16  '''
17  net: the main net operator should be added to
18 
19  cell_net: cell_net which is executed in a recurrent fasion
20 
21  inputs: sequences to be fed into the recurrent net. Currently only one input
22  is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
23  of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
24 
25  initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
26  Format for each input is:
27  (cell_net_input_name, external_blob_with_data)
28 
29  links: a dictionary from cell_net input names in moment t+1 and
30  output names of moment t. Currently we assume that each output becomes
31  an input for the next timestep.
32 
33  timestep: name of the timestep blob to be used. If not provided "timestep"
34  is used.
35 
36  scope: Internal blobs are going to be scoped in a format
37  <scope_name>/<blob_name>
38  If not provided we generate a scope name automatically
39 
40  outputs_with_grads : position indices of output blobs which will receive
41  error gradient (from outside recurrent network) during backpropagation
42 
43  recompute_blobs_on_backward: specify a list of blobs that will be
44  recomputed for backward pass, and thus need not to be
45  stored for each forward timestep.
46 
47  forward_only: if True, only forward steps are executed
48  '''
49  assert len(inputs) == 1, "Only one input blob is supported so far"
50 
51  input_blobs = [str(i[0]) for i in inputs]
52  initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
53  op_name = net.NextName('recurrent')
54 
55  def s(name):
56  # We have to manually scope due to our internal/external blob
57  # relationships.
58  scope_name = op_name if scope is None else scope
59  return "{}/{}".format(str(scope_name), str(name))
60 
61  # determine inputs that are considered to be references
62  # it is those that are not referred to in inputs or initial_cell_inputs
63  known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
64  known_inputs += [str(x[0]) for x in initial_cell_inputs]
65  if timestep is not None:
66  known_inputs.append(str(timestep))
67  references = [
68  core.BlobReference(b) for b in cell_net.Proto().external_input
69  if b not in known_inputs]
70 
71  inner_outputs = list(cell_net.Proto().external_output)
72  # These gradients are expected to be available during the backward pass
73  inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
74 
75  # compute the backward pass of the cell net
76  if not forward_only:
77  backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
78  cell_net.Proto().op, inner_outputs_map)
79  backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}
80 
81  backward_cell_net = core.Net("RecurrentBackwardStep")
82  del backward_cell_net.Proto().op[:]
83 
84  if recompute_blobs_on_backward is not None:
85  # Insert operators to re-compute the specified blobs.
86  # They are added in the same order as for the forward pass, thus
87  # the order is correct.
88  recompute_blobs_on_backward = {str(b) for b in
89  recompute_blobs_on_backward}
90 
91  for op in cell_net.Proto().op:
92  if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
93  backward_cell_net.Proto().op.extend([op])
94  # This fires if other outputs than the declared
95  # are computed by the ops that are recomputed
96  assert set(op.output).issubset(recompute_blobs_on_backward)
97 
98  backward_cell_net.Proto().op.extend(backward_ops)
99  # compute blobs used but not defined in the backward pass
100  backward_ssa, backward_blob_versions = core.get_ssa(
101  backward_cell_net.Proto())
102  undefined = core.get_undefined_blobs(backward_ssa)
103 
104  # also add to the output list the intermediate outputs of fwd_step that
105  # are used by backward.
106  ssa, blob_versions = core.get_ssa(cell_net.Proto())
107  scratches = [
108  blob
109  for blob, ver in viewitems(blob_versions)
110  if (ver > 0 and
111  blob in undefined and
112  blob not in cell_net.Proto().external_output)
113  ]
114  backward_cell_net.Proto().external_input.extend(scratches)
115  backward_cell_net.Proto().type = 'simple'
116  else:
117  backward_cell_net = None
118 
119  all_inputs = [i[1] for i in inputs] + [
120  x[1] for x in initial_cell_inputs] + references
121  all_outputs = []
122 
123  cell_net.Proto().type = 'simple'
124 
125  # Internal arguments used by RecurrentNetwork operator
126 
127  # Links are in the format blob_name, recurrent_states, offset.
128  # In the moment t we know that corresponding data block is at
129  # t + offset position in the recurrent_states tensor
130  forward_links = []
131  backward_links = []
132 
133  # Aliases are used to expose outputs to external world
134  # Format (internal_blob, external_blob, offset)
135  # Negative offset stands for going from the end,
136  # positive - from the beginning
137  aliases = []
138 
139  # States held inputs to the cell net
140  recurrent_states = []
141 
142  for cell_input, _ in initial_cell_inputs:
143  cell_input = str(cell_input)
144  # Recurrent_states is going to be (T + 1) x ...
145  # It stores all inputs and outputs of the cell net over time.
146  # Or their gradients in the case of the backward pass.
147  state = s(cell_input + "_states")
148  states_grad = state + "_grad"
149  cell_output = links[str(cell_input)]
150  forward_links.append((cell_input, state, 0))
151  forward_links.append((cell_output, state, 1))
152 
153  aliases.append((state, cell_output + "_all", 1))
154  aliases.append((state, cell_output + "_last", -1))
155  all_outputs.extend([cell_output + "_all", cell_output + "_last"])
156 
157  recurrent_states.append(state)
158 
159  if backward_cell_net is not None:
160  backward_links.append((cell_output + "_grad", states_grad, 1))
161  backward_cell_net.Proto().external_input.append(
162  str(cell_output) + "_grad")
163 
164  recurrent_input_grad = cell_input + "_grad"
165  if not backward_blob_versions.get(recurrent_input_grad, 0):
166  # If nobody writes to this recurrent input gradient, we need
167  # to make sure it gets to the states grad blob after all.
168  # We do this by using backward_links which triggers an alias
169  # This logic is being used for example in a SumOp case
170  backward_links.append(
171  (backward_mapping[cell_input], states_grad, 0))
172  else:
173  backward_links.append((recurrent_input_grad, states_grad, 0))
174 
175 
176  for input_t, input_blob in inputs:
177  forward_links.append((str(input_t), str(input_blob), 0))
178 
179  if backward_cell_net is not None:
180  for input_t, input_blob in inputs:
181  backward_links.append((
182  backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
183  ))
184  backward_cell_net.Proto().external_input.extend(
185  cell_net.Proto().external_input)
186  backward_cell_net.Proto().external_input.extend(
187  cell_net.Proto().external_output)
188 
189  def unpack_triple(x):
190  if x:
191  a, b, c = zip(*x)
192  return a, b, c
193  return [], [], []
194 
195  # Splitting to separate lists so we can pass them to c++
196  # where we ensemle them back
197  link_internal, link_external, link_offset = unpack_triple(forward_links)
198  alias_src, alias_dst, alias_offset = unpack_triple(aliases)
199 
200  recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
201 
202  # Make sure that recurrent gradients accumulate with internal gradients
203  # (if a blob in the backward_cell_net receives gradient from both an
204  # external connection as well as from within the backward_cell_net,
205  # those gradients need to be added together, rather than one overwriting
206  # the other)
207  if backward_cell_net is not None:
208  proto = backward_cell_net.Proto()
209  operators = []
210  while len(proto.op) > 0:
211  op = proto.op[-1]
212  proto.op.remove(op)
213  operators.append(op)
214  for op in operators[::-1]:
215  proto.op.extend([op])
216  for j, output_blob in enumerate(op.output):
217  if output_blob in proto.external_input:
218  # In place operation won't cause issues because it takes
219  # existing value of a blob into account
220  if output_blob in op.input:
221  continue
222  output_blob = core.BlobReference(output_blob)
223  accum_blob = output_blob + "_accum"
224  proto.op[-1].output[j] = str(accum_blob)
225  backward_cell_net.Sum(
226  [output_blob, accum_blob],
227  [output_blob],
228  )
229 
230  def map_to_dual_list(m):
231  return [str(x) for x in list(m.keys())] + \
232  [str(x) for x in list(m.values())]
233 
234  backward_args = {}
235  if backward_cell_net is not None:
236  backward_mapping_keys = set(viewkeys(backward_mapping))
237  backward_link_internal, backward_link_external, backward_link_offset = \
238  unpack_triple(backward_links)
239  params = [x for x in references if x in backward_mapping_keys]
240  param_grads = [
241  str(backward_mapping[x])
242  for x in references
243  if x in backward_mapping_keys
244  ]
245  if recompute_blobs_on_backward is None:
246  recompute_blobs_on_backward = set()
247  backward_args = {
248  'param': [all_inputs.index(p) for p in params],
249  'backward_link_internal': [str(l) for l in backward_link_internal],
250  'backward_link_external': [str(l) for l in backward_link_external],
251  'backward_link_offset': backward_link_offset,
252  'outputs_with_grads': outputs_with_grads,
253  'recompute_blobs_on_backward': [
254  str(b) for b in recompute_blobs_on_backward
255  ],
256  'param_grads': param_grads,
257  }
258  if len(backward_cell_net.Proto().op) != 0:
259  backward_args['backward_step_net'] = backward_cell_net.Proto()
260 
261 
262  results = net.RecurrentNetwork(
263  all_inputs,
264  all_outputs + [s("step_workspaces")],
265  alias_src=alias_src,
266  alias_dst=[str(a) for a in alias_dst],
267  alias_offset=alias_offset,
268  recurrent_states=recurrent_states,
269  initial_recurrent_state_ids=[
270  all_inputs.index(i) for i in recurrent_inputs
271  ],
272  link_internal=[str(l) for l in link_internal],
273  link_external=[str(l) for l in link_external],
274  link_offset=link_offset,
275  enable_rnn_executor=1,
276  step_net=cell_net.Proto(),
277  timestep="timestep" if timestep is None else str(timestep),
278  **backward_args
279  )
280 
281  # Restore net type since 'rnn' is not recognized outside RNNs
282  cell_net.Proto().type = 'simple'
283 
284  # The last output is a list of step workspaces,
285  # which is only needed internally for gradient propogation
286  return results[:-1]
287 
288 
289 def set_rnn_executor_config(rnn_op, num_threads=None, max_cuda_streams=None):
290  from caffe2.proto import caffe2_pb2
291  assert rnn_op.type in {'RecurrentNetwork', 'RecurrentNetworkGradient'}
292 
293  def add_arg(s, v):
294  a = caffe2_pb2.Argument()
295  a.name = "rnn_executor." + s
296  a.i = v
297  rnn_op.arg.extend([a])
298 
299  if num_threads is not None:
300  add_arg('num_threads', num_threads)
301  if max_cuda_streams is not None:
302  add_arg('max_cuda_streams', max_cuda_streams)
303 
304 
305 def retrieve_step_blobs(net, prefix='rnn'):
306  '''
307  Retrieves blobs from step workspaces (which contain intermediate recurrent
308  network computation for each timestep) and puts them in the global
309  workspace. This allows access to the contents of this intermediate
310  computation in python. Returns the list of extracted blob names.
311 
312  net: the net from which the step workspace blobs should be extracted
313 
314  prefix: prefix to append to extracted blob names when placing them in the
315  global workspace
316  '''
317  count = 1
318  output_list = []
319  for op in net.Proto().op:
320  if op.type == "RecurrentNetwork":
321  blob_name = prefix + "_" + str(count)
322  count = count + 1
323  scratch_workspaces_blob_name = op.output[-1]
324  workspace.RunOperatorOnce(
325  core.CreateOperator(
326  "RecurrentNetworkBlobFetcher",
327  [scratch_workspaces_blob_name],
328  [blob_name],
329  prefix=prefix
330  )
331  )
332  output_list += workspace.FetchBlob(blob_name).tolist()
333  return output_list