Caffe2 - Python API
A deep learning, cross platform ML framework
control.py
1 ## @package control
2 # Module caffe2.python.control
3 """
4 Implement functions for controlling execution of nets and steps, including
5  Do
6  DoParallel
7  For-loop
8  While-loop
9  Do-While-loop
10  Switch
11  If
12 """
13 
14 from __future__ import absolute_import
15 from __future__ import division
16 from __future__ import print_function
17 from __future__ import unicode_literals
18 
19 from caffe2.python import core
20 from future.utils import viewitems
21 
22 
23 # Used to generate names of the steps created by the control functions.
24 # It is actually the internal index of these steps.
25 _current_idx = 1
26 _used_step_names = set()
27 
28 
29 def _get_next_step_name(control_name, base_name):
30  global _current_idx, _used_step_names
31  concat_name = '%s/%s' % (base_name, control_name)
32  next_name = concat_name
33  while next_name in _used_step_names:
34  next_name = '%s_%d' % (concat_name, _current_idx)
35  _current_idx += 1
36  _used_step_names.add(next_name)
37  return next_name
38 
39 
40 def _MakeList(input):
41  """ input is a tuple.
42  Example:
43  (a, b, c) --> [a, b, c]
44  (a) --> [a]
45  ([a, b, c]) --> [a, b, c]
46  """
47  if len(input) == 0:
48  raise ValueError(
49  'input cannot be empty.')
50  elif len(input) == 1:
51  output = input[0]
52  if not isinstance(output, list):
53  output = [output]
54  else:
55  output = list(input)
56  return output
57 
58 
59 def _IsNets(nets_or_steps):
60  if isinstance(nets_or_steps, list):
61  return all(isinstance(n, core.Net) for n in nets_or_steps)
62  else:
63  return isinstance(nets_or_steps, core.Net)
64 
65 
66 def _PrependNets(nets_or_steps, *nets):
67  nets_or_steps = _MakeList((nets_or_steps,))
68  nets = _MakeList(nets)
69  if _IsNets(nets_or_steps):
70  return nets + nets_or_steps
71  else:
72  return [Do('prepend', nets)] + nets_or_steps
73 
74 
75 def _AppendNets(nets_or_steps, *nets):
76  nets_or_steps = _MakeList((nets_or_steps,))
77  nets = _MakeList(nets)
78  if _IsNets(nets_or_steps):
79  return nets_or_steps + nets
80  else:
81  return nets_or_steps + [Do('append', nets)]
82 
83 
84 def GetConditionBlobFromNet(condition_net):
85  """
86  The condition blob is the last external_output that must
87  be a single bool
88  """
89  assert len(condition_net.Proto().external_output) > 0, (
90  "Condition net %s must has at least one external output" %
91  condition_net.Proto.name)
92  # we need to use a blob reference here instead of a string
93  # otherwise, it will add another name_scope to the input later
94  # when we create new ops (such as OR of two inputs)
95  return core.BlobReference(condition_net.Proto().external_output[-1])
96 
97 
98 def BoolNet(*blobs_with_bool_value):
99  """A net assigning constant bool values to blobs. It is mainly used for
100  initializing condition blobs, for example, in multi-task learning, we
101  need to access reader_done blobs before reader_net run. In that case,
102  the reader_done blobs must be initialized.
103 
104  Args:
105  blobs_with_bool_value: one or more (blob, bool_value) pairs. The net will
106  assign each bool_value to the corresponding blob.
107 
108  returns
109  bool_net: A net assigning constant bool values to blobs.
110 
111  Examples:
112  - BoolNet((blob_1, bool_value_1), ..., (blob_n, bool_value_n))
113  - BoolNet([(blob_1, net1), ..., (blob_n, bool_value_n)])
114  - BoolNet((cond_1, bool_value_1))
115  """
116  blobs_with_bool_value = _MakeList(blobs_with_bool_value)
117  bool_net = core.Net('bool_net')
118  for blob, bool_value in blobs_with_bool_value:
119  out_blob = bool_net.ConstantFill(
120  [],
121  [blob],
122  shape=[],
123  value=bool_value,
124  dtype=core.DataType.BOOL)
125  bool_net.AddExternalOutput(out_blob)
126 
127  return bool_net
128 
129 
130 def NotNet(condition_blob_or_net):
131  """Not of a condition blob or net
132 
133  Args:
134  condition_blob_or_net can be either blob or net. If condition_blob_or_net
135  is Net, the condition is its last external_output
136  that must be a single bool.
137 
138  returns
139  not_net: the net NOT the input
140  out_blob: the output blob of the not_net
141  """
142  if isinstance(condition_blob_or_net, core.Net):
143  condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
144  else:
145  condition_blob = condition_blob_or_net
146 
147  not_net = core.Net('not_net')
148  out_blob = not_net.Not(condition_blob)
149  not_net.AddExternalOutput(out_blob)
150 
151  return not_net, out_blob
152 
153 
154 def _CopyConditionBlobNet(condition_blob):
155  """Make a condition net that copies the condition_blob
156 
157  Args:
158  condition_blob is a single bool.
159 
160  returns
161  not_net: the net NOT the input
162  out_blob: the output blob of the not_net
163  """
164  condition_net = core.Net('copy_condition_blob_net')
165  out_blob = condition_net.Copy(condition_blob)
166  condition_net.AddExternalOutput(out_blob)
167 
168  return condition_net, out_blob
169 
170 
171 def MergeConditionNets(name, condition_nets, relation):
172  """
173  Merge multi condition nets into a single condition nets.
174 
175  Args:
176  name: name of the new condition net.
177  condition_nets: a list of condition nets. The last external_output
178  of each condition net must be single bool value.
179  relation: can be 'And' or 'Or'.
180 
181  Returns:
182  - A new condition net. Its last external output is relation of all
183  condition_nets.
184  """
185  if not isinstance(condition_nets, list):
186  return condition_nets
187  if len(condition_nets) <= 1:
188  return condition_nets[0] if condition_nets else None
189 
190  merged_net = core.Net(name)
191  for i in range(len(condition_nets)):
192  net_proto = condition_nets[i].Proto()
193  assert net_proto.device_option == merged_net.Proto().device_option
194  assert net_proto.type == merged_net.Proto().type
195  merged_net.Proto().op.extend(net_proto.op)
196  merged_net.Proto().external_input.extend(net_proto.external_input)
197  # discard external outputs as we're combining them together
198  curr_cond = GetConditionBlobFromNet(condition_nets[i])
199  if i == 0:
200  last_cond = curr_cond
201  else:
202  last_cond = merged_net.__getattr__(relation)([last_cond, curr_cond])
203  # merge attributes
204  for k, v in viewitems(condition_nets[i]._attr_dict):
205  merged_net._attr_dict[k] += v
206 
207  merged_net.AddExternalOutput(last_cond)
208 
209  return merged_net
210 
211 
212 def CombineConditions(name, condition_nets, relation):
213  """
214  Combine conditions of multi nets into a single condition nets. Unlike
215  MergeConditionNets, the actual body of condition_nets is not copied into
216  the combine condition net.
217 
218  One example is about multi readers. Each reader net has a reader_done
219  condition. When we want to check whether all readers are done, we can
220  use this function to build a new net.
221 
222  Args:
223  name: name of the new condition net.
224  condition_nets: a list of condition nets. The last external_output
225  of each condition net must be single bool value.
226  relation: can be 'And' or 'Or'.
227 
228  Returns:
229  - A new condition net. Its last external output is relation of all
230  condition_nets.
231  """
232  if not condition_nets:
233  return None
234  if not isinstance(condition_nets, list):
235  raise ValueError('condition_nets must be a list of nets.')
236 
237  if len(condition_nets) == 1:
238  condition_blob = GetConditionBlobFromNet(condition_nets[0])
239  condition_net, _ = _CopyConditionBlobNet(condition_blob)
240  return condition_net
241 
242  combined_net = core.Net(name)
243  for i in range(len(condition_nets)):
244  curr_cond = GetConditionBlobFromNet(condition_nets[i])
245  if i == 0:
246  last_cond = curr_cond
247  else:
248  last_cond = combined_net.__getattr__(relation)(
249  [last_cond, curr_cond])
250 
251  combined_net.AddExternalOutput(last_cond)
252 
253  return combined_net
254 
255 
256 def Do(name, *nets_or_steps):
257  """
258  Execute the sequence of nets or steps once.
259 
260  Examples:
261  - Do('myDo', net1, net2, ..., net_n)
262  - Do('myDo', list_of_nets)
263  - Do('myDo', step1, step2, ..., step_n)
264  - Do('myDo', list_of_steps)
265  """
266  nets_or_steps = _MakeList(nets_or_steps)
267  if (len(nets_or_steps) == 1 and isinstance(
268  nets_or_steps[0], core.ExecutionStep)):
269  return nets_or_steps[0]
270  else:
271  return core.scoped_execution_step(
272  _get_next_step_name('Do', name), nets_or_steps)
273 
274 
275 def DoParallel(name, *nets_or_steps):
276  """
277  Execute the nets or steps in parallel, waiting for all of them to finish
278 
279  Examples:
280  - DoParallel('pDo', net1, net2, ..., net_n)
281  - DoParallel('pDo', list_of_nets)
282  - DoParallel('pDo', step1, step2, ..., step_n)
283  - DoParallel('pDo', list_of_steps)
284  """
285  nets_or_steps = _MakeList(nets_or_steps)
286  if (len(nets_or_steps) == 1 and isinstance(
287  nets_or_steps[0], core.ExecutionStep)):
288  return nets_or_steps[0]
289  else:
290  return core.scoped_execution_step(
291  _get_next_step_name('DoParallel', name),
292  nets_or_steps,
293  concurrent_substeps=True)
294 
295 
296 def _RunOnceIf(name, condition_blob_or_net, nets_or_steps):
297  """
298  Execute nets_or_steps once if condition_blob_or_net evaluates as true.
299 
300  If condition_blob_or_net is Net, the condition is its last external_output
301  that must be a single bool. And this net will be executed before
302  nets_or_steps so as to get the condition.
303  """
304  condition_not_net, stop_blob = NotNet(condition_blob_or_net)
305  if isinstance(condition_blob_or_net, core.Net):
306  nets_or_steps = _PrependNets(
307  nets_or_steps, condition_blob_or_net, condition_not_net)
308  else:
309  nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
310 
311  def if_step(control_name):
312  return core.scoped_execution_step(
313  _get_next_step_name(control_name, name),
314  nets_or_steps,
315  should_stop_blob=stop_blob,
316  only_once=True,
317  )
318 
319  if _IsNets(nets_or_steps):
320  bool_net = BoolNet((stop_blob, False))
321  return Do(name + '/_RunOnceIf',
322  bool_net, if_step('_RunOnceIf-inner'))
323  else:
324  return if_step('_RunOnceIf')
325 
326 
327 def _RunOnceIfNot(name, condition_blob_or_net, nets_or_steps):
328  """
329  Similar to _RunOnceIf() but Execute nets_or_steps once if
330  condition_blob_or_net evaluates as false.
331  """
332  if isinstance(condition_blob_or_net, core.Net):
333  condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
334  nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
335  else:
336  copy_net, condition_blob = _CopyConditionBlobNet(condition_blob_or_net)
337  nets_or_steps = _PrependNets(nets_or_steps, copy_net)
338 
339  return core.scoped_execution_step(
340  _get_next_step_name('_RunOnceIfNot', name),
341  nets_or_steps,
342  should_stop_blob=condition_blob,
343  only_once=True,
344  )
345 
346 
347 def For(name, nets_or_steps, iter_num):
348  """
349  Execute nets_or_steps iter_num times.
350 
351  Args:
352  nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
353  a list nets.
354  iter_num: the number times to execute the nets_or_steps.
355 
356  Returns:
357  A ExecutionStep instance.
358  """
359  init_net = core.Net('init-net')
360  iter_cnt = init_net.CreateCounter([], init_count=iter_num)
361  iter_net = core.Net('For-iter')
362  iter_done = iter_net.CountDown([iter_cnt])
363 
364  for_step = core.scoped_execution_step(
365  _get_next_step_name('For-inner', name),
366  _PrependNets(nets_or_steps, iter_net),
367  should_stop_blob=iter_done)
368  return Do(name + '/For',
369  Do(name + '/For-init-net', init_net),
370  for_step)
371 
372 
373 def While(name, condition_blob_or_net, nets_or_steps):
374  """
375  Execute nets_or_steps when condition_blob_or_net returns true.
376 
377  Args:
378  condition_blob_or_net: If it is an instance of Net, its last
379  external_output must be a single bool.
380  nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
381  a list nets.
382 
383  Returns:
384  A ExecutionStep instance.
385  """
386  condition_not_net, stop_blob = NotNet(condition_blob_or_net)
387  if isinstance(condition_blob_or_net, core.Net):
388  nets_or_steps = _PrependNets(
389  nets_or_steps, condition_blob_or_net, condition_not_net)
390  else:
391  nets_or_steps = _PrependNets(nets_or_steps, condition_not_net)
392 
393  def while_step(control_name):
394  return core.scoped_execution_step(
395  _get_next_step_name(control_name, name),
396  nets_or_steps,
397  should_stop_blob=stop_blob,
398  )
399 
400  if _IsNets(nets_or_steps):
401  # In this case, while_step has sub-nets:
402  # [condition_blob_or_net, condition_not_net, nets_or_steps]
403  # If stop_blob is pre-set to True (this may happen when While() is
404  # called twice), the loop will exit after executing
405  # condition_blob_or_net. So we use BootNet to set stop_blob to
406  # False.
407  bool_net = BoolNet((stop_blob, False))
408  return Do(name + '/While', bool_net, while_step('While-inner'))
409  else:
410  return while_step('While')
411 
412 
413 def Until(name, condition_blob_or_net, nets_or_steps):
414  """
415  Similar to While() but execute nets_or_steps when
416  condition_blob_or_net returns false
417  """
418  if isinstance(condition_blob_or_net, core.Net):
419  stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
420  nets_or_steps = _PrependNets(nets_or_steps, condition_blob_or_net)
421  else:
422  stop_blob = core.BlobReference(str(condition_blob_or_net))
423 
424  return core.scoped_execution_step(
425  _get_next_step_name('Until', name),
426  nets_or_steps,
427  should_stop_blob=stop_blob)
428 
429 
430 def DoWhile(name, condition_blob_or_net, nets_or_steps):
431  """
432  Execute nets_or_steps when condition_blob_or_net returns true. It will
433  execute nets_or_steps before evaluating condition_blob_or_net.
434 
435  Args:
436  condition_blob_or_net: if it is an instance of Net, tts last external_output
437  must be a single bool.
438  nets_or_steps: a ExecutionStep or a Net or a list of ExecutionSteps or
439  a list nets.
440 
441  Returns:
442  A ExecutionStep instance.
443  """
444  condition_not_net, stop_blob = NotNet(condition_blob_or_net)
445  if isinstance(condition_blob_or_net, core.Net):
446  nets_or_steps = _AppendNets(
447  nets_or_steps, condition_blob_or_net, condition_not_net)
448  else:
449  nets_or_steps = _AppendNets(nets_or_steps, condition_not_net)
450 
451  # If stop_blob is pre-set to True (this may happen when DoWhile() is
452  # called twice), the loop will exit after executing the first net/step
453  # in nets_or_steps. This is not what we want. So we use BootNet to
454  # set stop_blob to False.
455  bool_net = BoolNet((stop_blob, False))
456  return Do(name + '/DoWhile', bool_net, core.scoped_execution_step(
457  _get_next_step_name('DoWhile-inner', name),
458  nets_or_steps,
459  should_stop_blob=stop_blob,
460  ))
461 
462 
463 def DoUntil(name, condition_blob_or_net, nets_or_steps):
464  """
465  Similar to DoWhile() but execute nets_or_steps when
466  condition_blob_or_net returns false. It will execute
467  nets_or_steps before evaluating condition_blob_or_net.
468 
469  Special case: if condition_blob_or_net is a blob and is pre-set to
470  true, then only the first net/step of nets_or_steps will be executed and
471  loop is exited. So you need to be careful about the initial value the
472  condition blob when using DoUntil(), esp when DoUntil() is called twice.
473  """
474  if not isinstance(condition_blob_or_net, core.Net):
475  stop_blob = core.BlobReference(condition_blob_or_net)
476  return core.scoped_execution_step(
477  _get_next_step_name('DoUntil', name),
478  nets_or_steps,
479  should_stop_blob=stop_blob)
480 
481  nets_or_steps = _AppendNets(nets_or_steps, condition_blob_or_net)
482  stop_blob = GetConditionBlobFromNet(condition_blob_or_net)
483 
484  # If stop_blob is pre-set to True (this may happen when DoWhile() is
485  # called twice), the loop will exit after executing the first net/step
486  # in nets_or_steps. This is not what we want. So we use BootNet to
487  # set stop_blob to False.
488  bool_net = BoolNet((stop_blob, False))
489  return Do(name + '/DoUntil', bool_net, core.scoped_execution_step(
490  _get_next_step_name('DoUntil-inner', name),
491  nets_or_steps,
492  should_stop_blob=stop_blob,
493  ))
494 
495 
496 def Switch(name, *conditions):
497  """
498  Execute the steps for which the condition is true.
499  Each condition is a tuple (condition_blob_or_net, nets_or_steps).
500  Note:
501  1. Multi steps can be executed if their conditions are true.
502  2. The conditions_blob_or_net (if it is Net) of all steps will be
503  executed once.
504 
505  Examples:
506  - Switch('name', (cond_1, net_1), (cond_2, net_2), ..., (cond_n, net_n))
507  - Switch('name', [(cond_1, net1), (cond_2, net_2), ..., (cond_n, net_n)])
508  - Switch('name', (cond_1, net_1))
509  """
510  conditions = _MakeList(conditions)
511  return core.scoped_execution_step(
512  _get_next_step_name('Switch', name),
513  [_RunOnceIf(name + '/Switch', cond, step) for cond, step in conditions])
514 
515 
516 def SwitchNot(name, *conditions):
517  """
518  Similar to Switch() but execute the steps for which the condition is False.
519  """
520  conditions = _MakeList(conditions)
521  return core.scoped_execution_step(
522  _get_next_step_name('SwitchNot', name),
523  [_RunOnceIfNot(name + '/SwitchNot', cond, step)
524  for cond, step in conditions])
525 
526 
527 def If(name, condition_blob_or_net,
528  true_nets_or_steps, false_nets_or_steps=None):
529  """
530  condition_blob_or_net is first evaluated or executed. If the condition is
531  true, true_nets_or_steps is then executed, otherwise, false_nets_or_steps
532  is executed.
533 
534  If condition_blob_or_net is Net, the condition is its last external_output
535  that must be a single bool. And this Net will be executred before both
536  true/false_nets_or_steps so as to get the condition.
537  """
538  if not false_nets_or_steps:
539  return _RunOnceIf(name + '/If',
540  condition_blob_or_net, true_nets_or_steps)
541 
542  if isinstance(condition_blob_or_net, core.Net):
543  condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
544  else:
545  condition_blob = condition_blob_or_net
546 
547  return Do(
548  name + '/If',
549  _RunOnceIf(name + '/If-true',
550  condition_blob_or_net, true_nets_or_steps),
551  _RunOnceIfNot(name + '/If-false', condition_blob, false_nets_or_steps)
552  )
553 
554 
555 def IfNot(name, condition_blob_or_net,
556  true_nets_or_steps, false_nets_or_steps=None):
557  """
558  If condition_blob_or_net returns false, executes true_nets_or_steps,
559  otherwise executes false_nets_or_steps
560  """
561  if not false_nets_or_steps:
562  return _RunOnceIfNot(name + '/IfNot',
563  condition_blob_or_net, true_nets_or_steps)
564 
565  if isinstance(condition_blob_or_net, core.Net):
566  condition_blob = GetConditionBlobFromNet(condition_blob_or_net)
567  else:
568  condition_blob = condition_blob_or_net
569 
570  return Do(
571  name + '/IfNot',
572  _RunOnceIfNot(name + '/IfNot-true',
573  condition_blob_or_net, true_nets_or_steps),
574  _RunOnceIf(name + '/IfNot-false', condition_blob, false_nets_or_steps)
575  )