3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 def get_external_blob_names(net, lexical_scope):
13 Returns a set of blobs a given net depends on and a set of 14 output blobs that are written by the net 16 net - net to return input/output blobs for; 17 lexical_scope - all external blob names visible to the net 20 net_proto = net.Proto()
21 net_ssa, _ = core.get_ssa(net_proto)
22 input_names = core.get_undefined_blobs(net_ssa)
23 for input_name
in input_names:
24 assert str(input_name)
in lexical_scope, \
25 "Input blob " + input_name +
" is undefined" 28 for op
in net_proto.op:
29 for output
in op.output:
30 if output
in lexical_scope:
31 output_names.add(output)
33 return input_names, output_names
36 def add_if_op(if_net, cond_blob, lexical_scope, then_net, else_net=None):
38 A helper function to add an If op to the net. 39 Automatically determines whether blobs in the then/else subnets are external 40 (from the outer workspace) or local (visible only inside subnet's workspace) 41 based on lexical scope - set of all outer blob names visible to the 'If' 42 operator. All the blobs in then/else subnets with names matching a name in lexical 43 scope and all the blobs that are first used as the operators' inputs are 44 considered outer blobs - these blobs must exist in the outer workspace, 45 then/else subnets can read their values and new values written into these blobs 46 will be visible outside of the 'If' operator. All other blobs are local - exist 47 only within inner workspaces for then/else. 49 if_net - net to add an If op to; 50 cond_blob - scalar bool blob reference, used as If condition; 51 lexical_scope - a set of outer blob names visible to then/else branches; 52 then_net/else_net - nets (core.Net) for then/else branches 54 then_input_blob_names, then_output_blob_names = get_external_blob_names(
55 then_net, lexical_scope)
57 else_input_blob_names = set()
58 else_output_blob_names = set()
60 else_input_blob_names, else_output_blob_names = get_external_blob_names(
61 else_net, lexical_scope)
63 input_blob_names = then_input_blob_names | else_input_blob_names
64 output_blob_names = then_output_blob_names | else_output_blob_names
66 if_inputs = [cond_blob]
67 if_inputs += [core.BlobReference(name=b, net=
None)
for b
in input_blob_names]
68 if_outputs = [core.BlobReference(name=b, net=
None)
for b
in output_blob_names]
70 do_then_net = core.Net(
'do_then_net')
73 [core.BlobReference(name=b, net=
None)
for b
in then_input_blob_names]
75 [core.BlobReference(name=b, net=
None)
for b
in then_output_blob_names]
76 then_input_output_names_ordered = [
77 str(b)
for b
in (then_input_blobs + then_output_blobs)]
79 then_outer_blob_names = list(then_input_blob_names | then_output_blob_names)
80 then_outer_blob_names_idx = [
81 then_input_output_names_ordered.index(b)
for b
in then_outer_blob_names]
84 do_then_workspace_blob = if_net.NextScopedBlob(if_net.Name() +
'/workspace_if_then')
85 then_input_blobs.append(do_then_workspace_blob)
86 then_output_blobs.append(do_then_workspace_blob)
88 if_inputs.append(do_then_workspace_blob)
89 if_outputs.append(do_then_workspace_blob)
95 inner_blobs=then_outer_blob_names,
96 outer_blobs_idx=then_outer_blob_names_idx)
97 do_then_net.AddExternalOutput(*then_output_blobs)
100 if_args[
'then_net'] = do_then_net.Proto()
102 do_else_workspace_blob =
None 104 do_else_net = core.Net(
'do_else_net')
107 [core.BlobReference(name=b, net=
None)
for b
in else_input_blob_names]
108 else_output_blobs = \
109 [core.BlobReference(name=b, net=
None)
for b
in else_output_blob_names]
110 else_input_output_names_ordered = [
111 str(b)
for b
in (else_input_blobs + else_output_blobs)]
113 else_outer_blob_names = list(else_input_blob_names | else_output_blob_names)
114 else_outer_blob_names_idx = [
115 else_input_output_names_ordered.index(b)
for b
in else_outer_blob_names]
117 do_else_workspace_blob = \
118 if_net.NextScopedBlob(if_net.Name() +
'/workspace_if_else')
119 else_input_blobs.append(do_else_workspace_blob)
120 else_output_blobs.append(do_else_workspace_blob)
122 if_inputs.append(do_else_workspace_blob)
123 if_outputs.append(do_else_workspace_blob)
128 net=else_net.Proto(),
129 inner_blobs=else_outer_blob_names,
130 outer_blobs_idx=else_outer_blob_names_idx)
131 do_else_net.AddExternalOutput(*else_output_blobs)
132 if_args[
'else_net'] = do_else_net.Proto()
134 if_net.CreateScope([], [do_then_workspace_blob])
135 if do_else_workspace_blob:
136 if_net.CreateScope([], [do_else_workspace_blob])
137 if_net.If(if_inputs, if_outputs, **if_args)
138 if_net.AddExternalOutput(*if_outputs)
142 while_net, cond_blob, lexical_scope, loop_body_net, condition_body_net=
None):
144 A helper function to add a While op to the net. Same rules for determining 145 outer and inner blobs as for the 'If' operator apply for the 'While' operator 146 loop and condition subnets. If specified, condition net is executed in a separate 147 workspace before the first and after each iteration, the last operator must have 148 a single scalar boolean output that is written into the condition blob. 150 while_net - net to add a While op to; 151 cond_blob - scalar bool blob reference, used as a stop condition; 152 lexical_scope - a set of outer blob names visible to the loop's body; 153 loop_body_net - net to execute on each iteration; 154 condition_body_net - net to compute condition value 156 input_blob_names, output_blob_names = get_external_blob_names(
157 loop_body_net, lexical_scope)
161 input_blob_names |= output_blob_names
163 loop_inputs = [core.BlobReference(name=b, net=
None)
for b
in input_blob_names]
164 loop_outputs = [core.BlobReference(name=b, net=
None)
for b
in output_blob_names]
166 while_inputs = [cond_blob] + loop_inputs
167 while_outputs = [] + loop_outputs
169 do_loop_body_net = core.Net(
'do_loop_body_net')
171 loop_input_output_names_ordered = [
172 str(b)
for b
in (loop_inputs + loop_outputs)]
173 loop_body_outer_blob_names = list(input_blob_names | output_blob_names)
174 loop_body_outer_blob_names_idx = [
175 loop_input_output_names_ordered.index(b)
for b
in loop_body_outer_blob_names]
177 do_loop_body_workspace_blob = \
178 while_net.NextScopedBlob(while_net.Name() +
'/workspace_loop_body')
180 loop_inputs.append(do_loop_body_workspace_blob)
181 loop_outputs.append(do_loop_body_workspace_blob)
183 while_inputs.append(do_loop_body_workspace_blob)
184 while_outputs.append(do_loop_body_workspace_blob)
189 net=loop_body_net.Proto(),
190 inner_blobs=loop_body_outer_blob_names,
191 outer_blobs_idx=loop_body_outer_blob_names_idx,
192 copy_external_blobs=
True)
193 do_loop_body_net.AddExternalOutput(*loop_outputs)
196 while_args[
'loop_net'] = do_loop_body_net.Proto()
198 cond_workspace_blob =
None 199 if condition_body_net:
200 cond_input_blob_names, cond_output_blob_names = get_external_blob_names(
201 condition_body_net, lexical_scope)
205 found_condition_output =
False 206 for op
in condition_body_net.Proto().op:
207 if str(cond_blob)
in op.output:
208 found_condition_output =
True 210 assert found_condition_output, \
211 "Condition net does not write into condition blob" 212 if str(cond_blob)
not in cond_output_blob_names:
213 cond_output_blob_names.add(str(cond_blob))
215 cond_inputs = [core.BlobReference(name=b, net=
None)
216 for b
in cond_input_blob_names]
217 assert str(cond_blob)
in cond_output_blob_names, \
218 'Condition blob expected in condition net output' 219 cond_outputs = [core.BlobReference(name=b, net=
None)
220 for b
in cond_output_blob_names]
222 condition_net = core.Net(
'do_loop_condition_net')
224 cond_input_output_names_ordered = [
225 str(b)
for b
in (cond_inputs + cond_outputs)]
226 cond_body_outer_blob_names = \
227 list(cond_input_blob_names | cond_output_blob_names)
228 cond_body_outer_blob_names_idx = [
229 cond_input_output_names_ordered.index(b)
230 for b
in cond_body_outer_blob_names]
232 cond_workspace_blob = \
233 while_net.NextScopedBlob(while_net.Name() +
'/workspace_loop_cond')
234 cond_inputs.append(cond_workspace_blob)
235 cond_outputs.append(cond_workspace_blob)
240 net=condition_body_net.Proto(),
241 inner_blobs=cond_body_outer_blob_names,
242 outer_blobs_idx=cond_body_outer_blob_names_idx)
243 condition_net.AddExternalOutput(*cond_outputs)
245 while_args[
'cond_net'] = condition_net.Proto()
247 while_inputs += [b
for b
in cond_inputs
248 if str(b)
not in input_blob_names]
249 while_outputs += [b
for b
in cond_outputs
250 if str(b)
not in output_blob_names]
252 if str(cond_blob)
not in lexical_scope:
253 while_net.ConstantFill(
256 dtype=core.DataType.BOOL,
259 while_net.CreateScope([], [do_loop_body_workspace_blob])
260 if cond_workspace_blob:
261 while_net.CreateScope([], [cond_workspace_blob])
262 while_net.While(while_inputs, while_outputs, **while_args)
263 while_net.AddExternalOutput(*while_outputs)