Caffe2 - Python API
A deep learning, cross platform ML framework
recurrent.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package recurrent
17 # Module caffe2.python.recurrent
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, workspace
24 from future.utils import viewitems, viewkeys
25 
26 def recurrent_net(
27  net, cell_net, inputs, initial_cell_inputs,
28  links, timestep=None, scope=None, outputs_with_grads=(0,),
29  recompute_blobs_on_backward=None, forward_only=False,
30 ):
31  '''
32  net: the main net operator should be added to
33 
34  cell_net: cell_net which is executed in a recurrent fasion
35 
36  inputs: sequences to be fed into the recurrent net. Currently only one input
37  is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
38  of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
39 
40  initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
41  Format for each input is:
42  (cell_net_input_name, external_blob_with_data)
43 
44  links: a dictionary from cell_net input names in moment t+1 and
45  output names of moment t. Currently we assume that each output becomes
46  an input for the next timestep.
47 
48  timestep: name of the timestep blob to be used. If not provided "timestep"
49  is used.
50 
51  scope: Internal blobs are going to be scoped in a format
52  <scope_name>/<blob_name>
53  If not provided we generate a scope name automatically
54 
55  outputs_with_grads : position indices of output blobs which will receive
56  error gradient (from outside recurrent network) during backpropagation
57 
58  recompute_blobs_on_backward: specify a list of blobs that will be
59  recomputed for backward pass, and thus need not to be
60  stored for each forward timestep.
61 
62  forward_only: if True, only forward steps are executed
63  '''
64  assert len(inputs) == 1, "Only one input blob is supported so far"
65 
66  input_blobs = [str(i[0]) for i in inputs]
67  initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
68  op_name = net.NextName('recurrent')
69 
70  def s(name):
71  # We have to manually scope due to our internal/external blob
72  # relationships.
73  scope_name = op_name if scope is None else scope
74  return "{}/{}".format(str(scope_name), str(name))
75 
76  # determine inputs that are considered to be references
77  # it is those that are not referred to in inputs or initial_cell_inputs
78  known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
79  known_inputs += [str(x[0]) for x in initial_cell_inputs]
80  if timestep is not None:
81  known_inputs.append(str(timestep))
82  references = [
83  core.BlobReference(b) for b in cell_net.Proto().external_input
84  if b not in known_inputs]
85 
86  inner_outputs = list(cell_net.Proto().external_output)
87  # These gradients are expected to be available during the backward pass
88  inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
89 
90  # compute the backward pass of the cell net
91  if not forward_only:
92  backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
93  cell_net.Proto().op, inner_outputs_map)
94  backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}
95 
96  backward_cell_net = core.Net("RecurrentBackwardStep")
97  del backward_cell_net.Proto().op[:]
98 
99  if recompute_blobs_on_backward is not None:
100  # Insert operators to re-compute the specified blobs.
101  # They are added in the same order as for the forward pass, thus
102  # the order is correct.
103  recompute_blobs_on_backward = {str(b) for b in
104  recompute_blobs_on_backward}
105 
106  for op in cell_net.Proto().op:
107  if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
108  backward_cell_net.Proto().op.extend([op])
109  # This fires if other outputs than the declared
110  # are computed by the ops that are recomputed
111  assert set(op.output).issubset(recompute_blobs_on_backward)
112 
113  backward_cell_net.Proto().op.extend(backward_ops)
114  # compute blobs used but not defined in the backward pass
115  backward_ssa, backward_blob_versions = core.get_ssa(
116  backward_cell_net.Proto())
117  undefined = core.get_undefined_blobs(backward_ssa)
118 
119  # also add to the output list the intermediate outputs of fwd_step that
120  # are used by backward.
121  ssa, blob_versions = core.get_ssa(cell_net.Proto())
122  scratches = [
123  blob
124  for blob, ver in viewitems(blob_versions)
125  if (ver > 0 and
126  blob in undefined and
127  blob not in cell_net.Proto().external_output)
128  ]
129  backward_cell_net.Proto().external_input.extend(scratches)
130  backward_cell_net.Proto().type = 'simple'
131  else:
132  backward_cell_net = None
133 
134  all_inputs = [i[1] for i in inputs] + [
135  x[1] for x in initial_cell_inputs] + references
136  all_outputs = []
137 
138  cell_net.Proto().type = 'simple'
139 
140  # Internal arguments used by RecurrentNetwork operator
141 
142  # Links are in the format blob_name, recurrent_states, offset.
143  # In the moment t we know that corresponding data block is at
144  # t + offset position in the recurrent_states tensor
145  forward_links = []
146  backward_links = []
147 
148  # Aliases are used to expose outputs to external world
149  # Format (internal_blob, external_blob, offset)
150  # Negative offset stands for going from the end,
151  # positive - from the beginning
152  aliases = []
153 
154  # States held inputs to the cell net
155  recurrent_states = []
156 
157  for cell_input, _ in initial_cell_inputs:
158  cell_input = str(cell_input)
159  # Recurrent_states is going to be (T + 1) x ...
160  # It stores all inputs and outputs of the cell net over time.
161  # Or their gradients in the case of the backward pass.
162  state = s(cell_input + "_states")
163  states_grad = state + "_grad"
164  cell_output = links[str(cell_input)]
165  forward_links.append((cell_input, state, 0))
166  forward_links.append((cell_output, state, 1))
167 
168  aliases.append((state, cell_output + "_all", 1))
169  aliases.append((state, cell_output + "_last", -1))
170  all_outputs.extend([cell_output + "_all", cell_output + "_last"])
171 
172  recurrent_states.append(state)
173 
174  if backward_cell_net is not None:
175  backward_links.append((cell_output + "_grad", states_grad, 1))
176  backward_cell_net.Proto().external_input.append(
177  str(cell_output) + "_grad")
178 
179  recurrent_input_grad = cell_input + "_grad"
180  if not backward_blob_versions.get(recurrent_input_grad, 0):
181  # If nobody writes to this recurrent input gradient, we need
182  # to make sure it gets to the states grad blob after all.
183  # We do this by using backward_links which triggers an alias
184  # This logic is being used for example in a SumOp case
185  backward_links.append(
186  (backward_mapping[cell_input], states_grad, 0))
187  else:
188  backward_links.append((recurrent_input_grad, states_grad, 0))
189 
190 
191  for input_t, input_blob in inputs:
192  forward_links.append((str(input_t), str(input_blob), 0))
193 
194  if backward_cell_net is not None:
195  for input_t, input_blob in inputs:
196  backward_links.append((
197  backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
198  ))
199  backward_cell_net.Proto().external_input.extend(
200  cell_net.Proto().external_input)
201  backward_cell_net.Proto().external_input.extend(
202  cell_net.Proto().external_output)
203 
204  def unpack_triple(x):
205  if x:
206  a, b, c = zip(*x)
207  return a, b, c
208  return [], [], []
209 
210  # Splitting to separate lists so we can pass them to c++
211  # where we ensemle them back
212  link_internal, link_external, link_offset = unpack_triple(forward_links)
213  alias_src, alias_dst, alias_offset = unpack_triple(aliases)
214 
215  recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
216 
217  # Make sure that recurrent gradients accumulate with internal gradients
218  # (if a blob in the backward_cell_net receives gradient from both an
219  # external connection as well as from within the backward_cell_net,
220  # those gradients need to be added together, rather than one overwriting
221  # the other)
222  if backward_cell_net is not None:
223  proto = backward_cell_net.Proto()
224  operators = []
225  while len(proto.op) > 0:
226  op = proto.op[-1]
227  proto.op.remove(op)
228  operators.append(op)
229  for op in operators[::-1]:
230  proto.op.extend([op])
231  for j, output_blob in enumerate(op.output):
232  if output_blob in proto.external_input:
233  # In place operation won't cause issues because it takes
234  # existing value of a blob into account
235  if output_blob in op.input:
236  continue
237  output_blob = core.BlobReference(output_blob)
238  accum_blob = output_blob + "_accum"
239  proto.op[-1].output[j] = str(accum_blob)
240  backward_cell_net.Sum(
241  [output_blob, accum_blob],
242  [output_blob],
243  )
244 
245  def map_to_dual_list(m):
246  return [str(x) for x in list(m.keys())] + \
247  [str(x) for x in list(m.values())]
248 
249  backward_args = {}
250  if backward_cell_net is not None:
251  backward_mapping_keys = set(viewkeys(backward_mapping))
252  backward_link_internal, backward_link_external, backward_link_offset = \
253  unpack_triple(backward_links)
254  params = [x for x in references if x in backward_mapping_keys]
255  param_grads = [
256  str(backward_mapping[x])
257  for x in references
258  if x in backward_mapping_keys
259  ]
260  if recompute_blobs_on_backward is None:
261  recompute_blobs_on_backward = set()
262  backward_args = {
263  'param': [all_inputs.index(p) for p in params],
264  'backward_link_internal': [str(l) for l in backward_link_internal],
265  'backward_link_external': [str(l) for l in backward_link_external],
266  'backward_link_offset': backward_link_offset,
267  'outputs_with_grads': outputs_with_grads,
268  'recompute_blobs_on_backward': [
269  str(b) for b in recompute_blobs_on_backward
270  ],
271  'param_grads': param_grads,
272  }
273  if len(backward_cell_net.Proto().op) != 0:
274  backward_args['backward_step_net'] = backward_cell_net.Proto()
275 
276 
277  results = net.RecurrentNetwork(
278  all_inputs,
279  all_outputs + [s("step_workspaces")],
280  alias_src=alias_src,
281  alias_dst=[str(a) for a in alias_dst],
282  alias_offset=alias_offset,
283  recurrent_states=recurrent_states,
284  initial_recurrent_state_ids=[
285  all_inputs.index(i) for i in recurrent_inputs
286  ],
287  link_internal=[str(l) for l in link_internal],
288  link_external=[str(l) for l in link_external],
289  link_offset=link_offset,
290  enable_rnn_executor=1,
291  step_net=cell_net.Proto(),
292  timestep="timestep" if timestep is None else str(timestep),
293  **backward_args
294  )
295 
296  # Restore net type since 'rnn' is not recognized outside RNNs
297  cell_net.Proto().type = 'simple'
298 
299  # The last output is a list of step workspaces,
300  # which is only needed internally for gradient propogation
301  return results[:-1]
302 
303 
304 def set_rnn_executor_config(rnn_op, num_threads=None, max_cuda_streams=None):
305  from caffe2.proto import caffe2_pb2
306  assert rnn_op.type in {'RecurrentNetwork', 'RecurrentNetworkGradient'}
307 
308  def add_arg(s, v):
309  a = caffe2_pb2.Argument()
310  a.name = "rnn_executor." + s
311  a.i = v
312  rnn_op.arg.extend([a])
313 
314  if num_threads is not None:
315  add_arg('num_threads', num_threads)
316  if max_cuda_streams is not None:
317  add_arg('max_cuda_streams', max_cuda_streams)
318 
319 
320 def retrieve_step_blobs(net, prefix='rnn'):
321  '''
322  Retrieves blobs from step workspaces (which contain intermediate recurrent
323  network computation for each timestep) and puts them in the global
324  workspace. This allows access to the contents of this intermediate
325  computation in python. Returns the list of extracted blob names.
326 
327  net: the net from which the step workspace blobs should be extracted
328 
329  prefix: prefix to append to extracted blob names when placing them in the
330  global workspace
331  '''
332  count = 1
333  output_list = []
334  for op in net.Proto().op:
335  if op.type == "RecurrentNetwork":
336  blob_name = prefix + "_" + str(count)
337  count = count + 1
338  scratch_workspaces_blob_name = op.output[-1]
339  workspace.RunOperatorOnce(
340  core.CreateOperator(
341  "RecurrentNetworkBlobFetcher",
342  [scratch_workspaces_blob_name],
343  [blob_name],
344  prefix=prefix
345  )
346  )
347  output_list += workspace.FetchBlob(blob_name).tolist()
348  return output_list