Caffe2 - Python API
A deep learning, cross platform ML framework
control_ops_util.py
1 ## @package control_ops_util
2 # Module caffe2.python.control_ops_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core
9 
10 
11 def get_external_blob_names(net, lexical_scope):
12  """
13  Returns a set of blobs a given net depends on and a set of
14  output blobs that are written by the net
15  Inputs:
16  net - net to return input/output blobs for;
17  lexical_scope - all external blob names visible to the net
18  """
19  # Use the blobs that are actually read/written to as external inputs/outputs
20  net_proto = net.Proto()
21  net_ssa, _ = core.get_ssa(net_proto)
22  input_names = core.get_undefined_blobs(net_ssa)
23  for input_name in input_names:
24  assert str(input_name) in lexical_scope, \
25  "Input blob " + input_name + " is undefined"
26 
27  output_names = set()
28  for op in net_proto.op:
29  for output in op.output:
30  if output in lexical_scope:
31  output_names.add(output)
32 
33  return input_names, output_names
34 
35 
36 def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None):
37  """
38  A helper function to add an If op to the net.
39  Automatically determines whether blobs in the then/else subnets are external
40  (from the outer workspace) or local (visible only inside subnet's workspace)
41  based on lexical scope - set of all outer blob names visible to the 'If'
42  operator. All the blobs in then/else subnets with names matching a name in lexical
43  scope and all the blobs that are first used as the operators' inputs are
44  considered outer blobs - these blobs must exist in the outer workspace,
45  then/else subnets can read their values and new values written into these blobs
46  will be visible outside of the 'If' operator. All other blobs are local - exist
47  only within inner workspaces for then/else.
48  Inputs:
49  if_net - net to add an If op to;
50  cond_blob - scalar bool blob reference, used as If condition;
51  lexical_scope - a set of outer blob names visible to then/else branches;
52  then_net/else_net - nets (core.Net) for then/else branches
53  """
54  then_input_blob_names, then_output_blob_names = get_external_blob_names(
55  then_net, lexical_scope)
56 
57  else_input_blob_names = set()
58  else_output_blob_names = set()
59  if else_net:
60  else_input_blob_names, else_output_blob_names = get_external_blob_names(
61  else_net, lexical_scope)
62 
63  input_blob_names = then_input_blob_names | else_input_blob_names
64  output_blob_names = then_output_blob_names | else_output_blob_names
65 
66  if_inputs = [cond_blob]
67  if_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names]
68  if_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
69 
70  do_then_net = core.Net('do_then_net')
71 
72  then_input_blobs = \
73  [core.BlobReference(name=b, net=None) for b in then_input_blob_names]
74  then_output_blobs = \
75  [core.BlobReference(name=b, net=None) for b in then_output_blob_names]
76  then_input_output_names_ordered = [
77  str(b) for b in (then_input_blobs + then_output_blobs)]
78 
79  then_outer_blob_names = list(then_input_blob_names | then_output_blob_names)
80  then_outer_blob_names_idx = [
81  then_input_output_names_ordered.index(b) for b in then_outer_blob_names]
82 
83  # make sure to use net's name to have unique blob name across multiple subnets
84  do_then_workspace_blob = if_net.NextScopedBlob(if_net.Name() + '/workspace_if_then')
85  then_input_blobs.append(do_then_workspace_blob)
86  then_output_blobs.append(do_then_workspace_blob)
87  # make sure that added workspace pointer blobs are in if inputs/outputs
88  if_inputs.append(do_then_workspace_blob)
89  if_outputs.append(do_then_workspace_blob)
90 
91  do_then_net.Do(
92  then_input_blobs,
93  then_output_blobs,
94  net=then_net.Proto(),
95  inner_blobs=then_outer_blob_names,
96  outer_blobs_idx=then_outer_blob_names_idx)
97  do_then_net.AddExternalOutput(*then_output_blobs)
98 
99  if_args = {}
100  if_args['then_net'] = do_then_net.Proto()
101 
102  do_else_workspace_blob = None
103  if else_net:
104  do_else_net = core.Net('do_else_net')
105 
106  else_input_blobs = \
107  [core.BlobReference(name=b, net=None) for b in else_input_blob_names]
108  else_output_blobs = \
109  [core.BlobReference(name=b, net=None) for b in else_output_blob_names]
110  else_input_output_names_ordered = [
111  str(b) for b in (else_input_blobs + else_output_blobs)]
112 
113  else_outer_blob_names = list(else_input_blob_names | else_output_blob_names)
114  else_outer_blob_names_idx = [
115  else_input_output_names_ordered.index(b) for b in else_outer_blob_names]
116 
117  do_else_workspace_blob = \
118  if_net.NextScopedBlob(if_net.Name() + '/workspace_if_else')
119  else_input_blobs.append(do_else_workspace_blob)
120  else_output_blobs.append(do_else_workspace_blob)
121  # make sure that added workspace pointer blobs are in if inputs/outputs
122  if_inputs.append(do_else_workspace_blob)
123  if_outputs.append(do_else_workspace_blob)
124 
125  do_else_net.Do(
126  else_input_blobs,
127  else_output_blobs,
128  net=else_net.Proto(),
129  inner_blobs=else_outer_blob_names,
130  outer_blobs_idx=else_outer_blob_names_idx)
131  do_else_net.AddExternalOutput(*else_output_blobs)
132  if_args['else_net'] = do_else_net.Proto()
133 
134  if_net.CreateScope([], [do_then_workspace_blob])
135  if do_else_workspace_blob:
136  if_net.CreateScope([], [do_else_workspace_blob])
137  if_net.If(if_inputs, if_outputs, **if_args)
138  if_net.AddExternalOutput(*if_outputs)
139 
140 
141 def add_while_op(
142  while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=None):
143  """
144  A helper function to add a While op to the net. Same rules for determining
145  outer and inner blobs as for the 'If' operator apply for the 'While' operator
146  loop and condition subnets. If specified, condition net is executed in a separate
147  workspace before the first and after each iteration, the last operator must have
148  a single scalar boolean output that is written into the condition blob.
149  Inputs:
150  while_net - net to add a While op to;
151  cond_blob - scalar bool blob reference, used as a stop condition;
152  lexical_scope - a set of outer blob names visible to the loop's body;
153  loop_body_net - net to execute on each iteration;
154  condition_body_net - net to compute condition value
155  """
156  input_blob_names, output_blob_names = get_external_blob_names(
157  loop_body_net, lexical_scope)
158 
159  # Since it's possible that loop is not going to run even once
160  # we have to add loop's external outputs into inputs
161  input_blob_names |= output_blob_names
162 
163  loop_inputs = [core.BlobReference(name=b, net=None) for b in input_blob_names]
164  loop_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
165 
166  while_inputs = [cond_blob] + loop_inputs
167  while_outputs = [] + loop_outputs
168 
169  do_loop_body_net = core.Net('do_loop_body_net')
170 
171  loop_input_output_names_ordered = [
172  str(b) for b in (loop_inputs + loop_outputs)]
173  loop_body_outer_blob_names = list(input_blob_names | output_blob_names)
174  loop_body_outer_blob_names_idx = [
175  loop_input_output_names_ordered.index(b) for b in loop_body_outer_blob_names]
176 
177  do_loop_body_workspace_blob = \
178  while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_body')
179 
180  loop_inputs.append(do_loop_body_workspace_blob)
181  loop_outputs.append(do_loop_body_workspace_blob)
182  # make sure that added workspace pointer blobs are in While inputs/outputs
183  while_inputs.append(do_loop_body_workspace_blob)
184  while_outputs.append(do_loop_body_workspace_blob)
185 
186  do_loop_body_net.Do(
187  loop_inputs,
188  loop_outputs,
189  net=loop_body_net.Proto(),
190  inner_blobs=loop_body_outer_blob_names,
191  outer_blobs_idx=loop_body_outer_blob_names_idx,
192  copy_external_blobs=True)
193  do_loop_body_net.AddExternalOutput(*loop_outputs)
194 
195  while_args = {}
196  while_args['loop_net'] = do_loop_body_net.Proto()
197 
198  cond_workspace_blob = None
199  if condition_body_net:
200  cond_input_blob_names, cond_output_blob_names = get_external_blob_names(
201  condition_body_net, lexical_scope)
202 
203  # make sure condition blob is written by condition net and is
204  # visible outside of it
205  found_condition_output = False
206  for op in condition_body_net.Proto().op:
207  if str(cond_blob) in op.output:
208  found_condition_output = True
209  break
210  assert found_condition_output, \
211  "Condition net does not write into condition blob"
212  if str(cond_blob) not in cond_output_blob_names:
213  cond_output_blob_names.add(str(cond_blob))
214 
215  cond_inputs = [core.BlobReference(name=b, net=None)
216  for b in cond_input_blob_names]
217  assert str(cond_blob) in cond_output_blob_names, \
218  'Condition blob expected in condition net output'
219  cond_outputs = [core.BlobReference(name=b, net=None)
220  for b in cond_output_blob_names]
221 
222  condition_net = core.Net('do_loop_condition_net')
223 
224  cond_input_output_names_ordered = [
225  str(b) for b in (cond_inputs + cond_outputs)]
226  cond_body_outer_blob_names = \
227  list(cond_input_blob_names | cond_output_blob_names)
228  cond_body_outer_blob_names_idx = [
229  cond_input_output_names_ordered.index(b)
230  for b in cond_body_outer_blob_names]
231 
232  cond_workspace_blob = \
233  while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_cond')
234  cond_inputs.append(cond_workspace_blob)
235  cond_outputs.append(cond_workspace_blob)
236 
237  condition_net.Do(
238  cond_inputs,
239  cond_outputs,
240  net=condition_body_net.Proto(),
241  inner_blobs=cond_body_outer_blob_names,
242  outer_blobs_idx=cond_body_outer_blob_names_idx)
243  condition_net.AddExternalOutput(*cond_outputs)
244 
245  while_args['cond_net'] = condition_net.Proto()
246 
247  while_inputs += [b for b in cond_inputs
248  if str(b) not in input_blob_names]
249  while_outputs += [b for b in cond_outputs
250  if str(b) not in output_blob_names]
251 
252  if str(cond_blob) not in lexical_scope:
253  while_net.ConstantFill(
254  [],
255  cond_blob,
256  dtype=core.DataType.BOOL,
257  value=False)
258 
259  while_net.CreateScope([], [do_loop_body_workspace_blob])
260  if cond_workspace_blob:
261  while_net.CreateScope([], [cond_workspace_blob])
262  while_net.While(while_inputs, while_outputs, **while_args)
263  while_net.AddExternalOutput(*while_outputs)