Caffe2 - C++ API
A deep learning, cross platform ML framework
sequential.h
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/module.h>
6 #include <torch/nn/modules/any.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <c10/util/Exception.h>
11 
12 #include <cstdint>
13 #include <memory>
14 #include <ostream>
15 #include <string>
16 #include <type_traits>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch {
21 namespace nn {
22 
91 class SequentialImpl : public Cloneable<SequentialImpl> {
92  public:
93  using Iterator = std::vector<AnyModule>::iterator;
94  using ConstIterator = std::vector<AnyModule>::const_iterator;
95 
96  SequentialImpl() = default;
97 
99  template <typename... Modules>
100  explicit SequentialImpl(Modules&&... modules) {
101  modules_.reserve(sizeof...(Modules));
102  push_back(std::forward<Modules>(modules)...);
103  }
104 
107  std::shared_ptr<Module> clone(
108  const optional<Device>& device = nullopt) const override {
109  auto clone = std::make_shared<SequentialImpl>();
110  for (const auto& module : modules_) {
111  clone->push_back(module.clone(device));
112  }
113  return clone;
114  }
115 
118  void reset() override {}
119 
121  void pretty_print(std::ostream& stream) const override {
122  stream << "torch::nn::Sequential";
123  }
124 
152  template <typename ReturnType = Tensor, typename... InputTypes>
153  ReturnType forward(InputTypes&&... inputs) {
154  AT_CHECK(!is_empty(), "Cannot call forward() on an empty Sequential");
155 
156  auto iterator = modules_.begin();
157  auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
158 
159  for (++iterator; iterator != modules_.end(); ++iterator) {
160  input = iterator->any_forward(std::move(input));
161  }
162 
163  // Check the return value and give a nice error message if the requsted
164  // return type was incorrect.
165  if (auto* return_value = input.template try_get<ReturnType>()) {
166  return std::move(*return_value);
167  }
168  AT_ERROR(
169  "The type of the return value is ",
170  c10::demangle(input.type_info().name()),
171  ", but you asked for type ",
172  c10::demangle(typeid(ReturnType).name()));
173  }
174 
176  template <typename ModuleType>
177  void push_back(std::shared_ptr<ModuleType> module_ptr) {
178  // Nesting Sequential doesn't work because `forward()`'s return type is
179  // templatized, so it'll give a nasty compiler error.
180  static_assert(
181  !std::is_same<SequentialImpl, ModuleType>::value,
182  "Sequential is not nestable");
183  static_assert(
184  torch::detail::is_module<ModuleType>::value,
185  "Can only add objects derived from nn::Module to Sequential");
186  static_assert(
188  "Can only add modules with a forward() method to Sequential");
189  push_back(AnyModule(std::move(module_ptr)));
190  }
191 
197  template <typename M, typename = torch::detail::enable_if_module_t<M>>
198  void push_back(M&& module) {
199  // Need to get rid of any reference components for make_unique.
200  using Type = typename std::remove_reference<M>::type;
201  // Here we move (or copy) the module into a new shared_ptr.
202  push_back(std::make_shared<Type>(std::forward<M>(module)));
203  }
204 
207  template <typename M>
208  void push_back(const ModuleHolder<M>& module_holder) {
209  push_back(module_holder.ptr());
210  }
211 
213  template <typename Container>
214  void extend(const Container& container) {
215  for (const auto& module : container) {
216  push_back(module);
217  }
218  }
219 
221  Iterator begin() {
222  return modules_.begin();
223  }
224 
226  ConstIterator begin() const {
227  return modules_.begin();
228  }
229 
231  Iterator end() {
232  return modules_.end();
233  }
234 
236  ConstIterator end() const {
237  return modules_.end();
238  }
239 
243  template <typename T>
244  T& at(size_t index) {
245  static_assert(
246  torch::detail::is_module<T>::value,
247  "Can only call Sequential::at with an nn::Module type");
248  AT_CHECK(index < size(), "Index out of range");
249  return modules_[index].get<T>();
250  }
251 
255  template <typename T>
256  const T& at(size_t index) const {
257  static_assert(
258  torch::detail::is_module<T>::value,
259  "Can only call Sequential::at with an nn::Module type");
260  AT_CHECK(index < size(), "Index out of range");
261  return modules_[index].get<T>();
262  }
263 
267  std::shared_ptr<Module> ptr(size_t index) const {
268  AT_CHECK(index < size(), "Index out of range");
269  return modules_[index].ptr();
270  }
271 
275  template <typename T>
276  std::shared_ptr<T> ptr(size_t index) const {
277  static_assert(
278  torch::detail::is_module<T>::value,
279  "Can only call Sequential::ptr with an nn::Module type");
280  AT_CHECK(index < size(), "Index out of range");
281  return modules_[index].ptr<T>();
282  }
283 
285  std::shared_ptr<Module> operator[](size_t index) const {
286  // This is the only method we can call without a type.
287  return ptr(index);
288  }
289 
291  size_t size() const noexcept {
292  return modules_.size();
293  }
294 
296  bool is_empty() const noexcept {
297  return size() == 0;
298  }
299 
300  private:
305  template <typename First, typename Second, typename... Rest>
306  void push_back(First&& first, Second&& second, Rest&&... rest) {
307  push_back(std::forward<First>(first));
308  // Recursively calls this method, until the parameter pack only thas this
309  // entry left. Then calls `push_back()` a final time (above).
310  push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
311  }
312 
314  void push_back(AnyModule any_module) {
315  modules_.push_back(std::move(any_module));
316  const auto index = modules_.size() - 1;
317  register_module(std::to_string(index), modules_[index].ptr());
318  }
319 
321  void push_back() {}
322 
323  // Box the AnyModules to give Sequential reference semantics, like the rest of
324  // the API. Note that this is not required otherwise, this could just be a
325  // `vector<AnyModule>`.
326  std::vector<AnyModule> modules_;
327 };
328 
333 TORCH_MODULE(Sequential);
334 } // namespace nn
335 } // namespace torch
Definition: any.cpp:108
bool is_empty() const noexcept
True if there are no modules in the Sequential.
Definition: sequential.h:296
std::shared_ptr< Module > operator[](size_t index) const
Like ptr(index).
Definition: sequential.h:285
void push_back(std::shared_ptr< ModuleType > module_ptr)
Adds a new (boxed) Module to the Sequential container.
Definition: sequential.h:177
Iterator begin()
Returns an iterator to the start of the Sequential.
Definition: sequential.h:221
void push_back(M &&module)
Adds a new Module to the Sequential container, moving or copying it into a shared_ptr internally...
Definition: sequential.h:198
std::shared_ptr< Module > ptr(size_t index) const
Attempts to return a std::shared_ptr whose dynamic type is that of the underlying module at the given...
Definition: sequential.h:267
size_t size() const noexcept
The current size of the Sequential container.
Definition: sequential.h:291
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
Definition: module.cpp:187
void pretty_print(std::ostream &stream) const override
Pretty prints the Sequential module into the given stream.
Definition: sequential.h:121
Detects if a type T has a forward() method.
Definition: static.h:19
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
ConstIterator begin() const
Returns a const iterator to the start of the Sequential.
Definition: sequential.h:226
ReturnType forward(InputTypes &&...inputs)
Feeds inputs to the first module and then chains outputs to inputs, returning the last output...
Definition: sequential.h:153
void push_back(const ModuleHolder< M > &module_holder)
Unwraps the contained module of a ModuleHolder and adds it to the Sequential.
Definition: sequential.h:208
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
T & at(size_t index)
Attempts to return the module at the given index as the requested type.
Definition: sequential.h:244
Definition: jit_type.h:17
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
Definition: pimpl.h:100
ConstIterator end() const
Returns a const iterator to the end of the Sequential.
Definition: sequential.h:236
const T & at(size_t index) const
Attempts to return the module at the given index as the requested type.
Definition: sequential.h:256
Stores a type erased Module.
Definition: any.h:108
void reset() override
reset() is empty for Sequential, since it does not have parameters of its own.
Definition: sequential.h:118
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
Definition: Type.cpp:23
A list of Modules that acts as a Module itself.
Definition: sequential.h:91
SequentialImpl(Modules &&...modules)
Constructs the Sequential from a variadic list of modules.
Definition: sequential.h:100
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Special cloning function for Sequential because it does not use reset().
Definition: sequential.h:107
Iterator end()
Returns an iterator to the end of the Sequential.
Definition: sequential.h:231
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
Definition: module.h:556
std::shared_ptr< T > ptr(size_t index) const
Attempts to return a std::shared_ptr whose type is the one provided.
Definition: sequential.h:276
void extend(const Container &container)
Iterates over the container and calls push_back() on each value.
Definition: sequential.h:214