1 from __future__
import absolute_import, division, print_function, unicode_literals
8 """A handle which provides the capability to remove a hook.""" 12 def __init__(self, hooks_dict):
14 self.
id = RemovableHandle.next_id
15 RemovableHandle.next_id += 1
19 if hooks_dict
is not None and self.
id in hooks_dict:
20 del hooks_dict[self.
id]
22 def __getstate__(self):
25 def __setstate__(self, state):
32 RemovableHandle.next_id = max(RemovableHandle.next_id, self.
id + 1)
37 def __exit__(self, type, value, tb):
41 def unserializable_hook(f):
43 Decorator which marks a function as an unserializable hook. 44 This suppresses warnings that would otherwise arise if you attempt 45 to serialize a tensor that has a hook. 47 f.__torch_unserializable__ =
True 51 def warn_if_has_hooks(tensor):
52 if tensor._backward_hooks:
53 for k
in tensor._backward_hooks:
54 hook = tensor._backward_hooks[k]
55 if not hasattr(k,
"__torch_unserializable__"):
56 warnings.warn(
"backward hook {} on tensor will not be " 57 "serialized. If this is expected, you can " 58 "decorate the function with @torch.utils.hooks.unserializable_hook " 59 "to suppress this warning".format(repr(hook)))