Caffe2 - Python API
A deep learning, cross platform ML framework
hooks.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 import collections
3 import weakref
4 import warnings
5 
6 
7 class RemovableHandle(object):
8  """A handle which provides the capability to remove a hook."""
9 
10  next_id = 0
11 
12  def __init__(self, hooks_dict):
13  self.hooks_dict_ref = weakref.ref(hooks_dict)
14  self.id = RemovableHandle.next_id
15  RemovableHandle.next_id += 1
16 
17  def remove(self):
18  hooks_dict = self.hooks_dict_ref()
19  if hooks_dict is not None and self.id in hooks_dict:
20  del hooks_dict[self.id]
21 
22  def __getstate__(self):
23  return (self.hooks_dict_ref(), self.id)
24 
25  def __setstate__(self, state):
26  if state[0] is None:
27  # create a dead reference
28  self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
29  else:
30  self.hooks_dict_ref = weakref.ref(state[0])
31  self.id = state[1]
32  RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
33 
34  def __enter__(self):
35  return self
36 
37  def __exit__(self, type, value, tb):
38  self.remove()
39 
40 
41 def unserializable_hook(f):
42  """
43  Decorator which marks a function as an unserializable hook.
44  This suppresses warnings that would otherwise arise if you attempt
45  to serialize a tensor that has a hook.
46  """
47  f.__torch_unserializable__ = True
48  return f
49 
50 
51 def warn_if_has_hooks(tensor):
52  if tensor._backward_hooks:
53  for k in tensor._backward_hooks:
54  hook = tensor._backward_hooks[k]
55  if not hasattr(k, "__torch_unserializable__"):
56  warnings.warn("backward hook {} on tensor will not be "
57  "serialized. If this is expected, you can "
58  "decorate the function with @torch.utils.hooks.unserializable_hook "
59  "to suppress this warning".format(repr(hook)))