3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from future.utils
import viewitems, viewkeys
12 net, cell_net, inputs, initial_cell_inputs,
13 links, timestep=
None, scope=
None, outputs_with_grads=(0,),
14 recompute_blobs_on_backward=
None, forward_only=
False,
17 net: the main net operator should be added to 19 cell_net: cell_net which is executed in a recurrent fasion 21 inputs: sequences to be fed into the recurrent net. Currently only one input 22 is supported. It has to be in a format T x N x (D1...Dk) where T is lengths 23 of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions 25 initial_cell_inputs: inputs of the cell_net for the 0 timestamp. 26 Format for each input is: 27 (cell_net_input_name, external_blob_with_data) 29 links: a dictionary from cell_net input names in moment t+1 and 30 output names of moment t. Currently we assume that each output becomes 31 an input for the next timestep. 33 timestep: name of the timestep blob to be used. If not provided "timestep" 36 scope: Internal blobs are going to be scoped in a format 37 <scope_name>/<blob_name> 38 If not provided we generate a scope name automatically 40 outputs_with_grads : position indices of output blobs which will receive 41 error gradient (from outside recurrent network) during backpropagation 43 recompute_blobs_on_backward: specify a list of blobs that will be 44 recomputed for backward pass, and thus need not to be 45 stored for each forward timestep. 47 forward_only: if True, only forward steps are executed 49 assert len(inputs) == 1,
"Only one input blob is supported so far" 51 input_blobs = [str(i[0])
for i
in inputs]
52 initial_input_blobs = [str(x[1])
for x
in initial_cell_inputs]
53 op_name = net.NextName(
'recurrent')
58 scope_name = op_name
if scope
is None else scope
59 return "{}/{}".format(str(scope_name), str(name))
63 known_inputs = [str(b)
for b
in input_blobs + initial_input_blobs]
64 known_inputs += [str(x[0])
for x
in initial_cell_inputs]
65 if timestep
is not None:
66 known_inputs.append(str(timestep))
68 core.BlobReference(b)
for b
in cell_net.Proto().external_input
69 if b
not in known_inputs]
71 inner_outputs = list(cell_net.Proto().external_output)
73 inner_outputs_map = {o: o +
'_grad' for o
in inner_outputs}
77 backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
78 cell_net.Proto().op, inner_outputs_map)
79 backward_mapping = {str(k): v
for k, v
in viewitems(backward_mapping)}
81 backward_cell_net = core.Net(
"RecurrentBackwardStep")
82 del backward_cell_net.Proto().op[:]
84 if recompute_blobs_on_backward
is not None:
88 recompute_blobs_on_backward = {str(b)
for b
in 89 recompute_blobs_on_backward}
91 for op
in cell_net.Proto().op:
92 if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
93 backward_cell_net.Proto().op.extend([op])
96 assert set(op.output).issubset(recompute_blobs_on_backward)
98 backward_cell_net.Proto().op.extend(backward_ops)
100 backward_ssa, backward_blob_versions = core.get_ssa(
101 backward_cell_net.Proto())
102 undefined = core.get_undefined_blobs(backward_ssa)
106 ssa, blob_versions = core.get_ssa(cell_net.Proto())
109 for blob, ver
in viewitems(blob_versions)
111 blob
in undefined
and 112 blob
not in cell_net.Proto().external_output)
114 backward_cell_net.Proto().external_input.extend(scratches)
115 backward_cell_net.Proto().type =
'simple' 117 backward_cell_net =
None 119 all_inputs = [i[1]
for i
in inputs] + [
120 x[1]
for x
in initial_cell_inputs] + references
123 cell_net.Proto().type =
'simple' 140 recurrent_states = []
142 for cell_input, _
in initial_cell_inputs:
143 cell_input = str(cell_input)
147 state = s(cell_input +
"_states")
148 states_grad = state +
"_grad" 149 cell_output = links[str(cell_input)]
150 forward_links.append((cell_input, state, 0))
151 forward_links.append((cell_output, state, 1))
153 aliases.append((state, cell_output +
"_all", 1))
154 aliases.append((state, cell_output +
"_last", -1))
155 all_outputs.extend([cell_output +
"_all", cell_output +
"_last"])
157 recurrent_states.append(state)
159 if backward_cell_net
is not None:
160 backward_links.append((cell_output +
"_grad", states_grad, 1))
161 backward_cell_net.Proto().external_input.append(
162 str(cell_output) +
"_grad")
164 recurrent_input_grad = cell_input +
"_grad" 165 if not backward_blob_versions.get(recurrent_input_grad, 0):
170 backward_links.append(
171 (backward_mapping[cell_input], states_grad, 0))
173 backward_links.append((recurrent_input_grad, states_grad, 0))
176 for input_t, input_blob
in inputs:
177 forward_links.append((str(input_t), str(input_blob), 0))
179 if backward_cell_net
is not None:
180 for input_t, input_blob
in inputs:
181 backward_links.append((
182 backward_mapping[str(input_t)], str(input_blob) +
"_grad", 0
184 backward_cell_net.Proto().external_input.extend(
185 cell_net.Proto().external_input)
186 backward_cell_net.Proto().external_input.extend(
187 cell_net.Proto().external_output)
189 def unpack_triple(x):
197 link_internal, link_external, link_offset = unpack_triple(forward_links)
198 alias_src, alias_dst, alias_offset = unpack_triple(aliases)
200 recurrent_inputs = [str(x[1])
for x
in initial_cell_inputs]
207 if backward_cell_net
is not None:
208 proto = backward_cell_net.Proto()
210 while len(proto.op) > 0:
214 for op
in operators[::-1]:
215 proto.op.extend([op])
216 for j, output_blob
in enumerate(op.output):
217 if output_blob
in proto.external_input:
220 if output_blob
in op.input:
222 output_blob = core.BlobReference(output_blob)
223 accum_blob = output_blob +
"_accum" 224 proto.op[-1].output[j] = str(accum_blob)
225 backward_cell_net.Sum(
226 [output_blob, accum_blob],
230 def map_to_dual_list(m):
231 return [str(x)
for x
in list(m.keys())] + \
232 [str(x)
for x
in list(m.values())]
235 if backward_cell_net
is not None:
236 backward_mapping_keys = set(viewkeys(backward_mapping))
237 backward_link_internal, backward_link_external, backward_link_offset = \
238 unpack_triple(backward_links)
239 params = [x
for x
in references
if x
in backward_mapping_keys]
241 str(backward_mapping[x])
243 if x
in backward_mapping_keys
245 if recompute_blobs_on_backward
is None:
246 recompute_blobs_on_backward = set()
248 'param': [all_inputs.index(p)
for p
in params],
249 'backward_link_internal': [str(l)
for l
in backward_link_internal],
250 'backward_link_external': [str(l)
for l
in backward_link_external],
251 'backward_link_offset': backward_link_offset,
252 'outputs_with_grads': outputs_with_grads,
253 'recompute_blobs_on_backward': [
254 str(b)
for b
in recompute_blobs_on_backward
256 'param_grads': param_grads,
258 if len(backward_cell_net.Proto().op) != 0:
259 backward_args[
'backward_step_net'] = backward_cell_net.Proto()
262 results = net.RecurrentNetwork(
264 all_outputs + [s(
"step_workspaces")],
266 alias_dst=[str(a)
for a
in alias_dst],
267 alias_offset=alias_offset,
268 recurrent_states=recurrent_states,
269 initial_recurrent_state_ids=[
270 all_inputs.index(i)
for i
in recurrent_inputs
272 link_internal=[str(l)
for l
in link_internal],
273 link_external=[str(l)
for l
in link_external],
274 link_offset=link_offset,
275 enable_rnn_executor=1,
276 step_net=cell_net.Proto(),
277 timestep=
"timestep" if timestep
is None else str(timestep),
282 cell_net.Proto().type =
'simple' 289 def set_rnn_executor_config(rnn_op, num_threads=None, max_cuda_streams=None):
290 from caffe2.proto
import caffe2_pb2
291 assert rnn_op.type
in {
'RecurrentNetwork',
'RecurrentNetworkGradient'}
294 a = caffe2_pb2.Argument()
295 a.name =
"rnn_executor." + s
297 rnn_op.arg.extend([a])
299 if num_threads
is not None:
300 add_arg(
'num_threads', num_threads)
301 if max_cuda_streams
is not None:
302 add_arg(
'max_cuda_streams', max_cuda_streams)
305 def retrieve_step_blobs(net, prefix='rnn'):
307 Retrieves blobs from step workspaces (which contain intermediate recurrent 308 network computation for each timestep) and puts them in the global 309 workspace. This allows access to the contents of this intermediate 310 computation in python. Returns the list of extracted blob names. 312 net: the net from which the step workspace blobs should be extracted 314 prefix: prefix to append to extracted blob names when placing them in the 319 for op
in net.Proto().op:
320 if op.type ==
"RecurrentNetwork":
321 blob_name = prefix +
"_" + str(count)
323 scratch_workspaces_blob_name = op.output[-1]
324 workspace.RunOperatorOnce(
326 "RecurrentNetworkBlobFetcher",
327 [scratch_workspaces_blob_name],
332 output_list += workspace.FetchBlob(blob_name).tolist()