3 #include <torch/detail/static.h> 4 #include <torch/nn/cloneable.h> 5 #include <torch/nn/module.h> 6 #include <torch/nn/modules/any.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 10 #include <c10/util/Exception.h> 16 #include <type_traits> 93 using Iterator = std::vector<AnyModule>::iterator;
94 using ConstIterator = std::vector<AnyModule>::const_iterator;
99 template <
typename... Modules>
101 modules_.reserve(
sizeof...(Modules));
109 auto clone = std::make_shared<SequentialImpl>();
110 for (
const auto& module : modules_) {
111 clone->push_back(module.clone(device));
122 stream <<
"torch::nn::Sequential";
152 template <
typename ReturnType =
Tensor,
typename... InputTypes>
154 AT_CHECK(!
is_empty(),
"Cannot call forward() on an empty Sequential");
156 auto iterator = modules_.begin();
157 auto input = iterator->any_forward(std::forward<InputTypes>(inputs)...);
159 for (++iterator; iterator != modules_.end(); ++iterator) {
160 input = iterator->any_forward(std::move(input));
165 if (
auto* return_value = input.template try_get<ReturnType>()) {
166 return std::move(*return_value);
169 "The type of the return value is ",
171 ", but you asked for type ",
176 template <
typename ModuleType>
177 void push_back(std::shared_ptr<ModuleType> module_ptr) {
181 !std::is_same<SequentialImpl, ModuleType>::value,
182 "Sequential is not nestable");
184 torch::detail::is_module<ModuleType>::value,
185 "Can only add objects derived from nn::Module to Sequential");
188 "Can only add modules with a forward() method to Sequential");
197 template <
typename M,
typename = torch::detail::enable_if_module_t<M>>
200 using Type =
typename std::remove_reference<M>::type;
202 push_back(std::make_shared<Type>(std::forward<M>(module)));
207 template <
typename M>
213 template <
typename Container>
214 void extend(
const Container& container) {
215 for (
const auto& module : container) {
222 return modules_.begin();
227 return modules_.begin();
232 return modules_.end();
236 ConstIterator
end()
const {
237 return modules_.end();
243 template <
typename T>
246 torch::detail::is_module<T>::value,
247 "Can only call Sequential::at with an nn::Module type");
248 AT_CHECK(index <
size(),
"Index out of range");
249 return modules_[index].get<
T>();
255 template <
typename T>
256 const T&
at(
size_t index)
const {
258 torch::detail::is_module<T>::value,
259 "Can only call Sequential::at with an nn::Module type");
260 AT_CHECK(index <
size(),
"Index out of range");
261 return modules_[index].get<
T>();
267 std::shared_ptr<Module>
ptr(
size_t index)
const {
268 AT_CHECK(index <
size(),
"Index out of range");
269 return modules_[index].ptr();
275 template <
typename T>
276 std::shared_ptr<T>
ptr(
size_t index)
const {
278 torch::detail::is_module<T>::value,
279 "Can only call Sequential::ptr with an nn::Module type");
280 AT_CHECK(index <
size(),
"Index out of range");
281 return modules_[index].ptr<
T>();
292 return modules_.size();
305 template <
typename First,
typename Second,
typename... Rest>
306 void push_back(First&& first, Second&& second, Rest&&... rest) {
310 push_back(std::forward<Second>(second), std::forward<Rest>(rest)...);
315 modules_.push_back(std::move(any_module));
316 const auto index = modules_.size() - 1;
326 std::vector<AnyModule> modules_;
333 TORCH_MODULE(Sequential);
bool is_empty() const noexcept
True if there are no modules in the Sequential.
std::shared_ptr< Module > operator[](size_t index) const
Like ptr(index).
void push_back(std::shared_ptr< ModuleType > module_ptr)
Adds a new (boxed) Module to the Sequential container.
Iterator begin()
Returns an iterator to the start of the Sequential.
void push_back(M &&module)
Adds a new Module to the Sequential container, moving or copying it into a shared_ptr internally...
std::shared_ptr< Module > ptr(size_t index) const
Attempts to return a std::shared_ptr whose dynamic type is that of the underlying module at the given...
size_t size() const noexcept
The current size of the Sequential container.
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
void pretty_print(std::ostream &stream) const override
Pretty prints the Sequential module into the given stream.
Detects if a type T has a forward() method.
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
ConstIterator begin() const
Returns a const iterator to the start of the Sequential.
ReturnType forward(InputTypes &&...inputs)
Feeds inputs to the first module and then chains outputs to inputs, returning the last output...
void push_back(const ModuleHolder< M > &module_holder)
Unwraps the contained module of a ModuleHolder and adds it to the Sequential.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
T & at(size_t index)
Attempts to return the module at the given index as the requested type.
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
ConstIterator end() const
Returns a const iterator to the end of the Sequential.
const T & at(size_t index) const
Attempts to return the module at the given index as the requested type.
Stores a type erased Module.
void reset() override
reset() is empty for Sequential, since it does not have parameters of its own.
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
A list of Modules that acts as a Module itself.
SequentialImpl(Modules &&...modules)
Constructs the Sequential from a variadic list of modules.
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Special cloning function for Sequential because it does not use reset().
Iterator end()
Returns an iterator to the end of the Sequential.
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
std::shared_ptr< T > ptr(size_t index) const
Attempts to return a std::shared_ptr whose type is the one provided.
void extend(const Container &container)
Iterates over the container and calls push_back() on each value.