4 #include <torch/detail/static.h> 5 #include <torch/serialize/archive.h> 6 #include <torch/types.h> 8 #include <torch/csrc/utils/variadic.h> 11 #include <type_traits> 17 #include <torch/csrc/api/include/torch/nn/pimpl-inl.h> 25 template <
typename Contained>
31 std::shared_ptr<Contained>
impl_;
34 using ContainedType = Contained;
44 std::is_default_constructible<Contained>::value,
45 "You are trying to default construct a module which has " 46 "no default constructor. Use = nullptr to give it the empty state " 47 "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`).");
60 typename =
typename std::enable_if<
62 (
sizeof...(Tail) == 0))>::type>
64 : impl_(
new Contained(
65 std::forward<Head>(head),
66 std::forward<Tail>(tail)...)) {}
71 : impl_(
std::move(module)) {}
75 explicit operator bool() const noexcept {
100 const std::shared_ptr<Contained>&
ptr()
const {
101 AT_CHECK(!is_empty(),
"Accessing empty ModuleHolder");
107 AT_CHECK(!is_empty(),
"Accessing empty ModuleHolder");
112 const Contained*
get()
const {
113 AT_CHECK(!is_empty(),
"Accessing empty ModuleHolder");
118 template <
typename... Args>
120 -> torch::detail::return_type_of_forward_t<Contained, Args...> {
125 return impl_->forward(::std::forward<Args>(args)...);
131 template <
typename Arg>
132 auto operator[](Arg&& arg) -> decltype((*impl_)[::std::forward<Arg>(arg)]) {
133 return (*impl_)[::std::forward<Arg>(arg)];
138 return impl_ ==
nullptr;
151 typename T = Contained,
152 typename = torch::enable_if_t<std::is_default_constructible<T>::value>>
153 std::shared_ptr<Contained> default_construct() {
154 return std::make_shared<Contained>();
157 template <
typename T = Contained>
159 std::is_default_constructible<T>::value,
160 std::shared_ptr<Contained>>
161 default_construct() {
167 template <
typename ModuleType>
168 std::ostream& operator<<(
169 std::ostream& stream,
171 return stream << *module;
175 template <
typename ModuleType>
179 return archive << module.
ptr();
183 template <
typename ModuleType>
187 return archive >> module.
ptr();
195 #define TORCH_MODULE_IMPL(Name, Impl) \ 196 class Name : public torch::nn::ModuleHolder<Impl> { \ 198 using torch::nn::ModuleHolder<Impl>::ModuleHolder; \ 202 #define TORCH_MODULE(Name) TORCH_MODULE_IMPL(Name, Name##Impl) ModuleHolder()
Default constructs the contained module if if has a default constructor, else produces a static error...
bool is_empty() const noexcept
Returns true if the ModuleHolder does not contain a module.
auto operator()(Args &&...args) -> torch::detail::return_type_of_forward_t< Contained, Args... >
Calls the forward() method of the contained module.
ModuleHolder(std::nullptr_t)
Constructs the ModuleHolder with an empty contained value.
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
const Contained * operator->() const
Forwards to the contained module.
auto operator[](Arg &&arg) -> decltype((*impl_)[::std::forward< Arg >(arg)])
Forwards to the subscript operator of the contained module.
Contained * operator->()
Forwards to the contained module.
Contained & operator*()
Returns a reference to the contained module.
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
ModuleHolder(Head &&head, Tail &&...tail)
Constructs the ModuleHolder with a contained module, forwarding all arguments to its constructor...
std::shared_ptr< Contained > impl_
The module pointer this class wraps.
ModuleHolder(std::shared_ptr< Contained > module)
Constructs the ModuleHolder from a pointer to the contained type.
const Contained & operator*() const
Returns a const reference to the contained module.