Caffe2 - Python API
A deep learning, cross platform ML framework
memonger.py
1 ## @package memonger
2 # Module caffe2.python.memonger
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import networkx as nx
9 import collections
10 import time
11 import copy
12 from caffe2.python import workspace, core
13 from caffe2.proto import caffe2_pb2
14 import enum
15 import logging
16 from future.utils import viewitems, viewvalues
18 
19 log = logging.getLogger("memonger")
20 log.setLevel(logging.INFO)
21 LiveRange = collections.namedtuple('LiveRange', ["defined", "used", "size"])
22 
23 
24 def share_grad_blobs(
25  net,
26  losses,
27  param_grads,
28  namescope,
29  dont_share_blobs=None,
30  share_activations=False,
31  blob_shapes=None,
32 ):
33  '''
34  Implements similar optimization as Torch's shareGradInput():
35  for the gradients that are passed between layers, share blobs between
36  operators when possible. This yields significant memory savings with
37  deep networks.
38 
39  Returns an optimized protobuf (assign to net._net)
40  '''
41  def is_grad_blob(b):
42  name = str(b)
43  # Note: need to look at _{namescope} pattern as it matches
44  # to handle the auto-split gradients
45  return name.endswith("_grad") and (name.startswith(namescope) or
46  name.startswith("_" + namescope)) and name not in param_grads
47 
48  def is_grad_op(op):
49  # TODO: something smarter
50  for b in list(op.input) + list(op.output):
51  if is_grad_blob(b):
52  return True
53  return False
54 
55  log.warn("NOTE: Executing memonger to optimize gradient memory")
56 
57  # Collect ops that have something to do with gradients
58  if namescope != "" and not namescope.endswith("/"):
59  namescope += "/"
60 
61  netproto = copy.deepcopy(net.Proto())
62  activations = []
63  external_output = set(net.Proto().external_output)
64 
65  # Hacky way to get activations, think of a better way
66  for op in net.Proto().op:
67  for b in op.output:
68  if b + "_w" in op.input and b not in external_output:
69  activations.append(b)
70 
71  # Remove last activations, as they are usually accessed externally
72  activations = set(activations[:-2])
73 
74  # Gradient ops
75  grad_op_indices = []
76  for idx, op in enumerate(netproto.op):
77  if (is_grad_op(op)):
78  grad_op_indices.append(idx)
79 
80  shared_blobs = set()
81  for op in net.Proto().op:
82  for b in list(op.input) + list(op.output):
83  if is_grad_blob(b) or (share_activations and b in activations):
84  shared_blobs.add(b)
85  start_time = time.time()
86  optim_str = C.memonger_compute_blob_recycling_for_dag(
87  netproto.SerializeToString(),
88  [str(s).encode('utf-8') for s in losses],
89  grad_op_indices,
90  set(str(s).encode('utf-8') for s in shared_blobs),
91  namescope.encode('utf-8'),
92  set() if dont_share_blobs is None else dont_share_blobs,
93  {} if blob_shapes is None else blob_shapes
94  )
95 
96  log.info("Memonger memory optimization took {} secs".format(
97  time.time() - start_time),
98  )
99 
100  optim = caffe2_pb2.NetDef()
101  optim.ParseFromString(optim_str)
102  assert verify_graph_equality(net.Proto(), optim), \
103  "Memonger graph is not equal to original."
104  assert verify_inplace_blobs(net.Proto(), optim), \
105  "Inplace assignments differ in memonger net."
106  return optim
107 
108 
109 def optimize_inference_for_dag(net, input_blobs, namescope=""):
110  netproto = copy.deepcopy(net.Proto())
111  external_input = set(net.Proto().external_input)
112  external_output = set(net.Proto().external_output)
113 
114  def is_activation_blob(b):
115  return b not in external_input and b not in external_output
116 
117  activation_blobs = set()
118  seen_as_output = set()
119  ops = list(net.Proto().op)
120  op_indices = [index for index, op in enumerate(net.Proto().op)]
121 
122  # Sanity check: check that all external inputs are properlyh accounted
123  # and that no gradient ops are included in 'net'
124  for op in ops:
125  for b in op.input:
126  if is_activation_blob(b):
127  activation_blobs.add(b)
128  if b not in seen_as_output:
129  assert False, "{} not in external input".format(b)
130  for b in op.output:
131  if is_activation_blob(b):
132  activation_blobs.add(b)
133  seen_as_output = seen_as_output.union(set(op.output))
134  assert not op.is_gradient_op, \
135  "You can only pass inference-only nets to optimize_inference_for_dag"
136  start_time = time.time()
137  optim_str = C.memonger_compute_blob_recycling_for_dag(
138  netproto.SerializeToString(),
139  [str(s).encode('utf-8') for s in input_blobs],
140  op_indices,
141  set(str(s).encode('utf-8') for s in activation_blobs),
142  namescope.encode('utf-8'),
143  set(),
144  {}
145  )
146 
147  log.info("Memonger memory optimization took {} secs".format(
148  time.time() - start_time),
149  )
150 
151  optim = caffe2_pb2.NetDef()
152  optim.ParseFromString(optim_str)
153 
154  assert verify_graph_equality(net.Proto(), optim), \
155  "Memonger graph is not equal to original."
156  assert verify_inplace_blobs(net.Proto(), optim), \
157  "Inplace assignments differ in memonger net."
158  return optim
159 
160 
161 def estimate_memory_usage(protos, shapes, types, devicescope):
162  import numpy as np
163  '''
164  Estimate memory usage of a model. This is an estimate because
165  we assume a single threaded execution and miss some internal
166  memory usage of operators. Only estimates the memory for a given
167  device scope.
168 
169  Also, currently it does not handle correctly if blob sizes vary
170  during execution, as it uses only the final blob size.
171 
172  Returns (total, highwater, by op type) memory allocation in bytes.
173  '''
174  sizeofs = {
175  caffe2_pb2.TensorProto.DOUBLE: 8,
176  caffe2_pb2.TensorProto.FLOAT: 4,
177  caffe2_pb2.TensorProto.FLOAT16: 2,
178  caffe2_pb2.TensorProto.INT32: 4,
179  caffe2_pb2.TensorProto.INT8: 1,
180  caffe2_pb2.TensorProto.UINT8: 1,
181  caffe2_pb2.TensorProto.UINT16: 2,
182  caffe2_pb2.TensorProto.INT16: 2,
183  caffe2_pb2.TensorProto.BOOL: 1,
184  caffe2_pb2.TensorProto.INT64: 8,
185  }
186 
187  def split_net(proto):
188  ops = [op for op in proto.op if
189  op.device_option == devicescope or op.type in {"Free", "Alias"}]
190  del proto.op[:]
191  proto.op.extend(ops)
192  return proto
193 
194  def num_bytes(blob):
195  if blob not in shapes or blob not in types:
196  log.warning("Unknown blob encountered: {}".format(blob))
197  return 0
198  sizeof = sizeofs[types[blob]]
199  return sizeof * np.prod(shapes[blob])
200 
201  protos = [split_net(proto) for proto in protos]
202  allocs_by_ops = collections.defaultdict(lambda: 0)
203 
204  # Evaluate
205  current_allocated = 0
206  max_allocated = 0
207  total_allocated = 0
208  allocated = set()
209  for proto in protos:
210  for op in proto.op:
211  if op.type == "Free" or op.type == "Alias":
212  for o in op.output:
213  if o in allocated:
214  current_allocated -= num_bytes(o)
215  allocated.remove(o)
216  else:
217  for output in op.output:
218  if output not in allocated:
219  nbytes = num_bytes(output)
220  total_allocated += nbytes
221  current_allocated += nbytes
222  max_allocated = max(max_allocated, current_allocated)
223  allocated.add(output)
224  allocs_by_ops[op.type] += nbytes
225 
226  return (total_allocated, max_allocated, allocs_by_ops)
227 
228 
229 def release_blobs_when_used(netproto, dont_free_blobs, selector_fun=None):
230  '''
231  Insert Free-ops after a blob has been used the last time, so that its
232  memory can be reclaimed. Use this only with efficient caching memory
233  managers (such as CUB, --caffe2_cuda_memory_pool=cub).
234 
235  Blobs used with Alias op won't be freed.
236 
237  @dont_free_blobs: is a set of blobs that should not be freed
238  @selector_fun: optional lambda that return True if blob name
239  can be released. Use for easy special filtering, like
240  excluding blobs with "loss" in the name.
241 
242  Returns a new protobuffer. To use with a model, use:
243  model.net._net = memonger.release_blobs_when_used(..)
244  '''
245  input_blobs = set()
246  can_release = set()
247  alias_blobs = set()
248  netproto = copy.deepcopy(netproto)
249 
250  for op in netproto.op:
251  if op.type == 'Alias':
252  alias_blobs.add(op.input[0])
253  continue
254  for inp in op.input:
255  input_blobs.add(inp)
256  for outp in op.output:
257  if outp not in input_blobs:
258  if selector_fun is None or selector_fun(outp):
259  can_release.add(outp)
260 
261  # Remove such blobs that are not input at all and external outputs
262  can_release = can_release - set(netproto.external_output)
263  can_release = can_release.intersection(input_blobs)
264  can_release = can_release - dont_free_blobs
265  can_release = can_release - alias_blobs
266 
267  ops = list(netproto.op)
268 
269  # .. then find last use of each can-release blob, and insert a Free op
270  for j in reversed(range(0, len(netproto.op))):
271  op = netproto.op[j]
272  for inp in op.input:
273  if inp in can_release:
274  can_release.remove(inp)
275  ops.insert(j + 1, core.CreateOperator("Free", [inp], [inp]))
276 
277  del netproto.op[:]
278  netproto.op.extend(ops)
279  return netproto
280 
281 
282 def _find_source_nodes(g):
283  ''' Return nodes without predecessors '''
284  ret = []
285  for cn in g:
286  cur_pred = list(g.predecessors(cn))
287  if not cur_pred:
288  ret.append(cn)
289  return ret
290 
291 
292 def _find_target_nodes(g):
293  ''' Return nodes without successors '''
294  ret = []
295  for cn in g:
296  cur_succ = list(g.successors(cn))
297  if not cur_succ:
298  ret.append(cn)
299  return ret
300 
301 
302 def _add_single_target_ifneeded(g):
303  targets = _find_target_nodes(g)
304  assert len(targets) >= 1
305  if len(targets) == 1:
306  return g
307  ret = copy.deepcopy(g)
308 
309  def _next_available_idx(g):
310  ret = -1
311  for cn in g:
312  if cn > ret:
313  ret = cn
314  ret += 1
315  return ret
316 
317  target_node_idx = _next_available_idx(g)
318  ret.add_node(target_node_idx)
319  for cn in targets:
320  ret.add_edge(cn, target_node_idx)
321 
322  return ret
323 
324 
325 def _get_path(pred_list, dist_list):
326  ''' Get the path from nx.bellman_ford()'s output '''
327 
328  # distances are negative
329  assert all(dist_list[x] <= 0 for x in dist_list)
330  # node with longest distance to source is the target
331  target = min(dist_list, key=lambda x: dist_list[x])
332 
333  ret = []
334  cur = target
335 
336 
337  while cur is not None:
338  ret.append(cur)
339  # Hack to get networkx 2.0 happy: it uses list in pred.
340  # TODO(tulloch): are there cases with multiple predecessors?
341  try:
342  cur = pred_list[cur][0]
343  except TypeError:
344  cur = pred_list[cur]
345 
346  return list(reversed(ret))
347 
348 
349 def _get_longest_paths(g, source_nodes):
350  ''' Get the longest path for nodes in 'source_nodes'
351  Find with bellman_ford() by setting weight = -1
352  '''
353 
354  ng = copy.deepcopy(g)
355  for u, v in ng.edges():
356  ng[u][v]["weight"] = -1
357 
358  ret = {}
359  for cn in source_nodes:
360  pred, dist = nx.bellman_ford(ng, cn, weight="weight")
361  path = _get_path(pred, dist)
362  assert path[0] == cn
363  assert len(path) - 1 == -dist[path[-1]]
364  ret[cn] = path
365 
366  return ret
367 
368 
369 def _build_tree(paths):
370  ''' Build a tree for given paths based on common elements.
371  Last elements of all paths are the same, which is the root of the tree.
372  '''
373  assert all(cp[-1] == paths[0][-1] for cp in paths)
374  g = nx.DiGraph()
375  node_set = {y for x in paths for y in x}
376  g.add_nodes_from(node_set)
377  for cp in paths:
378  for ce in zip(cp[0:-1], cp[1:]):
379  g.add_edge(ce[1], ce[0])
380 
381  root = paths[0][-1]
382  _compute_tree_height(g, root)
383 
384  return (g, root)
385 
386 
387 def _compute_tree_height(g, root):
388  ''' Compute the heights of the tree for all nodes
389  Height of leaves are 0
390  '''
391  def _get_height(root):
392  children = list(g.successors(root))
393  height = 0
394  if children:
395  child_heights = [_get_height(x) for x in children]
396  height = max(child_heights) + 1
397  g.node[root]["height"] = height
398  return height
399 
400  _get_height(root)
401 
402 
403 def _sort_tree_leaves(g, root):
404  ''' For each node, sort its child nodes based on the height of the nodes.
405  Return the leaf nodes of the tree after sorting.
406  '''
407  def _get_height(root):
408  return g.node[root]["height"]
409 
410  def _get_sorted_leaves(root):
411  children = list(g.successors(root))
412  if not children:
413  return [root]
414  child_heights = [_get_height(x) for x in children]
415  order = sorted(range(len(children)), key=lambda x: child_heights[x])
416  ret = []
417  for co in order:
418  cr = children[co]
419  ret += _get_sorted_leaves(cr)
420 
421  return ret
422 
423  return _get_sorted_leaves(root)
424 
425 
426 def topological_sort_traversal_longest_path(g):
427  ''' The graph 'g' may contain several source nodes (nodes without incoming
428  edge), which could be in any order and still be a valid
429  topological sorting result. We would like to arrange these source nodes
430  so that the average live spans of the computed blobs are shorter.
431  The idea is to sort the source nodes based on the length of their path to
432  the target node so that the one with longer path is used first.
433  This is done by:
434  - Add a single target node if there are multiple target nodes in 'g'.
435  - Find the longest path between each source and the target node.
436  - Convert the longest paths to a tree with the target node being the root
437  and source nodes being the leaves.
438  - Sort the nodes of the tree based on the height of the tree.
439  '''
440  gt = _add_single_target_ifneeded(g)
441  source_nodes = _find_source_nodes(gt)
442  lpaths = _get_longest_paths(gt, source_nodes)
443  tree, root = _build_tree(list(viewvalues(lpaths)))
444  sorted_sources = _sort_tree_leaves(tree, root)
445  assert(sorted(sorted_sources) == sorted(source_nodes))
446 
447  if nx.__version__ < '2.0':
448  ret = nx.topological_sort(g, sorted_sources)
449  else:
450  # Manually making a sorted descendent list
451  dependency_order = list(sorted_sources)
452  seen_nodes = set(sorted_sources)
453  for s in sorted_sources:
454  desc = nx.descendants(g, s)
455  for d in desc:
456  if d not in seen_nodes:
457  seen_nodes.add(d)
458  dependency_order.append(d)
459  sort_key = dict((v, len(dependency_order) - i) for i, v in enumerate(dependency_order))
460  ret = nx.algorithms.dag.lexicographical_topological_sort(
461  g, key=lambda x: sort_key[x])
462  ret = list(ret)
463  assert(len(ret) == len(g.node))
464  return ret
465 
466 
467 def topological_sort_traversal(g):
468  return list(nx.topological_sort(g))
469 
470 
471 def compute_ranges(linearized_ops, blob_sizes=None):
472  if not blob_sizes:
473  log.warning('Provide blob sizes to get more accurate assignments.')
474 
475  blobs = collections.defaultdict(
476  lambda: LiveRange(defined=None, used=None, size=None))
477  for i, op in enumerate(linearized_ops):
478  for blob in op.input:
479  used = blobs[blob].used
480  if used is None:
481  used = i
482  else:
483  used = max(used, i)
484  blobs[blob] = blobs[blob]._replace(used=used)
485  blob_size = blob_sizes[blob] if blob_sizes else None
486  assert not blob_sizes or blob_size is not None
487  blobs[blob] = blobs[blob]._replace(size=blob_size)
488  for blob in op.output:
489  defined = blobs[blob].defined
490  if defined is None:
491  defined = i
492  else:
493  defined = min(defined, i)
494  blobs[blob] = blobs[blob]._replace(defined=defined)
495  blob_size = blob_sizes[blob] if blob_sizes else None
496  assert not blob_sizes or blob_size is not None
497  blobs[blob] = blobs[blob]._replace(size=blob_size)
498 
499  return blobs
500 
501 
502 def is_compatible(candidate_range, assignment, static_blobs):
503  (name, range_) = assignment[-1]
504  if name in static_blobs:
505  return False
506  if candidate_range.defined is None or range_.defined is None \
507  or range_.used is None:
508  return False
509  return candidate_range.defined > range_.used
510 
511 
512 def compute_blob_assignments(assignments):
513  blob_assignments = {}
514  for assignment in assignments:
515  if len(assignment) == 1:
516  continue
517  last_blob, _ = assignment[-1]
518  for (blob, _) in assignment:
519  blob_assignments[blob] = last_blob
520  return blob_assignments
521 
522 
523 def _get_max_size(assignment):
524  if not assignment:
525  return 0
526  ret = max([x[1].size for x in assignment])
527  ret = 0 if ret is None else ret
528  return ret
529 
530 
531 def get_memory_usage(assignments):
532  ret = 0
533  for cur in assignments:
534  ret += _get_max_size(cur)
535  return ret
536 
537 
538 def compute_assignments_greedy(ranges_sorted, init_assignments=None):
539  assignments = init_assignments or []
540  visited = {y[0] for x in assignments for y in x}
541 
542  for (name, range_) in ranges_sorted:
543  if name in visited:
544  continue
545  assigned = False
546  best_assignment = 0
547  min_dist = float("inf")
548  candidate_size = range_.size or 0
549  for idx, assignment in enumerate(assignments):
550  if is_compatible(range_, assignment, []):
551  assigned = True
552  dist = abs(_get_max_size(assignment) - candidate_size)
553  if dist < min_dist:
554  min_dist = dist
555  best_assignment = idx
556  if assigned:
557  assignment = assignments[best_assignment]
558  assignment.append((name, range_))
559  else:
560  assignments.append([(name, range_)])
561  return assignments
562 
563 
564 def _get_count(assignments):
565  ''' Return number of blobs in assignments '''
566  if assignments:
567  return sum([len(x) for x in assignments])
568  return 0
569 
570 
571 def compute_assignments_dp(ranges_sorted, init_assignment, counter=None):
572  ''' Compute assignment for blobs in 'ranges_sorted' on top of 'init_assignment'
573  using dynamic programming + recursion.
574 
575  ranges_sorted: blobs sorted by 'used'
576  init_assignment: assignment to start with, blobs in 'ranges_sorted' should
577  not be used in 'init_assignment'
578 
579  Using f(b, k, init) to represent the best assignment for blobs b[0:k]
580  given initial assignment 'init', we have
581  f(b, k, init) = f(b, j, init) +
582  find_best(b[j:k], f(b, j, init))
583  where j is the index of the last best assignment that is independent of
584  blob b[k - 1] (b[k - 1] is compatible with all assignments in
585  f(b, j, init)), and find_best(b1, init1) gives the best assignment
586  for blobs in 'b1' based on the initial assignment 'init1', and blobs
587  b1[0:-1] should be incompatible with b1[-1]. f(b, len(b), []) gives
588  the best assignment for blobs 'b'.
589 
590  For find_best(b, init), since b[0:-1] are not compatible with b[-1], we
591  could reduce it to a smaller problem to find best assignment for b[0:-1]
592  as
593  find_best(b, init) = min {
594  f(b[0:-1], len(b) - 1, init - x) + [x, b[-1]] for x in init, or
595  f(b[0:-1], len(b) - 1, init) + [b[-1]]
596  }
597  where min{} gives the assignment with minimum memory usage.
598  '''
599 
600  def _get_compatible_prev(candidate_range, best_assignments, cur_idx):
601  ''' Find closest position k of best_assignments that is independent of
602  candidate_range that candiate_range is compatible with all assignments
603  in best_assignments[k].
604  Return -1 if not found.
605  '''
606  def is_compatible_all(candidate_range, assignments):
607  ''' return true if compatiable for all assignments in assignments '''
608  return all([is_compatible(candidate_range[1], x, []) for x in assignments])
609 
610  ii = cur_idx - 1
611  while ii >= 0:
612  cba = best_assignments[ii]
613  if is_compatible_all(candidate_range, cba):
614  return ii
615  ii -= 1
616  return -1
617 
618  def _find_best(ranges, init_assignment, prev_best_assignment, counter):
619  ''' Find the best assignment for blobs 'ranges' given an initialized
620  assignment 'init_assignment'.
621 
622  Blobs in ranges[0:-1] should be incompatible with blob range[-1].
623  'prev_best_assignment': best assignment for blobs in ranges[:-1]
624 
625  By assigning ranges[-1] to each assignment k in 'init_assignment' or
626  in a new assignment, the problem becomes a smaller problem to find
627  the best assignment for ranges[0:-1] given the initial assignment
628  init_assigment[0:k, (k+1):-1].
629  '''
630  # Blob to check
631  find_range = ranges[-1]
632  # Blobs in ranges[0:-1] are incompatible with ranges[-1] so that we can
633  # reduce it to a smaller problem.
634  assert all(not is_compatible(x[1], [find_range], []) for x in ranges[0:-1])
635 
636  sz = len(init_assignment)
637  best_candidates = []
638  # Try to assign 'find_range' to each assignment in init_assignment
639  for ii in range(sz):
640  if not is_compatible(find_range[1], init_assignment[ii], []):
641  continue
642  cur_best = copy.deepcopy(init_assignment)
643  cur_best[ii].append(find_range)
644  if len(ranges) > 1:
645  cur_best_tmp = [x for i, x in enumerate(cur_best) if i != ii]
646  # reduce to a smaller dp problem
647  cur_best_tmp = compute_assignments_dp(
648  ranges[:-1], cur_best_tmp, counter)
649  cur_best = cur_best_tmp + [cur_best[ii]]
650  best_candidates.append(cur_best)
651  # Try to put 'find_range' in a new assignment
652  best_candidates.append(prev_best_assignment + [[find_range]])
653 
654  ret = min(best_candidates, key=lambda x: get_memory_usage(x))
655  return ret
656 
657  if not counter:
658  counter = [0]
659  counter[0] += 1
660 
661  if counter and counter[0] % 5000 == 0:
662  rs = [ranges_sorted[0][1].defined, ranges_sorted[-1][1].used]
663  log.info('Finding assignments {} ({} -> {})...'.format(
664  counter[0], rs[0], rs[1]))
665 
666  init_assignment = init_assignment or []
667  # best_assignments[k]: best assignments for first k blobs ranges_sorted[0:(k+1)]
668  best_assignments = []
669  # Find best assignment for blobs ranges_sorted[0:ii]
670  for ii, cur_range in enumerate(ranges_sorted):
671  # closest best_assignment that is independent of ranges_sorted[ii]
672  prev_idx = _get_compatible_prev(cur_range, best_assignments, ii)
673  prev_best = copy.deepcopy(init_assignment) if prev_idx < 0 else \
674  copy.deepcopy(best_assignments[prev_idx])
675  # Need to find best assignment for blobs in 'ranges_part'
676  ranges_part = ranges_sorted[(prev_idx + 1):(ii + 1)]
677  cur_best = _find_best(
678  ranges_part, prev_best,
679  best_assignments[-1] if best_assignments else init_assignment,
680  counter)
681  assert _get_count(cur_best) == _get_count(prev_best) + len(ranges_part)
682  best_assignments.append(copy.deepcopy(cur_best))
683 
684  assert len(best_assignments) == len(ranges_sorted)
685 
686  best = best_assignments[-1]
687 
688  return best
689 
690 
691 def get_updated_ranges(ranges, max_live=None):
692  ''' Set LiveRange.defined = -1 if it is None
693  Set LiveRange.used = max_live if it is None
694  Set LiveRanee.size = 1 if it is None
695  '''
696 
697  def _get_max_live(ranges):
698  max_live = max(x[1].used for x in ranges if x[1].used) + 1
699  return max_live
700 
701  def _update_range(x, max_live, size):
702  cx = x
703  if x[1].defined is None:
704  cx = (cx[0], cx[1]._replace(defined=-1))
705  if x[1].used is None:
706  cx = (cx[0], cx[1]._replace(used=max_live))
707  if x[1].size is None:
708  cx = (cx[0], cx[1]._replace(size=size))
709  return cx
710 
711  if max_live is None:
712  max_live = _get_max_live(ranges)
713  ranges = [_update_range(x, max_live, 1) for x in ranges]
714 
715  return ranges
716 
717 
718 def compute_assignments(ranges, static_blobs, algo):
719  '''
720  algo: Method used to find assignments (AssignmentAlgorithm.GREEDY or
721  AssignmentAlgorithm.DYNAMIC_PROGRAMMING).
722  AssignmentAlgorithm.DYNAMIC_PROGRAMMING gives optimal solution at the
723  cost of more computation.
724  AssignmentAlgorithm.GREEDY may be better in the case 'blob_sizes' is
725  not provided.
726  '''
727 
728  # Sort the ranges based on when they are last used.
729  # If LiveRange.used is None, then the blob is never used and could
730  # be consumed externally. Sort these to the end of the list as opposed
731  # to the beginning so that they can be shared as well.
732  ranges = sorted(
733  viewitems(ranges),
734  key=lambda p: (p[1].used is None, p[1].used),
735  )
736  # Update None values
737  ranges = get_updated_ranges(ranges)
738 
739  # Sharable blobs
740  ranges_sharable = [x for x in ranges if x[0] not in static_blobs]
741  # Static blobs, not sharable
742  ranges_static = [x for x in ranges if x[0] in static_blobs]
743 
744  log.info("Total sharable blobs {}".format(len(ranges_sharable)))
745 
746  best_assignment = []
747  if algo == AssignmentAlgorithm.DYNAMIC_PROGRAMMING:
748  best_assignment = compute_assignments_dp(ranges_sharable, [])
749  elif algo == AssignmentAlgorithm.GREEDY:
750  best_assignment = compute_assignments_greedy(ranges_sharable, [])
751  else:
752  assert "Invalid algo name {}".format(algo)
753  best_assignment += [[x] for x in ranges_static]
754 
755  # verify_assignments(best_assignment)
756 
757  return best_assignment
758 
759 
760 def verify_assignments(assignments):
761  for cur in assignments:
762  for x, y in zip(cur[0:-1], cur[1:]):
763  assert x[1].used < y[1].defined
764 
765 
766 def compute_interference_graph(ops):
767  g = nx.DiGraph()
768  for i, op in enumerate(ops):
769  g.add_node(i, op=op)
770  for i, parent_op in enumerate(ops):
771  for j, child_op in enumerate(ops):
772  if i >= j:
773  continue
774  if any(output in child_op.input for output in parent_op.output):
775  deps = set(child_op.input).intersection(parent_op.output)
776  g.add_edge(i, j, deps=deps)
777  assert nx.is_directed_acyclic_graph(g), child_op
778  return g
779 
780 
781 Optimization = collections.namedtuple(
782  'Optimization', ['net', 'assignments', 'blob_assignments'])
783 
784 
785 def apply_assignments(net, blob_assignments):
786  def canonical_name(blob):
787  if blob not in blob_assignments:
788  return blob
789  return blob_assignments[blob]
790 
791  for op in net.op:
792  # Descend into subnets of the recurrent network
793  if op.type.startswith('RecurrentNetwork'):
794  apply_recurrent_blob_assignments(op, blob_assignments, canonical_name)
795 
796  for i, input_ in enumerate(op.input):
797  op.input[i] = canonical_name(input_)
798  for i, output in enumerate(op.output):
799  op.output[i] = canonical_name(output)
800 
801 
802 
803 def apply_recurrent_blob_assignments(op, blob_assignments, canonical_name):
804  log.debug("Applying assignments to recurrent op: {}".format(op.type))
805  step_args = [a for a in op.arg if a.name.endswith("step_net")]
806  for step_arg in step_args:
807  apply_assignments(step_arg.n, blob_assignments)
808  for i, einp in enumerate(step_arg.n.external_input):
809  if einp in blob_assignments:
810  step_arg.n.external_input[i] = canonical_name(einp)
811  # Store renamings
812  for blob, renamed in viewitems(blob_assignments):
813  if blob in list(op.input) + list(op.output):
814  a = caffe2_pb2.Argument()
815  a.name = blob + ".rename"
816  a.s = str(renamed).encode("ascii")
817  op.arg.extend([a])
818 
819 
820 class AssignmentAlgorithm(enum.Enum):
821  GREEDY = 0
822  DYNAMIC_PROGRAMMING = 1
823 
824 
825 def optimize_inference_fast(net, static_blobs):
826  optim = caffe2_pb2.NetDef()
827  optim_str = C.memonger_optimize_inference_net(
828  net.SerializeToString(),
829  [str(s).encode('utf-8') for s in static_blobs]
830  )
831  optim.ParseFromString(optim_str)
832  return optim
833 
834 
835 def optimize_interference(net, static_blobs,
836  ordering_function=topological_sort_traversal,
837  blob_sizes=None,
838  algo=AssignmentAlgorithm.GREEDY):
839  """
840  ordering_function: topological_sort_traversal or
841  topological_sort_traversal_longest_path.
842  topological_sort_traversal_longest_path gives better
843  results but needs a bit more computation.
844  algo: Method used to find assignments (AssignmentAlgorithm.GREEDY or
845  AssignmentAlgorithm.DYNAMIC_PROGRAMMING).
846  AssignmentAlgorithm.DYNAMIC_PROGRAMMING gives optimal solution at the
847  cost of more computation.
848  AssignmentAlgorithm.GREEDY may be better in the case 'blob_sizes' is
849  not provided.
850  """
851 
852  """
853  1) Use a BFS traversal of the execution graph to generate an
854  ordering of the node executions.
855  2) Generate use-def ranges for each `blob` in the BFS traversal
856  order.
857  3) Assign blobs to `canonical blobs`
858  4) Rename blobs to canonical blobs
859  """
860 
861  net = copy.deepcopy(net)
862  g = compute_interference_graph(net.op)
863  ordering = ordering_function(g)
864  linearized_ops = [net.op[i] for i in ordering]
865 
866  # Reorder ops in net based on the computed linearlized order.
867  # If the graph has multiple topological orderings and if the NetDef's
868  # ordering differs from the order used to compute ranges, then the
869  # runtime might end up overwriting blobs before they are used.
870  del net.op[:]
871  net.op.extend(linearized_ops)
872 
873  ranges = compute_ranges(linearized_ops, blob_sizes)
874  assignments = compute_assignments(ranges, static_blobs, algo)
875  blob_assignments = compute_blob_assignments(assignments)
876  apply_assignments(net, blob_assignments)
877  return Optimization(
878  net=net,
879  blob_assignments=blob_assignments,
880  assignments=assignments)
881 
882 
883 def verify_inplace_blobs(net_a, net_b):
884  """
885  Verifies that net_a and net_b have the same in-place blob assignments.
886  Particularly, that memonger did not add an in-place assignment when that
887  did not exist before.
888  """
889  def get_inplaces(op):
890  out = list(op.output)
891  inplaces = []
892  for j, inp in enumerate(op.input):
893  if inp in out:
894  inplaces.append([j, out.index(inp)])
895  return inplaces
896 
897  for op_a, op_b in zip(net_a.op, net_b.op):
898  if op_a.type != op_b.type:
899  return False
900  if get_inplaces(op_a) != get_inplaces(op_b):
901  return False
902  return True
903 
904 
905 def verify_graph_equality(net_a, net_b):
906  """
907  Determines if the execution of two graphs are identical.
908  That is, all inputs blobs are mapped to the same output blobs
909  for each operator in their respective positions.
910 
911  This is meant to check the output of memonger with the original graph.
912  It assumes that the nets have same external input and output.
913 
914  O(E) runtime + O(1) amortized cost to hash for python dict
915  """
916 
917  def parent_list(ops):
918  parent_list = [[] for _ in ops]
919  edge_owner = {}
920  for i, op in enumerate(ops):
921  for blob in op.input:
922  parent_id = edge_owner.get(blob)
923  if parent_id is not None:
924  parent_list[i].append(parent_id)
925  for blob in op.output:
926  edge_owner[blob] = i
927 
928  return parent_list
929 
930  # Operator wise equality checks
931  if (len(net_a.op) != len(net_b.op)):
932  return False
933  for op_a, op_b in zip(net_a.op, net_b.op):
934  if (op_a.type != op_b.type or
935  op_a.device_option != op_b.device_option or
936  op_a.engine != op_b.engine):
937  return False
938 
939  # Print debug info
940  parent_list_a = parent_list(net_a.op)
941  parent_list_b = parent_list(net_b.op)
942  if parent_list_a != parent_list_b:
943  j = 0
944  for a, b in zip(parent_list_a, parent_list_b):
945  if a != b:
946  print("Difference {} vs {} \n {}".format(
947  j, net_a.op[j], net_b.op[j]))
948  print("Parents: {} vs {}".format(a, b))
949 
950  j += 1
951 
952  # Net wise equality check
953  return parent_list_a == parent_list_b
954 
955 
956 Statistics = collections.namedtuple(
957  'Statistics', ['baseline_nbytes', 'optimized_nbytes'])
958 
959 
960 def blob_nbytes(blob):
961  sz = 0
962  try:
963  sz = workspace.FetchBlob(blob).nbytes
964  except Exception:
965  log.warning('Error when fetching blob {}'.format(blob))
966  return sz
967 
968 
969 def compute_statistics(assignments):
970  blob_bytes = {
971  blob: blob_nbytes(blob) for assignment in assignments
972  for (blob, _) in assignment}
973  baseline_nbytes = sum(viewvalues(blob_bytes))
974  optimized_nbytes = sum(
975  max(blob_bytes[blob] for (blob, _) in assignment)
976  for assignment in assignments)
977  return Statistics(
978  baseline_nbytes=baseline_nbytes,
979  optimized_nbytes=optimized_nbytes)
980 
981 
982 def collect_blob_sizes(net):
983  blobs = {}
984  for op in net.op:
985  for blob in op.input:
986  blobs[blob] = blob_nbytes(blob)
987  for blob in op.output:
988  blobs[blob] = blob_nbytes(blob)
989 
990  return blobs