3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 @context.define_context()
16 Scope-driven mechanism for building nets, loops and conditional blocks. 18 name: NetBuilder's name 19 initial_scope: list of blobs that are available for reading/writing 21 from caffe2.python.net_builder import NetBuilder, ops 22 with NetBuilder() as nb: 26 ops.stop_if(ops.LE([c, ops.Const(0)])) 27 ops.Add([c, ops.Const(-1)], [c]) 28 with ops.If(ops.GE([c, ops.Const(3)])): 29 ops.Add([d, ops.Const(10)], [d]) 32 step = core.to_execution_step(nb) 34 def __init__(self, name=None, initial_scope=None, _stop_blob_required=False,
35 _stop_blob=
None, _fullname=
None, _use_control_ops=
False):
36 parent = NetBuilder.current(required=
False)
37 assert not _fullname
or not name,
'Cannot set both _fullname and name' 38 assert not _use_control_ops
or \
39 (
not _stop_blob_required
and not _stop_blob), \
40 'Stop blobs are not used with control operators' 41 self.
name = _fullname
or '/'.join(
42 n
for n
in (parent.name
if parent
else None, name)
if n
49 parent._update_lexical_scope()
60 Returns the BlobReference to the stop_blob of this NetBuilder. 61 If one is not yet available, creates one. 62 This function assumes that the stop_blob() will be used immediatelly 63 in the current net, so it doesn't initialize it if the current net is 64 the first of the builder. 67 'Stop blobs are not used with control operators' 71 net.NextName(
'stop_blob'), net=net)
74 self._children.insert(0,
core.Net(
'stop_blob_init'))
78 def stop_if(self, blob):
80 'Stop blobs are not used with control operators' 82 ops.Or([stop_blob, blob], [stop_blob])
85 def _assert_mutable(self):
87 'This NetBuilder (%s) has been built already.' % self.
name)
89 def _update_lexical_scope(self):
91 Updates lexical scope based on the current list of children. 92 Lexical scope contains names of blobs that are currently available 93 and were introduced in the net builder 99 elif isinstance(child, NetBuilder)
and child._use_control_ops:
102 def _reset_children(self):
107 def add(self, child):
111 assert isinstance(child,
core.Net)
or (
112 isinstance(child, NetBuilder)
and child._use_control_ops), \
113 "Expected Net or NetBuilder with control ops" 116 self._children.append(child)
123 def current_net(self, name=None):
131 if hasattr(child,
'freeze'):
140 def __exit__(self, etype, *args):
144 merged_net = NetBuilder.merge_nets(
146 assert merged_net,
"Expected a non-empty merge of children" 150 if etype
is not None:
153 'This NetBuilder (%s) requires a stop condition ' % self.
name +
154 'to be set with `stop` or `stop_if`')
157 def merge_nets(nets_or_builders, outer_blob_names):
178 for n
in nets_or_builders:
180 if isinstance(n, NetBuilder):
181 assert n._use_control_ops, \
182 "Merging of NetBuilder supported only for control ops" 184 assert len(nets) == 1
and isinstance(nets[0],
core.Net), \
185 "Invalid control op net builder" 196 external_outputs = [o
for o
in net.Proto().external_output
197 if o
in outer_blob_names]
198 net.Proto().external_output[:] = external_outputs
202 return self.
name or 'Un-named NetBuilder' 207 Operations to be used in the context of a NetBuilder. 209 def net(self, net=None, name=None):
211 Retrieves the current net, or add a new net to the builder. 213 net: If provided, add the given net to the active builder. 214 Else, returns the current Net or creates a new one as needed. 215 name: if provided, creates a new Net with given name and makes 216 it the new current net of the active builder. Cannot 217 be provided if net is provided. 219 assert name
is None or net
is None, (
220 'Cannot provide both `net` and `name`.')
222 NetBuilder.current().add(net)
224 return NetBuilder.current().current_net(name=name)
228 Adds an operator call to the currently active Net. 230 if op_type.startswith(
'__'):
231 raise AttributeError()
233 if NetBuilder.current(required=
False)
is None:
234 raise AttributeError(
'No active NetBuilder.')
235 return getattr(self.
net(), op_type)
239 Creates a local task group which will execute as the next step of 240 the current NetBuilder. 243 group = NetBuilder.current()
252 Stop execution of the current execution step. 257 In the example, 'b' will never be printed. 259 return self.
stop_if(ops.Const(
True))
263 Stop execution of the current execution step if the 264 condition `blob` is met. 267 ops.stop_if(ops.LE([x, ops.Const(0)])) 269 In the example, 'b' will only be printed if the value of scalar 270 tensor 'x' is greater than 0. 272 return NetBuilder.current().
stop_if(blob)
274 def loop(self, iters=None, name=None):
276 Creates a NetBuilder that will execute in a loop as the next step of 277 the current NetBuilder. If `iters` is provided, the loop will execute 278 for `iters` iterations and then stop. `iters` can be a constant or a 279 BlobReference. If `iters` is not provided, the loop will execute 280 until `ops.stop` or `ops.stop_if` is called. 284 ops.stop_if(ops.LE([a, ops.Const(0)])) 286 ops.Add([a, ops.Const(-1)], [a]) 287 Above, 'a' will be printed 5 times, with values 5 to 1. 289 with ops.loop(10) as loop: 290 ops.LogInfo(loop.iter()) 291 This will print the numbers from 0 to 9. 293 x = ops.Add([ops.Const(10), ops.Const(10)]) 294 with ops.loop(x) as loop: 295 ops.LogInfo(loop.iter()) 296 This will print the numbers from 0 to 19. 298 return NetBuilder.current().add(
_Loop(iters, name=name))
302 Creates a NetBuilder that will execute once as the next step of the 303 current NetBuilder. After execution, a bool tensor will indicate 304 whether the inner execution was halted with `stop` or `stop_if`. 307 with ops.stop_guard() as sg1: 309 ops.Print(ops.Const('did not stop')) 311 with ops.stop_guard() as sg2: 313 ops.Print(ops.Const('did not stop')) 314 ops.Print(sg1.has_stopped(), []) 315 ops.Print(sg2.has_stopped(), []) 316 In the example, 'did not stop' will be printed once, 317 followed by True and False. 319 return NetBuilder.current().add(
320 _StopGuard(has_stopped_blob=has_stopped_blob, name=name))
322 def If(self, cond, name=None):
324 Creates a NetBuilder that will execute once as the next step of the 325 current NetBuilder if the blob `cond` is True. 327 with ops.If(ops.Const(True)): 328 ops.Print(ops.Const('Will print')) 329 with ops.If(ops.Const(False)): 330 ops.Print(ops.Const('Wont print')) 331 The example will print 'Will print' once. 333 return NetBuilder.current().add(
_RunIf(cond, name=name))
337 Same as If, but uses 'If' operator instead of execution step logic 339 return NetBuilder.current().add(
_RunIfNet(cond, name=name))
343 Else branch of IfNet, has to be specified immediately after IfNet. 345 with ops.IfNet(ops.LT([x, y])): 354 NetBuilder for 'While' control operator 356 return NetBuilder.current().add(
_RunWhileNet(name=name))
360 Loop's condition, executed within WhileNet context 362 assert isinstance(NetBuilder.current(), _RunWhileNet), \
363 "Use of Condition outside of WhileNet" 368 Defines operations that will be executed once at task startup. 369 Useful when implementing processors, that don't have access to the Task 372 This setup will be run only once, even if multiple instances of the task 373 will run in parallel. For instance-local initialization, use 374 `task_instance_init` instead. 377 def my_processor(rec): 378 with ops.task_init(): 382 ops.Add(rec[0](), zero), ops.Add(rec[1](), two)) 385 self.
net().add_attribute(Task.TASK_SETUP, setup)
390 Define operations to be executed once at task shutdown. 391 Useful when implementing processors, that don't have access to the Task 394 This shutdown will be run only once, after all concurrent instances of 395 the task have already finished. For instance-local shutdown, 396 use `task_instance_exit` instead. 399 def read_queue(queue): 400 with ops.task_exit(): 401 queue.close(ops.net()) 402 return queue.read(ops.net()) 405 self.
net().add_attribute(Task.TASK_SETUP, setup)
410 Defines operations that will be executed once at startup of each 411 instance of a task. This can be seen as "thread_local" initialization. 412 It is guaranteed to run only after all `task_init` logic finishes. 414 This setup will be run concurrently for each instance of a task. 415 For global task initialization, use `task_init` instead. 418 self.
net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
423 Defines operations that will be executed once at shutdown of each 424 instance of a task. This can be seen as "thread_local" finalization. 426 This shutdown will be run concurrently for each instance of a task. 427 For global task shutdown, use `task_exit` instead. 430 self.
net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
435 Similar to `task_init`, but executes at TaskGroup's startup instead, 436 before any task of the group starts executing. This will run only 437 once on each node, before initialization of any task, so it can be 438 used e.g. to initialize blobs shared across tasks. 441 self.
net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
446 Similar to `task_exit`, but executes at TaskGroup's exit instead, 447 after all tasks of the group finished execution. 448 This will run only once on each node. 451 self.
net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
456 Define operations to be executed at every time interval from 457 task start-up to finish. These operations are guaranteed to 458 execute at least once after all other operations of the task are 462 with ops.task_reporter(interval_ms=10000): 463 ops.LogInfo('10s elapsed') 469 Similar to task_report, but operations defined within this block 470 will run repeatedly for as long as any of the tasks in the current 471 TaskGroup have not finished. 480 def __init__(self, interval_ms, net=None, name=None):
481 NetBuilder.__init__(self, name)
485 def __exit__(self, etype, *args):
487 step = core.to_execution_step(self)
490 self._net.add_attribute(Task.REPORT_STEP, step)
492 TaskGroup.current().report_step(
494 NetBuilder.__exit__(self, etype, *args)
501 def __init__(self, type, name=None):
502 NetBuilder.__init__(self, name)
505 def setup(self, net):
506 if self.
type == _SetupBuilder.INIT:
507 return core.to_execution_step(self)
510 if self.
type == _SetupBuilder.EXIT:
511 return core.to_execution_step(self)
515 def __init__(self, name=None):
516 NetBuilder.__init__(self, name)
518 def __exit__(self, etype, *args):
519 if etype
is None and self.
_stop_blob is not None:
521 NetBuilder.__exit__(self, etype, *args)
525 def __init__(self, has_stopped_blob=None, name=None):
526 _RunOnce.__init__(self, name)
531 r = _RunOnce.__enter__(self)
535 def __exit__(self, etype, *args):
538 ops.Const(
False, blob_out=self.
_stopped)
539 _RunOnce.__exit__(self, etype, *args)
543 Return a blob that will be set to scalar bool `True` after 544 this net builder ran, iff it was halted early. 546 assert self.
_ran,
'Context not used yet.' 551 def __init__(self, iters=None, name=None):
552 NetBuilder.__init__(self, name, _stop_blob_required=
True)
553 if iters
is not None:
554 self.
_inc = ops.Const(1)
555 self.
_iter = ops.Const(0)
558 else ops.Const(iters))
564 'This loop does not have a number of iterations.')
565 assert self.
_iter is not None, (
566 'iter() must be called from inside the loop context')
570 builder = NetBuilder.__enter__(self)
575 def __exit__(self, type, *args):
576 if type
is None and self.
_num_iters is not None:
578 NetBuilder.__exit__(self, type, *args)
582 def __init__(self, cond_blob=None, name=None, _already_ran=None):
583 _RunOnce.__init__(self, name)
584 assert cond_blob
or _already_ran
586 if _already_ran
is None:
591 self.
_else_blob = _already_ran
if cond_blob
is None else (
592 ops.Or([_already_ran, ops.Not(cond_blob)]))
595 r = _RunOnce.__enter__(self)
600 def Elif(self, cond, name=None):
601 assert not self.
_is_else,
'Else not allowed for an Else.' 602 return NetBuilder.current().add(
_RunIf(
605 def Else(self, name=None):
606 assert not self.
_is_else,
'Elif not allowed for an Else.' 607 return NetBuilder.current().add(
613 Generates a single net that uses If operator 615 def __init__(self, cond_blob, name=None):
616 NetBuilder.__init__(self, name=name, _use_control_ops=
True)
617 assert cond_blob,
'Conditional blob is not specified for an If net' 622 def add(self, child):
623 return NetBuilder.add(self, child)
625 def __exit__(self, type, *args):
641 NetBuilder.__exit__(self, type, *args)
646 Else branch for _RunIfNet builder 648 def __init__(self, name=None):
649 NetBuilder.__init__(self, name=name, _use_control_ops=
True)
650 parent = NetBuilder.current(required=
False)
651 assert parent
and len(parent._children) > 0
and \
652 isinstance(parent._children[-1], _RunIfNet), \
653 'Invalid use of Else builder' 656 def __exit__(self, type, *args):
661 self._if_builder._else_net = NetBuilder.merge_nets(
663 if self._if_builder._else_net:
667 self._if_builder._cond_blob,
669 self._if_builder._then_net,
670 self._if_builder._else_net)
671 self._if_builder._current_net = if_else_net
672 self._if_builder._children = [if_else_net]
673 NetBuilder.__exit__(self, type, *args)
678 Generates a single net that uses While operator 680 def __init__(self, name=None):
681 NetBuilder.__init__(self, name=name, _use_control_ops=
True)
684 def __exit__(self, type, *args):
687 'Condition builder must be specified in While op' 689 _cond_blob = self._cond_builder._cond_blob
690 _cond_net = self._cond_builder._cond_net
694 loop_body_net = NetBuilder.merge_nets(
696 if not loop_body_net:
697 loop_body_net =
core.Net(
'empty_loop_body_net')
701 loop_body_net, _cond_net)
705 NetBuilder.__exit__(self, type, *args)
710 Computes loop's condition, used in the context of WhileNet. 711 Last operator must have a single scalar boolean output that will be used 712 as a condition value, no other blobs created in the condition net are 713 visible outside of it 715 def __init__(self, name=None):
716 NetBuilder.__init__(self, name=name, _use_control_ops=
True)
717 parent = NetBuilder.current(required=
False)
718 assert parent
and isinstance(parent, _RunWhileNet), \
719 'Invalid use of loop condition builder' 720 assert not parent._cond_builder, \
721 'Multiple loop condition builders specified' 722 assert len(parent._children) == 0, \
723 'Condition definition must be specified before the loop\'s body' 724 parent._cond_builder = self
728 def __exit__(self, type, *args):
734 assert self.
_cond_net,
'Invalid loop condition specified' 735 assert len(self._cond_net.Proto().op) > 0,
'Invalid condition net' 736 last_op = self._cond_net.Proto().op[-1]
737 assert len(last_op.output) == 1,
'Invalid condition net' 742 NetBuilder.__exit__(self, type, *args)
def _reset_children(self)
def loop(self, iters=None, name=None)
def local_reporter(self, interval_ms=1000, name=None)
def task_instance_init(self)
def Else(self, name=None)
def WhileNet(self, name=None)
def _update_lexical_scope(self)
def current_net(self, name=None)
def __getattr__(self, op_type)
def local_exit(self, name=None)
def If(self, cond, name=None)
def IfNet(self, cond, name=None)
def task_instance_exit(self)
def stop_guard(self, has_stopped_blob=None, name=None)
def net(self, net=None, name=None)
def _assert_mutable(self)
def Condition(self, name=None)
def task_reporter(self, interval_ms=1000, name=None)