4 Implement functions for controlling execution of nets and steps, including 14 from __future__
import absolute_import
15 from __future__
import division
16 from __future__
import print_function
17 from __future__
import unicode_literals
20 from future.utils
import viewitems
26 _used_step_names = set()
29 def _get_next_step_name(control_name, base_name):
30 global _current_idx, _used_step_names
31 concat_name =
'%s/%s' % (base_name, control_name)
32 next_name = concat_name
33 while next_name
in _used_step_names:
34 next_name =
'%s_%d' % (concat_name, _current_idx)
36 _used_step_names.add(next_name)
43 (a, b, c) --> [a, b, c] 45 ([a, b, c]) --> [a, b, c] 49 'input cannot be empty.')
52 if not isinstance(output, list):
59 def _IsNets(nets_or_steps):
60 if isinstance(nets_or_steps, list):
61 return all(isinstance(n, core.Net)
for n
in nets_or_steps)
63 return isinstance(nets_or_steps, core.Net)
66 def _PrependNets(nets_or_steps, *nets):
67 nets_or_steps = _MakeList((nets_or_steps,))
68 nets = _MakeList(nets)
69 if _IsNets(nets_or_steps):
70 return nets + nets_or_steps
72 return [Do(
'prepend', nets)] + nets_or_steps
75 def _AppendNets(nets_or_steps, *nets):
76 nets_or_steps = _MakeList((nets_or_steps,))
77 nets = _MakeList(nets)
78 if _IsNets(nets_or_steps):
79 return nets_or_steps + nets
81 return nets_or_steps + [Do(
'append', nets)]
84 def GetConditionBlobFromNet(condition_net):
86 The condition blob is the last external_output that must 89 assert len(condition_net.Proto().external_output) > 0, (
90 "Condition net %s must has at least one external output" %
91 condition_net.Proto.name)
95 return core.BlobReference(condition_net.Proto().external_output[-1])
98 def BoolNet(*blobs_with_bool_value):
99 """A net assigning constant bool values to blobs. It is mainly used for 100 initializing condition blobs, for example, in multi-task learning, we 101 need to access reader_done blobs before reader_net run. In that case, 102 the reader_done blobs must be initialized. 105 blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will 106 assign each bool_value to the corresponding blob. 109 bool_net: A net assigning constant bool values to blobs. 112 - BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n)) 113 - BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)]) 114 - BoolNet((cond_1, bool_value_1)) 116 blobs_with_bool_value = _MakeList(blobs_with_bool_value)
117 bool_net = core.Net(
'bool_net')
118 for blob, bool_value
in blobs_with_bool_value:
119 out_blob = bool_net.ConstantFill(
124 dtype=core.DataType.BOOL)
125 bool_net.AddExternalOutput(out_blob)
130 def NotNet(condition_blob_or_net):
131 """Not of a condition blob or net 134 condition_blob_or_net can be either blob or net. If condition_blob_or_net 135 is Net, the condition is its last external_output 136 that must be a single bool. 139 not_net: the net NOT the input 140 out_blob: the output blob of the not_net 142 if isinstance(condition_blob_or_net, core.Net):
143 condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
145 condition_blob = condition_blob_or_net
147 not_net = core.Net(
'not_net')
148 out_blob = not_net.Not(condition_blob)
149 not_net.AddExternalOutput(out_blob)
151 return not_net, out_blob
154 def _CopyConditionBlobNet(condition_blob):
155 """Make a condition net that copies the condition_blob 158 condition_blob is a single bool. 161 not_net: the net NOT the input 162 out_blob: the output blob of the not_net 164 condition_net = core.Net(
'copy_condition_blob_net')
165 out_blob = condition_net.Copy(condition_blob)
166 condition_net.AddExternalOutput(out_blob)
168 return condition_net, out_blob
171 def MergeConditionNets(name, condition_nets, relation):
173 Merge multi condition nets into a single condition nets. 176 name: name of the new condition net. 177 condition_nets: a list of condition nets. The last external_output 178 of each condition net must be single bool value. 179 relation: can be 'And' or 'Or'. 182 - A new condition net. Its last external output is relation of all 185 if not isinstance(condition_nets, list):
186 return condition_nets
187 if len(condition_nets) <= 1:
188 return condition_nets[0]
if condition_nets
else None 190 merged_net = core.Net(name)
191 for i
in range(len(condition_nets)):
192 net_proto = condition_nets[i].Proto()
193 assert net_proto.device_option == merged_net.Proto().device_option
194 assert net_proto.type == merged_net.Proto().type
195 merged_net.Proto().op.extend(net_proto.op)
196 merged_net.Proto().external_input.extend(net_proto.external_input)
198 curr_cond = GetConditionBlobFromNet(condition_nets[i])
200 last_cond = curr_cond
202 last_cond = merged_net.__getattr__(relation)([last_cond, curr_cond])
204 for k, v
in viewitems(condition_nets[i]._attr_dict):
205 merged_net._attr_dict[k] += v
207 merged_net.AddExternalOutput(last_cond)
212 def CombineConditions(name, condition_nets, relation):
214 Combine conditions of multi nets into a single condition nets. Unlike 215 MergeConditionNets, the actual body of condition_nets is not copied into 216 the combine condition net. 218 One example is about multi readers. Each reader net has a reader_done 219 condition. When we want to check whether all readers are done, we can 220 use this function to build a new net. 223 name: name of the new condition net. 224 condition_nets: a list of condition nets. The last external_output 225 of each condition net must be single bool value. 226 relation: can be 'And' or 'Or'. 229 - A new condition net. Its last external output is relation of all 232 if not condition_nets:
234 if not isinstance(condition_nets, list):
235 raise ValueError(
'condition_nets must be a list of nets.')
237 if len(condition_nets) == 1:
238 condition_blob = GetConditionBlobFromNet(condition_nets[0])
239 condition_net, _ = _CopyConditionBlobNet(condition_blob)
242 combined_net = core.Net(name)
243 for i
in range(len(condition_nets)):
244 curr_cond = GetConditionBlobFromNet(condition_nets[i])
246 last_cond = curr_cond
248 last_cond = combined_net.__getattr__(relation)(
249 [last_cond, curr_cond])
251 combined_net.AddExternalOutput(last_cond)
256 def Do(name, *nets_or_steps):
258 Execute the sequence of nets or steps once. 261 - Do('myDo', net1, net2, ..., net_n) 262 - Do('myDo', list_of_nets) 263 - Do('myDo', step1, step2, ..., step_n) 264 - Do('myDo', list_of_steps) 266 nets_or_steps = _MakeList(nets_or_steps)
267 if (len(nets_or_steps) == 1
and isinstance(
268 nets_or_steps[0], core.ExecutionStep)):
269 return nets_or_steps[0]
271 return core.scoped_execution_step(
272 _get_next_step_name(
'Do', name), nets_or_steps)
275 def DoParallel(name, *nets_or_steps):
277 Execute the nets or steps in parallel, waiting for all of them to finish 280 - DoParallel('pDo', net1, net2, ..., net_n) 281 - DoParallel('pDo', list_of_nets) 282 - DoParallel('pDo', step1, step2, ..., step_n) 283 - DoParallel('pDo', list_of_steps) 285 nets_or_steps = _MakeList(nets_or_steps)
286 if (len(nets_or_steps) == 1
and isinstance(
287 nets_or_steps[0], core.ExecutionStep)):
288 return nets_or_steps[0]
290 return core.scoped_execution_step(
291 _get_next_step_name(
'DoParallel', name),
293 concurrent_substeps=
True)
296 def _RunOnceIf(name, condition_blob_or_net, nets_or_steps):
298 Execute nets_or_steps once if condition_blob_or_net evaluates as true. 300 If condition_blob_or_net is Net, the condition is its last external_output 301 that must be a single bool. And this net will be executed before 302 nets_or_steps so as to get the condition. 304 condition_not_net, stop_blob = NotNet(condition_blob_or_net)
305 if isinstance(condition_blob_or_net, core.Net):
306 nets_or_steps = _PrependNets(
307 nets_or_steps, condition_blob_or_net, condition_not_net)
309 nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
311 def if_step(control_name):
312 return core.scoped_execution_step(
313 _get_next_step_name(control_name, name),
315 should_stop_blob=stop_blob,
319 if _IsNets(nets_or_steps):
320 bool_net = BoolNet((stop_blob,
False))
321 return Do(name +
'/_RunOnceIf',
322 bool_net, if_step(
'_RunOnceIf-inner'))
324 return if_step(
'_RunOnceIf')
327 def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps):
329 Similar to _RunOnceIf() but Execute nets_or_steps once if 330 condition_blob_or_net evaluates as false. 332 if isinstance(condition_blob_or_net, core.Net):
333 condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
334 nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
336 copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net)
337 nets_or_steps = _PrependNets(nets_or_steps, copy_net)
339 return core.scoped_execution_step(
340 _get_next_step_name(
'_RunOnceIfNot', name),
342 should_stop_blob=condition_blob,
347 def For(name, nets_or_steps, iter_num):
349 Execute nets_or_steps iter_num times. 352 nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or 354 iter_num: the number times to execute the nets_or_steps. 357 A ExecutionStep instance. 359 init_net = core.Net(
'init-net')
360 iter_cnt = init_net.CreateCounter([], init_count=iter_num)
361 iter_net = core.Net(
'For-iter')
362 iter_done = iter_net.CountDown([iter_cnt])
364 for_step = core.scoped_execution_step(
365 _get_next_step_name(
'For-inner', name),
366 _PrependNets(nets_or_steps, iter_net),
367 should_stop_blob=iter_done)
368 return Do(name +
'/For',
369 Do(name +
'/For-init-net', init_net),
373 def While(name, condition_blob_or_net, nets_or_steps):
375 Execute nets_or_steps when condition_blob_or_net returns true. 378 condition_blob_or_net: If it is an instance of Net, its last 379 external_output must be a single bool. 380 nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or 384 A ExecutionStep instance. 386 condition_not_net, stop_blob = NotNet(condition_blob_or_net)
387 if isinstance(condition_blob_or_net, core.Net):
388 nets_or_steps = _PrependNets(
389 nets_or_steps, condition_blob_or_net, condition_not_net)
391 nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
393 def while_step(control_name):
394 return core.scoped_execution_step(
395 _get_next_step_name(control_name, name),
397 should_stop_blob=stop_blob,
400 if _IsNets(nets_or_steps):
407 bool_net = BoolNet((stop_blob,
False))
408 return Do(name +
'/While', bool_net, while_step(
'While-inner'))
410 return while_step(
'While')
413 def Until(name, condition_blob_or_net, nets_or_steps):
415 Similar to While() but execute nets_or_steps when 416 condition_blob_or_net returns false 418 if isinstance(condition_blob_or_net, core.Net):
419 stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
420 nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
422 stop_blob = core.BlobReference(str(condition_blob_or_net))
424 return core.scoped_execution_step(
425 _get_next_step_name(
'Until', name),
427 should_stop_blob=stop_blob)
430 def DoWhile(name, condition_blob_or_net, nets_or_steps):
432 Execute nets_or_steps when condition_blob_or_net returns true. It will 433 execute nets_or_steps before evaluating condition_blob_or_net. 436 condition_blob_or_net: if it is an instance of Net, tts last external_output 437 must be a single bool. 438 nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or 442 A ExecutionStep instance. 444 condition_not_net, stop_blob = NotNet(condition_blob_or_net)
445 if isinstance(condition_blob_or_net, core.Net):
446 nets_or_steps = _AppendNets(
447 nets_or_steps, condition_blob_or_net, condition_not_net)
449 nets_or_steps = _AppendNets(nets_or_steps, condition_not_net)
455 bool_net = BoolNet((stop_blob,
False))
456 return Do(name +
'/DoWhile', bool_net, core.scoped_execution_step(
457 _get_next_step_name(
'DoWhile-inner', name),
459 should_stop_blob=stop_blob,
463 def DoUntil(name, condition_blob_or_net, nets_or_steps):
465 Similar to DoWhile() but execute nets_or_steps when 466 condition_blob_or_net returns false. It will execute 467 nets_or_steps before evaluating condition_blob_or_net. 469 Special case: if condition_blob_or_net is a blob and is pre-set to 470 true, then only the first net/step of nets_or_steps will be executed and 471 loop is exited. So you need to be careful about the initial value the 472 condition blob when using DoUntil(), esp when DoUntil() is called twice. 474 if not isinstance(condition_blob_or_net, core.Net):
475 stop_blob = core.BlobReference(condition_blob_or_net)
476 return core.scoped_execution_step(
477 _get_next_step_name(
'DoUntil', name),
479 should_stop_blob=stop_blob)
481 nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net)
482 stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
488 bool_net = BoolNet((stop_blob,
False))
489 return Do(name +
'/DoUntil', bool_net, core.scoped_execution_step(
490 _get_next_step_name(
'DoUntil-inner', name),
492 should_stop_blob=stop_blob,
496 def Switch(name, *conditions):
498 Execute the steps for which the condition is true. 499 Each condition is a tuple (condition_blob_or_net, nets_or_steps). 501 1. Multi steps can be executed if their conditions are true. 502 2. The conditions_blob_or_net (if it is Net) of all steps will be 506 - Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n)) 507 - Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)]) 508 - Switch('name', (cond_1, net_1)) 510 conditions = _MakeList(conditions)
511 return core.scoped_execution_step(
512 _get_next_step_name(
'Switch', name),
513 [_RunOnceIf(name +
'/Switch', cond, step)
for cond, step
in conditions])
516 def SwitchNot(name, *conditions):
518 Similar to Switch() but execute the steps for which the condition is False. 520 conditions = _MakeList(conditions)
521 return core.scoped_execution_step(
522 _get_next_step_name(
'SwitchNot', name),
523 [_RunOnceIfNot(name +
'/SwitchNot', cond, step)
524 for cond, step
in conditions])
527 def If(name, condition_blob_or_net,
528 true_nets_or_steps, false_nets_or_steps=
None):
530 condition_blob_or_net is first evaluated or executed. If the condition is 531 true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps 534 If condition_blob_or_net is Net, the condition is its last external_output 535 that must be a single bool. And this Net will be executred before both 536 true/false_nets_or_steps so as to get the condition. 538 if not false_nets_or_steps:
539 return _RunOnceIf(name +
'/If',
540 condition_blob_or_net, true_nets_or_steps)
542 if isinstance(condition_blob_or_net, core.Net):
543 condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
545 condition_blob = condition_blob_or_net
549 _RunOnceIf(name +
'/If-true',
550 condition_blob_or_net, true_nets_or_steps),
551 _RunOnceIfNot(name +
'/If-false', condition_blob, false_nets_or_steps)
555 def IfNot(name, condition_blob_or_net,
556 true_nets_or_steps, false_nets_or_steps=
None):
558 If condition_blob_or_net returns false, executes true_nets_or_steps, 559 otherwise executes false_nets_or_steps 561 if not false_nets_or_steps:
562 return _RunOnceIfNot(name +
'/IfNot',
563 condition_blob_or_net, true_nets_or_steps)
565 if isinstance(condition_blob_or_net, core.Net):
566 condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
568 condition_blob = condition_blob_or_net
572 _RunOnceIfNot(name +
'/IfNot-true',
573 condition_blob_or_net, true_nets_or_steps),
574 _RunOnceIf(name +
'/IfNot-false', condition_blob, false_nets_or_steps)