Caffe2 - Python API
A deep learning, cross platform ML framework
control_ops_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package control_ops_util
17 # Module caffe2.python.control_ops_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core
24 
25 
26 def get_external_blob_names(net, lexical_scope):
27  """
28  Returns a set of blobs a given net depends on and a set of
29  output blobs that are written by the net
30  Inputs:
31  net - net to return input/output blobs for;
32  lexical_scope - all external blob names visible to the net
33  """
34  # Use the blobs that are actually read/written to as external inputs/outputs
35  net_proto = net.Proto()
36  net_ssa, _ = core.get_ssa(net_proto)
37  input_names = core.get_undefined_blobs(net_ssa)
38  for input_name in input_names:
39  assert str(input_name) in lexical_scope, \
40  "Input blob " + input_name + " is undefined"
41 
42  output_names = set()
43  for op in net_proto.op:
44  for output in op.output:
45  if output in lexical_scope:
46  output_names.add(output)
47 
48  return input_names, output_names
49 
50 
51 def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None):
52  """
53  A helper function to add an If op to the net.
54  Automatically determines whether blobs in the then/else subnets are external
55  (from the outer workspace) or local (visible only inside subnet's workspace)
56  based on lexical scope - set of all outer blob names visible to the 'If'
57  operator. All the blobs in then/else subnets with names matching a name in lexical
58  scope and all the blobs that are first used as the operators' inputs are
59  considered outer blobs - these blobs must exist in the outer workspace,
60  then/else subnets can read their values and new values written into these blobs
61  will be visible outside of the 'If' operator. All other blobs are local - exist
62  only within inner workspaces for then/else.
63  Inputs:
64  if_net - net to add an If op to;
65  cond_blob - scalar bool blob reference, used as If condition;
66  lexical_scope - a set of outer blob names visible to then/else branches;
67  then_net/else_net - nets (core.Net) for then/else branches
68  """
69  then_input_blob_names, then_output_blob_names = get_external_blob_names(
70  then_net, lexical_scope)
71 
72  else_input_blob_names = set()
73  else_output_blob_names = set()
74  if else_net:
75  else_input_blob_names, else_output_blob_names = get_external_blob_names(
76  else_net, lexical_scope)
77 
78  input_blob_names = then_input_blob_names | else_input_blob_names
79  output_blob_names = then_output_blob_names | else_output_blob_names
80 
81  if_inputs = [cond_blob]
82  if_inputs += [core.BlobReference(name=b, net=None) for b in input_blob_names]
83  if_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
84 
85  do_then_net = core.Net('do_then_net')
86 
87  then_input_blobs = \
88  [core.BlobReference(name=b, net=None) for b in then_input_blob_names]
89  then_output_blobs = \
90  [core.BlobReference(name=b, net=None) for b in then_output_blob_names]
91  then_input_output_names_ordered = [
92  str(b) for b in (then_input_blobs + then_output_blobs)]
93 
94  then_outer_blob_names = list(then_input_blob_names | then_output_blob_names)
95  then_outer_blob_names_idx = [
96  then_input_output_names_ordered.index(b) for b in then_outer_blob_names]
97 
98  # make sure to use net's name to have unique blob name across multiple subnets
99  do_then_workspace_blob = if_net.NextScopedBlob(if_net.Name() + '/workspace_if_then')
100  then_input_blobs.append(do_then_workspace_blob)
101  then_output_blobs.append(do_then_workspace_blob)
102  # make sure that added workspace pointer blobs are in if inputs/outputs
103  if_inputs.append(do_then_workspace_blob)
104  if_outputs.append(do_then_workspace_blob)
105 
106  do_then_net.Do(
107  then_input_blobs,
108  then_output_blobs,
109  net=then_net.Proto(),
110  inner_blobs=then_outer_blob_names,
111  outer_blobs_idx=then_outer_blob_names_idx)
112  do_then_net.AddExternalOutput(*then_output_blobs)
113 
114  if_args = {}
115  if_args['then_net'] = do_then_net.Proto()
116 
117  do_else_workspace_blob = None
118  if else_net:
119  do_else_net = core.Net('do_else_net')
120 
121  else_input_blobs = \
122  [core.BlobReference(name=b, net=None) for b in else_input_blob_names]
123  else_output_blobs = \
124  [core.BlobReference(name=b, net=None) for b in else_output_blob_names]
125  else_input_output_names_ordered = [
126  str(b) for b in (else_input_blobs + else_output_blobs)]
127 
128  else_outer_blob_names = list(else_input_blob_names | else_output_blob_names)
129  else_outer_blob_names_idx = [
130  else_input_output_names_ordered.index(b) for b in else_outer_blob_names]
131 
132  do_else_workspace_blob = \
133  if_net.NextScopedBlob(if_net.Name() + '/workspace_if_else')
134  else_input_blobs.append(do_else_workspace_blob)
135  else_output_blobs.append(do_else_workspace_blob)
136  # make sure that added workspace pointer blobs are in if inputs/outputs
137  if_inputs.append(do_else_workspace_blob)
138  if_outputs.append(do_else_workspace_blob)
139 
140  do_else_net.Do(
141  else_input_blobs,
142  else_output_blobs,
143  net=else_net.Proto(),
144  inner_blobs=else_outer_blob_names,
145  outer_blobs_idx=else_outer_blob_names_idx)
146  do_else_net.AddExternalOutput(*else_output_blobs)
147  if_args['else_net'] = do_else_net.Proto()
148 
149  if_net.CreateScope([], [do_then_workspace_blob])
150  if do_else_workspace_blob:
151  if_net.CreateScope([], [do_else_workspace_blob])
152  if_net.If(if_inputs, if_outputs, **if_args)
153  if_net.AddExternalOutput(*if_outputs)
154 
155 
156 def add_while_op(
157  while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=None):
158  """
159  A helper function to add a While op to the net. Same rules for determining
160  outer and inner blobs as for the 'If' operator apply for the 'While' operator
161  loop and condition subnets. If specified, condition net is executed in a separate
162  workspace before the first and after each iteration, the last operator must have
163  a single scalar boolean output that is written into the condition blob.
164  Inputs:
165  while_net - net to add a While op to;
166  cond_blob - scalar bool blob reference, used as a stop condition;
167  lexical_scope - a set of outer blob names visible to the loop's body;
168  loop_body_net - net to execute on each iteration;
169  condition_body_net - net to compute condition value
170  """
171  input_blob_names, output_blob_names = get_external_blob_names(
172  loop_body_net, lexical_scope)
173 
174  # Since it's possible that loop is not going to run even once
175  # we have to add loop's external outputs into inputs
176  input_blob_names |= output_blob_names
177 
178  loop_inputs = [core.BlobReference(name=b, net=None) for b in input_blob_names]
179  loop_outputs = [core.BlobReference(name=b, net=None) for b in output_blob_names]
180 
181  while_inputs = [cond_blob] + loop_inputs
182  while_outputs = [] + loop_outputs
183 
184  do_loop_body_net = core.Net('do_loop_body_net')
185 
186  loop_input_output_names_ordered = [
187  str(b) for b in (loop_inputs + loop_outputs)]
188  loop_body_outer_blob_names = list(input_blob_names | output_blob_names)
189  loop_body_outer_blob_names_idx = [
190  loop_input_output_names_ordered.index(b) for b in loop_body_outer_blob_names]
191 
192  do_loop_body_workspace_blob = \
193  while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_body')
194 
195  loop_inputs.append(do_loop_body_workspace_blob)
196  loop_outputs.append(do_loop_body_workspace_blob)
197  # make sure that added workspace pointer blobs are in While inputs/outputs
198  while_inputs.append(do_loop_body_workspace_blob)
199  while_outputs.append(do_loop_body_workspace_blob)
200 
201  do_loop_body_net.Do(
202  loop_inputs,
203  loop_outputs,
204  net=loop_body_net.Proto(),
205  inner_blobs=loop_body_outer_blob_names,
206  outer_blobs_idx=loop_body_outer_blob_names_idx,
207  copy_external_blobs=True)
208  do_loop_body_net.AddExternalOutput(*loop_outputs)
209 
210  while_args = {}
211  while_args['loop_net'] = do_loop_body_net.Proto()
212 
213  cond_workspace_blob = None
214  if condition_body_net:
215  cond_input_blob_names, cond_output_blob_names = get_external_blob_names(
216  condition_body_net, lexical_scope)
217 
218  # make sure condition blob is written by condition net and is
219  # visible outside of it
220  found_condition_output = False
221  for op in condition_body_net.Proto().op:
222  if str(cond_blob) in op.output:
223  found_condition_output = True
224  break
225  assert found_condition_output, \
226  "Condition net does not write into condition blob"
227  if str(cond_blob) not in cond_output_blob_names:
228  cond_output_blob_names.add(str(cond_blob))
229 
230  cond_inputs = [core.BlobReference(name=b, net=None)
231  for b in cond_input_blob_names]
232  assert str(cond_blob) in cond_output_blob_names, \
233  'Condition blob expected in condition net output'
234  cond_outputs = [core.BlobReference(name=b, net=None)
235  for b in cond_output_blob_names]
236 
237  condition_net = core.Net('do_loop_condition_net')
238 
239  cond_input_output_names_ordered = [
240  str(b) for b in (cond_inputs + cond_outputs)]
241  cond_body_outer_blob_names = \
242  list(cond_input_blob_names | cond_output_blob_names)
243  cond_body_outer_blob_names_idx = [
244  cond_input_output_names_ordered.index(b)
245  for b in cond_body_outer_blob_names]
246 
247  cond_workspace_blob = \
248  while_net.NextScopedBlob(while_net.Name() + '/workspace_loop_cond')
249  cond_inputs.append(cond_workspace_blob)
250  cond_outputs.append(cond_workspace_blob)
251 
252  condition_net.Do(
253  cond_inputs,
254  cond_outputs,
255  net=condition_body_net.Proto(),
256  inner_blobs=cond_body_outer_blob_names,
257  outer_blobs_idx=cond_body_outer_blob_names_idx)
258  condition_net.AddExternalOutput(*cond_outputs)
259 
260  while_args['cond_net'] = condition_net.Proto()
261 
262  while_inputs += [b for b in cond_inputs
263  if str(b) not in input_blob_names]
264  while_outputs += [b for b in cond_outputs
265  if str(b) not in output_blob_names]
266 
267  if str(cond_blob) not in lexical_scope:
268  while_net.ConstantFill(
269  [],
270  cond_blob,
271  dtype=core.DataType.BOOL,
272  value=False)
273 
274  while_net.CreateScope([], [do_loop_body_workspace_blob])
275  if cond_workspace_blob:
276  while_net.CreateScope([], [cond_workspace_blob])
277  while_net.While(while_inputs, while_outputs, **while_args)
278  while_net.AddExternalOutput(*while_outputs)