Stores a type erased Module
.
More...
#include <any.h>
Data Structures | |
struct | Placeholder |
class | Value |
A simplified implementation of std::any which stores a type erased object, whose concrete value can be retrieved at runtime by checking if the typeid() of a requested type matches the typeid() of the object stored. More... | |
Public Member Functions | |
AnyModule ()=default | |
A default-constructed AnyModule is in an empty state. | |
template<typename ModuleType > | |
AnyModule (std::shared_ptr< ModuleType > module) | |
Constructs an AnyModule from a shared_ptr to concrete module object. | |
template<typename ModuleType , typename = torch::detail::enable_if_module_t<ModuleType>> | |
AnyModule (ModuleType &&module) | |
Constructs an AnyModule from a concrete module object. | |
template<typename ModuleType > | |
AnyModule (const ModuleHolder< ModuleType > &module_holder) | |
Constructs an AnyModule from a module holder. | |
AnyModule (AnyModule &&)=default | |
Move construction and assignment is allowed, and follows the default behavior of move for std::unique_ptr . More... | |
AnyModule & | operator= (AnyModule &&)=default |
AnyModule (const AnyModule &other) | |
Creates a shallow copy of an AnyModule . | |
AnyModule & | operator= (const AnyModule &other) |
AnyModule | clone (optional< Device > device=nullopt) const |
Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty. More... | |
template<typename ModuleType > | |
AnyModule & | operator= (std::shared_ptr< ModuleType > module) |
Assigns a module to the AnyModule (to circumvent the explicit constructor). More... | |
template<typename... ArgumentTypes> | |
Value | any_forward (ArgumentTypes &&...arguments) |
Invokes forward() on the contained module with the given arguments, and returns the return value as an Value . More... | |
template<typename ReturnType = torch::Tensor, typename... ArgumentTypes> | |
ReturnType | forward (ArgumentTypes &&...arguments) |
Invokes forward() on the contained module with the given arguments, and casts the returned Value to the supplied ReturnType (which defaults to torch::Tensor ). More... | |
template<typename T , typename = torch::detail::enable_if_module_t<T>> | |
T & | get () |
Attempts to cast the underlying module to the given module type. More... | |
template<typename T , typename = torch::detail::enable_if_module_t<T>> | |
const T & | get () const |
Attempts to cast the underlying module to the given module type. More... | |
template<typename T , typename ContainedType = typename T::ContainedType> | |
T | get () const |
Returns the contained module in a nn::ModuleHolder subclass if possible (i.e. More... | |
std::shared_ptr< Module > | ptr () const |
Returns a std::shared_ptr whose dynamic type is that of the underlying module. More... | |
template<typename T , typename = torch::detail::enable_if_module_t<T>> | |
std::shared_ptr< T > | ptr () const |
Like ptr() , but casts the pointer to the given type. | |
const std::type_info & | type_info () const |
Returns the type_info object of the contained value. | |
bool | is_empty () const noexcept |
Returns true if the AnyModule does not contain a module. | |
Stores a type erased Module
.
The PyTorch C++ API does not impose an interface on the signature of forward()
in Module
subclasses. This gives you complete freedom to design your forward()
methods to your liking. However, this also means there is no unified base type you could store in order to call forward()
polymorphically for any module. This is where the AnyModule
comes in. Instead of inheritance, it relies on type erasure for polymorphism.
An AnyModule
can store any nn::Module
subclass that provides a forward()
method. This forward()
may accept any types and return any type. Once stored in an AnyModule
, you can invoke the underlying module's forward()
by calling AnyModule::forward()
with the arguments you would supply to the stored module (though see one important limitation below). Example:
.. code-block:: cpp
struct GenericTrainer { torch::nn::AnyModule module;
void train(torch::Tensor input) { module.forward(input); } };
GenericTrainer trainer1{torch::nn::Linear(3, 4)}; GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
As AnyModule
erases the static type of the stored module (and its forward()
method) to achieve polymorphism, type checking of arguments is moved to runtime. That is, passing an argument with an incorrect type to an AnyModule
will compile, but throw an exception at runtime:
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4)); // Linear takes a tensor as input, but we are passing an integer. // This will compile, but throw a torch::Error
exception at runtime. module.forward(123);
.. attention:: One noteworthy limitation of AnyModule
is that its forward()
method does not support implicit conversion of argument types. For example, if the stored module's forward()
method accepts a float
and you call any_module.forward(3.4)
(where 3.4
is a double
), this will throw an exception.
The return type of the AnyModule
's forward()
method is controlled via the first template argument to AnyModule::forward()
. It defaults to torch::Tensor
. To change it, you can write any_module.forward<int>()
, for example.
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4)); auto output = module.forward(torch::ones({2, 3}));
struct IntModule { int forward(int x) { return x; } }; torch::nn::AnyModule module(IntModule{}); int output = module.forward<int>(5);
The only other method an AnyModule
provides access to on the stored module is clone()
. However, you may acquire a handle on the module via .ptr()
, which returns a shared_ptr<nn::Module>
. Further, if you know the concrete type of the stored module, you can get a concrete handle to it using .get<T>()
where T
is the concrete module type.
.. code-block:: cpp
torch::nn::AnyModule module(torch::nn::Linear(3, 4)); std::shared_ptr<nn::Module> ptr = module.ptr(); torch::nn::Linear linear(module.get<torch::nn::Linear>());
|
default |
Move construction and assignment is allowed, and follows the default behavior of move for std::unique_ptr
.
AnyModule::Value torch::nn::AnyModule::any_forward | ( | ArgumentTypes &&... | arguments | ) |
ReturnType torch::nn::AnyModule::forward | ( | ArgumentTypes &&... | arguments | ) |
T & torch::nn::AnyModule::get | ( | ) |
const T & torch::nn::AnyModule::get | ( | ) | const |
T torch::nn::AnyModule::get | ( | ) | const |
Returns the contained module in a nn::ModuleHolder
subclass if possible (i.e.
if T
has a constructor for the underlying module type).
AnyModule & torch::nn::AnyModule::operator= | ( | std::shared_ptr< ModuleType > | module | ) |
|
inline |