3 #include <torch/detail/static.h>     4 #include <torch/nn/cloneable.h>     5 #include <torch/nn/module.h>     6 #include <torch/nn/modules/any.h>     7 #include <torch/nn/pimpl.h>     8 #include <torch/types.h>    10 #include <c10/util/Exception.h>    16 #include <type_traits>    93   using Iterator = std::vector<AnyModule>::iterator;
    94   using ConstIterator = std::vector<AnyModule>::const_iterator;
    99   template <
typename... Modules>
   101     modules_.reserve(
sizeof...(Modules));
   109     auto clone = std::make_shared<SequentialImpl>();
   110     for (
const auto& module : modules_) {
   111       clone->push_back(module.clone(device));
   122     stream << 
"torch::nn::Sequential";
   152   template <
typename ReturnType = 
Tensor, 
typename... InputTypes>
   154     AT_CHECK(!
is_empty(), 
"Cannot call forward() on an empty Sequential");
   156     auto iterator = modules_.begin();
   157     auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
   159     for (++iterator; iterator != modules_.end(); ++iterator) {
   160       input = iterator->any_forward(std::move(input));
   165     if (
auto* return_value = input.template try_get<ReturnType>()) {
   166       return std::move(*return_value);
   169         "The type of the return value is ",
   171         ", but you asked for type ",
   176   template <
typename ModuleType>
   177   void push_back(std::shared_ptr<ModuleType> module_ptr) {
   181         !std::is_same<SequentialImpl, ModuleType>::value,
   182         "Sequential is not nestable");
   184         torch::detail::is_module<ModuleType>::value,
   185         "Can only add objects derived from nn::Module to Sequential");
   188         "Can only add modules with a forward() method to Sequential");
   197   template <
typename M, 
typename = torch::detail::enable_if_module_t<M>>
   200     using Type = 
typename std::remove_reference<M>::type;
   202     push_back(std::make_shared<Type>(std::forward<M>(module)));
   207   template <
typename M>
   213   template <
typename Container>
   214   void extend(
const Container& container) {
   215     for (
const auto& module : container) {
   222     return modules_.begin();
   227     return modules_.begin();
   232     return modules_.end();
   236   ConstIterator 
end()
 const {
   237     return modules_.end();
   243   template <
typename T>
   246         torch::detail::is_module<T>::value,
   247         "Can only call Sequential::at with an nn::Module type");
   248     AT_CHECK(index < 
size(), 
"Index out of range");
   249     return modules_[index].get<
T>();
   255   template <
typename T>
   256   const T& 
at(
size_t index)
 const {
   258         torch::detail::is_module<T>::value,
   259         "Can only call Sequential::at with an nn::Module type");
   260     AT_CHECK(index < 
size(), 
"Index out of range");
   261     return modules_[index].get<
T>();
   267   std::shared_ptr<Module> 
ptr(
size_t index)
 const {
   268     AT_CHECK(index < 
size(), 
"Index out of range");
   269     return modules_[index].ptr();
   275   template <
typename T>
   276   std::shared_ptr<T> 
ptr(
size_t index)
 const {
   278         torch::detail::is_module<T>::value,
   279         "Can only call Sequential::ptr with an nn::Module type");
   280     AT_CHECK(index < 
size(), 
"Index out of range");
   281     return modules_[index].ptr<
T>();
   292     return modules_.size();
   305   template <
typename First, 
typename Second, 
typename... Rest>
   306   void push_back(First&& first, Second&& second, Rest&&... rest) {
   310     push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
   315     modules_.push_back(std::move(any_module));
   316     const auto index = modules_.size() - 1;
   326   std::vector<AnyModule> modules_;
   333 TORCH_MODULE(Sequential);
 
bool is_empty() const noexcept
True if there are no modules in the Sequential. 
 
std::shared_ptr< Module > operator[](size_t index) const 
Like ptr(index). 
 
void push_back(std::shared_ptr< ModuleType > module_ptr)
Adds a new (boxed) Module to the Sequential container. 
 
Iterator begin()
Returns an iterator to the start of the Sequential. 
 
void push_back(M &&module)
Adds a new Module to the Sequential container, moving or copying it into a shared_ptr internally...
 
std::shared_ptr< Module > ptr(size_t index) const 
Attempts to return a std::shared_ptr whose dynamic type is that of the underlying module at the given...
 
size_t size() const noexcept
The current size of the Sequential container. 
 
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const 
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
 
void pretty_print(std::ostream &stream) const override
Pretty prints the Sequential module into the given stream. 
 
Detects if a type T has a forward() method. 
 
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
 
ConstIterator begin() const 
Returns a const iterator to the start of the Sequential. 
 
ReturnType forward(InputTypes &&...inputs)
Feeds inputs to the first module and then chains outputs to inputs, returning the last output...
 
void push_back(const ModuleHolder< M > &module_holder)
Unwraps the contained module of a ModuleHolder and adds it to the Sequential. 
 
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
 
T & at(size_t index)
Attempts to return the module at the given index as the requested type. 
 
const std::shared_ptr< Contained > & ptr() const 
Returns a shared pointer to the underlying module. 
 
ConstIterator end() const 
Returns a const iterator to the end of the Sequential. 
 
const T & at(size_t index) const 
Attempts to return the module at the given index as the requested type. 
 
Stores a type erased Module. 
 
void reset() override
reset() is empty for Sequential, since it does not have parameters of its own. 
 
std::string demangle(const char *name)
Utility to demangle a C++ symbol name. 
 
A list of Modules that acts as a Module itself. 
 
SequentialImpl(Modules &&...modules)
Constructs the Sequential from a variadic list of modules. 
 
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Special cloning function for Sequential because it does not use reset(). 
 
Iterator end()
Returns an iterator to the end of the Sequential. 
 
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module. 
 
std::shared_ptr< T > ptr(size_t index) const 
Attempts to return a std::shared_ptr whose type is the one provided. 
 
void extend(const Container &container)
Iterates over the container and calls push_back() on each value.