Caffe2 - C++ API
A deep learning, cross platform ML framework
Data Structures | Public Member Functions
torch::nn::AnyModule Class Reference

Stores a type erased Module. More...

#include <any.h>

Data Structures

struct  Placeholder
 
class  Value
 A simplified implementation of std::any which stores a type erased object, whose concrete value can be retrieved at runtime by checking if the typeid() of a requested type matches the typeid() of the object stored. More...
 

Public Member Functions

 AnyModule ()=default
 A default-constructed AnyModule is in an empty state.
 
template<typename ModuleType >
 AnyModule (std::shared_ptr< ModuleType > module)
 Constructs an AnyModule from a shared_ptr to concrete module object.
 
template<typename ModuleType , typename = torch::detail::enable_if_module_t<ModuleType>>
 AnyModule (ModuleType &&module)
 Constructs an AnyModule from a concrete module object.
 
template<typename ModuleType >
 AnyModule (const ModuleHolder< ModuleType > &module_holder)
 Constructs an AnyModule from a module holder.
 
 AnyModule (AnyModule &&)=default
 Move construction and assignment is allowed, and follows the default behavior of move for std::unique_ptr. More...
 
AnyModuleoperator= (AnyModule &&)=default
 
 AnyModule (const AnyModule &other)
 Creates a shallow copy of an AnyModule.
 
AnyModuleoperator= (const AnyModule &other)
 
AnyModule clone (optional< Device > device=nullopt) const
 Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty. More...
 
template<typename ModuleType >
AnyModuleoperator= (std::shared_ptr< ModuleType > module)
 Assigns a module to the AnyModule (to circumvent the explicit constructor). More...
 
template<typename... ArgumentTypes>
Value any_forward (ArgumentTypes &&...arguments)
 Invokes forward() on the contained module with the given arguments, and returns the return value as an Value. More...
 
template<typename ReturnType = torch::Tensor, typename... ArgumentTypes>
ReturnType forward (ArgumentTypes &&...arguments)
 Invokes forward() on the contained module with the given arguments, and casts the returned Value to the supplied ReturnType (which defaults to torch::Tensor). More...
 
template<typename T , typename = torch::detail::enable_if_module_t<T>>
Tget ()
 Attempts to cast the underlying module to the given module type. More...
 
template<typename T , typename = torch::detail::enable_if_module_t<T>>
const Tget () const
 Attempts to cast the underlying module to the given module type. More...
 
template<typename T , typename ContainedType = typename T::ContainedType>
T get () const
 Returns the contained module in a nn::ModuleHolder subclass if possible (i.e. More...
 
std::shared_ptr< Moduleptr () const
 Returns a std::shared_ptr whose dynamic type is that of the underlying module. More...
 
template<typename T , typename = torch::detail::enable_if_module_t<T>>
std::shared_ptr< Tptr () const
 Like ptr(), but casts the pointer to the given type.
 
const std::type_info & type_info () const
 Returns the type_info object of the contained value.
 
bool is_empty () const noexcept
 Returns true if the AnyModule does not contain a module.
 

Detailed Description

Stores a type erased Module.

The PyTorch C++ API does not impose an interface on the signature of forward() in Module subclasses. This gives you complete freedom to design your forward() methods to your liking. However, this also means there is no unified base type you could store in order to call forward() polymorphically for any module. This is where the AnyModule comes in. Instead of inheritance, it relies on type erasure for polymorphism.

An AnyModule can store any nn::Module subclass that provides a forward() method. This forward() may accept any types and return any type. Once stored in an AnyModule, you can invoke the underlying module's forward() by calling AnyModule::forward() with the arguments you would supply to the stored module (though see one important limitation below). Example:

.. code-block:: cpp

struct GenericTrainer { torch::nn::AnyModule module;

void train(torch::Tensor input) { module.forward(input); } };

GenericTrainer trainer1{torch::nn::Linear(3, 4)}; GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};

As AnyModule erases the static type of the stored module (and its forward() method) to achieve polymorphism, type checking of arguments is moved to runtime. That is, passing an argument with an incorrect type to an AnyModule will compile, but throw an exception at runtime:

.. code-block:: cpp

torch::nn::AnyModule module(torch::nn::Linear(3, 4)); // Linear takes a tensor as input, but we are passing an integer. // This will compile, but throw a torch::Error exception at runtime. module.forward(123);

.. attention:: One noteworthy limitation of AnyModule is that its forward() method does not support implicit conversion of argument types. For example, if the stored module's forward() method accepts a float and you call any_module.forward(3.4) (where 3.4 is a double), this will throw an exception.

The return type of the AnyModule's forward() method is controlled via the first template argument to AnyModule::forward(). It defaults to torch::Tensor. To change it, you can write any_module.forward<int>(), for example.

.. code-block:: cpp

torch::nn::AnyModule module(torch::nn::Linear(3, 4)); auto output = module.forward(torch::ones({2, 3}));

struct IntModule { int forward(int x) { return x; } }; torch::nn::AnyModule module(IntModule{}); int output = module.forward<int>(5);

The only other method an AnyModule provides access to on the stored module is clone(). However, you may acquire a handle on the module via .ptr(), which returns a shared_ptr<nn::Module>. Further, if you know the concrete type of the stored module, you can get a concrete handle to it using .get<T>() where T is the concrete module type.

.. code-block:: cpp

torch::nn::AnyModule module(torch::nn::Linear(3, 4)); std::shared_ptr<nn::Module> ptr = module.ptr(); torch::nn::Linear linear(module.get<torch::nn::Linear>());

Definition at line 108 of file any.h.

Constructor & Destructor Documentation

torch::nn::AnyModule::AnyModule ( AnyModule &&  )
default

Move construction and assignment is allowed, and follows the default behavior of move for std::unique_ptr.

Member Function Documentation

template<typename... ArgumentTypes>
AnyModule::Value torch::nn::AnyModule::any_forward ( ArgumentTypes &&...  arguments)

Invokes forward() on the contained module with the given arguments, and returns the return value as an Value.

Use this method when chaining AnyModules in a loop.

Definition at line 455 of file any.h.

AnyModule torch::nn::AnyModule::clone ( optional< Device device = nullopt) const
inline

Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty.

Definition at line 443 of file any.h.

template<typename ReturnType , typename... ArgumentTypes>
ReturnType torch::nn::AnyModule::forward ( ArgumentTypes &&...  arguments)

Invokes forward() on the contained module with the given arguments, and casts the returned Value to the supplied ReturnType (which defaults to torch::Tensor).

Definition at line 466 of file any.h.

template<typename T , typename >
T & torch::nn::AnyModule::get ( )

Attempts to cast the underlying module to the given module type.

Throws an exception if the types do not match.

Definition at line 472 of file any.h.

template<typename T , typename >
const T & torch::nn::AnyModule::get ( ) const

Attempts to cast the underlying module to the given module type.

Throws an exception if the types do not match.

Definition at line 478 of file any.h.

template<typename T , typename ContainedType >
T torch::nn::AnyModule::get ( ) const

Returns the contained module in a nn::ModuleHolder subclass if possible (i.e.

if T has a constructor for the underlying module type).

Definition at line 484 of file any.h.

template<typename ModuleType >
AnyModule & torch::nn::AnyModule::operator= ( std::shared_ptr< ModuleType >  module)

Assigns a module to the AnyModule (to circumvent the explicit constructor).

Definition at line 450 of file any.h.

std::shared_ptr< Module > torch::nn::AnyModule::ptr ( ) const
inline

Returns a std::shared_ptr whose dynamic type is that of the underlying module.

Definition at line 488 of file any.h.


The documentation for this class was generated from the following file: