Caffe2 - C++ API
A deep learning, cross platform ML framework
pimpl.h
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/detail/static.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/utils/variadic.h>
9 
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 
14 namespace torch {
15 namespace detail {
16 // Dump all the template metaprogramming in this file.
17 #include <torch/csrc/api/include/torch/nn/pimpl-inl.h>
18 } // namespace detail
19 
20 namespace nn {
21 
25 template <typename Contained>
27  protected:
31  std::shared_ptr<Contained> impl_;
32 
33  public:
34  using ContainedType = Contained;
35 
42  ModuleHolder() : impl_(default_construct()) {
43  static_assert(
44  std::is_default_constructible<Contained>::value,
45  "You are trying to default construct a module which has "
46  "no default constructor. Use = nullptr to give it the empty state "
47  "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
48  }
49 
53  /* implicit */ ModuleHolder(std::nullptr_t) : impl_(nullptr) {}
54 
57  template <
58  typename Head,
59  typename... Tail,
60  typename = typename std::enable_if<
62  (sizeof...(Tail) == 0))>::type>
63  explicit ModuleHolder(Head&& head, Tail&&... tail)
64  : impl_(new Contained(
65  std::forward<Head>(head),
66  std::forward<Tail>(tail)...)) {}
67 
70  /* implicit */ ModuleHolder(std::shared_ptr<Contained> module)
71  : impl_(std::move(module)) {}
72 
75  explicit operator bool() const noexcept {
76  return !is_empty();
77  }
78 
80  Contained* operator->() {
81  return get();
82  }
83 
85  const Contained* operator->() const {
86  return get();
87  }
88 
90  Contained& operator*() {
91  return *get();
92  }
93 
95  const Contained& operator*() const {
96  return *get();
97  }
98 
100  const std::shared_ptr<Contained>& ptr() const {
101  AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
102  return impl_;
103  }
104 
106  Contained* get() {
107  AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
108  return impl_.get();
109  }
110 
112  const Contained* get() const {
113  AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
114  return impl_.get();
115  }
116 
118  template <typename... Args>
119  auto operator()(Args&&... args)
120  -> torch::detail::return_type_of_forward_t<Contained, Args...> {
121  // This will not compile if the module does not have a `forward()` method
122  // (as expected).
123  // NOTE: `std::forward` is qualified to prevent VS2017 emitting
124  // error C2872: 'std': ambiguous symbol
125  return impl_->forward(::std::forward<Args>(args)...);
126  }
127 
131  template <typename Arg>
132  auto operator[](Arg&& arg) -> decltype((*impl_)[::std::forward<Arg>(arg)]) {
133  return (*impl_)[::std::forward<Arg>(arg)];
134  }
135 
137  bool is_empty() const noexcept {
138  return impl_ == nullptr;
139  }
140 
141  private:
149 
150  template <
151  typename T = Contained,
152  typename = torch::enable_if_t<std::is_default_constructible<T>::value>>
153  std::shared_ptr<Contained> default_construct() {
154  return std::make_shared<Contained>();
155  }
156 
157  template <typename T = Contained>
158  torch::disable_if_t<
159  std::is_default_constructible<T>::value,
160  std::shared_ptr<Contained>>
161  default_construct() {
162  return nullptr;
163  }
164 };
165 
167 template <typename ModuleType>
168 std::ostream& operator<<(
169  std::ostream& stream,
170  const nn::ModuleHolder<ModuleType>& module) {
171  return stream << *module;
172 }
173 
175 template <typename ModuleType>
176 serialize::OutputArchive& operator<<(
177  serialize::OutputArchive& archive,
178  const nn::ModuleHolder<ModuleType>& module) {
179  return archive << module.ptr();
180 }
181 
183 template <typename ModuleType>
184 serialize::InputArchive& operator>>(
185  serialize::InputArchive& archive,
187  return archive >> module.ptr();
188 }
189 
190 } // namespace nn
191 } // namespace torch
192 
195 #define TORCH_MODULE_IMPL(Name, Impl) \
196  class Name : public torch::nn::ModuleHolder<Impl> { /* NOLINT */ \
197  public: \
198  using torch::nn::ModuleHolder<Impl>::ModuleHolder; \
199  }
200 
202 #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl)
ModuleHolder()
Default constructs the contained module if if has a default constructor, else produces a static error...
Definition: pimpl.h:42
bool is_empty() const noexcept
Returns true if the ModuleHolder does not contain a module.
Definition: pimpl.h:137
auto operator()(Args &&...args) -> torch::detail::return_type_of_forward_t< Contained, Args... >
Calls the forward() method of the contained module.
Definition: pimpl.h:119
ModuleHolder(std::nullptr_t)
Constructs the ModuleHolder with an empty contained value.
Definition: pimpl.h:53
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
const Contained * operator->() const
Forwards to the contained module.
Definition: pimpl.h:85
auto operator[](Arg &&arg) -> decltype((*impl_)[::std::forward< Arg >(arg)])
Forwards to the subscript operator of the contained module.
Definition: pimpl.h:132
Definition: jit_type.h:17
Contained * operator->()
Forwards to the contained module.
Definition: pimpl.h:80
Contained & operator*()
Returns a reference to the contained module.
Definition: pimpl.h:90
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
Definition: pimpl.h:100
ModuleHolder(Head &&head, Tail &&...tail)
Constructs the ModuleHolder with a contained module, forwarding all arguments to its constructor...
Definition: pimpl.h:63
std::shared_ptr< Contained > impl_
The module pointer this class wraps.
Definition: pimpl.h:31
ModuleHolder(std::shared_ptr< Contained > module)
Constructs the ModuleHolder from a pointer to the contained type.
Definition: pimpl.h:70
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
const Contained & operator*() const
Returns a const reference to the contained module.
Definition: pimpl.h:95