1 """Functionality for Python <-> C++ frontend inter-op.""" 8 A wrapper around a C++ OrderedDict that dynamically evaluates the 9 OrderedDict getter on a bound C++ module, such that new changes on the C++ 10 side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just 11 once would get a frozen copy of the parameters at the time of access. 12 ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so 13 using properties does not work. 16 def __init__(self, cpp_module, attr):
28 return self.cpp_dict.items()
31 return self.cpp_dict.keys()
34 return self.cpp_dict.values()
37 return self.cpp_dict.__iter__()
40 return self.cpp_dict.__len__()
42 def __contains__(self, key):
43 return self.cpp_dict.__contains__(key)
45 def __getitem__(self, key):
46 return self.cpp_dict.__getitem__(key)
51 A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and 55 def __init__(self, cpp_module):
59 super(ModuleWrapper, self).__init__()
63 for attr
in dir(cpp_module):
65 if not attr.startswith(
"_"):
66 setattr(self, attr, getattr(self.
cpp_module, attr))
69 for param
in self.parameters():
72 param.data = fn(param.data)
73 if param._grad
is not None:
74 param._grad.data = fn(param._grad.data)
76 for buf
in self.buffers():
77 buf.data = fn(buf.data)
83 return self.cpp_module.training
86 def training(self, mode):
87 self.cpp_module.train(mode)
90 return self.cpp_module.__repr__()