Caffe2 - Python API
A deep learning, cross platform ML framework
cpp.py
1 """Functionality for Python <-> C++ frontend inter-op."""
2 
3 from torch import nn
4 
5 
6 class OrderedDictWrapper(object):
7  """
8  A wrapper around a C++ OrderedDict that dynamically evaluates the
9  OrderedDict getter on a bound C++ module, such that new changes on the C++
10  side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just
11  once would get a frozen copy of the parameters at the time of access.
12  ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so
13  using properties does not work.
14  """
15 
16  def __init__(self, cpp_module, attr):
17  self.cpp_module = cpp_module
18  self.attr = attr
19 
20  @property
21  def cpp_dict(self):
22  return getattr(self.cpp_module, self.attr)
23 
24  # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
25  # must manually override them.
26 
27  def items(self):
28  return self.cpp_dict.items()
29 
30  def keys(self):
31  return self.cpp_dict.keys()
32 
33  def values(self):
34  return self.cpp_dict.values()
35 
36  def __iter__(self):
37  return self.cpp_dict.__iter__()
38 
39  def __len__(self):
40  return self.cpp_dict.__len__()
41 
42  def __contains__(self, key):
43  return self.cpp_dict.__contains__(key)
44 
45  def __getitem__(self, key):
46  return self.cpp_dict.__getitem__(key)
47 
48 
49 class ModuleWrapper(nn.Module):
50  """
51  A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and
52  delegates all access.
53  """
54 
55  def __init__(self, cpp_module):
56  # Assign before the super class constructor so ``self.training`` can be
57  # assigned to in the super class constructor.
58  self.cpp_module = cpp_module
59  super(ModuleWrapper, self).__init__()
60  self._parameters = OrderedDictWrapper(cpp_module, "_parameters")
61  self._buffers = OrderedDictWrapper(cpp_module, "_buffers")
62  self._modules = OrderedDictWrapper(cpp_module, "_modules")
63  for attr in dir(cpp_module):
64  # Skip magic methods and the three attributes above.
65  if not attr.startswith("_"):
66  setattr(self, attr, getattr(self.cpp_module, attr))
67 
68  def _apply(self, fn):
69  for param in self.parameters():
70  # Tensors stored in modules are graph leaves, and we don't
71  # want to create copy nodes, so we have to unpack the data.
72  param.data = fn(param.data)
73  if param._grad is not None:
74  param._grad.data = fn(param._grad.data)
75 
76  for buf in self.buffers():
77  buf.data = fn(buf.data)
78 
79  return self
80 
81  @property
82  def training(self):
83  return self.cpp_module.training
84 
85  @training.setter
86  def training(self, mode):
87  self.cpp_module.train(mode)
88 
89  def __repr__(self):
90  return self.cpp_module.__repr__()