Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions
torch::jit::script::Module Struct Reference

Public Member Functions

 TH_DISALLOW_COPY_AND_ASSIGN (Module)
 
void set_optimized (bool o)
 
bool is_optimized () const
 
IValue forward (std::vector< IValue > inputs)
 
void register_buffer (const std::string &name, autograd::Variable v)
 
void register_parameter (const std::string &name, autograd::Variable v, bool is_buffer)
 
void register_attribute (const std::string &name, const TypePtr type, IValue ivalue)
 
void register_module (const std::string &name, std::shared_ptr< Module > module)
 
Methodcreate_method (const std::string &name, std::shared_ptr< Graph > graph, std::vector< IValue * > member_inputs)
 
Methodcreate_method (const std::string &name, std::function< void(Method &)> creator)
 
IValueparameter_slot (const std::string &name) const
 
void set_parameter (const std::string &name, at::Tensor v)
 
autograd::Variable get_parameter (const std::string &name) const
 
autograd::Variable get_buffer (const std::string &name) const
 
Methodget_method (const std::string &name) const
 
std::shared_ptr< Moduleget_module (const std::string &name) const
 
const torch::OrderedDict< std::string, NamedModule > & get_modules () const
 
const torch::OrderedDict< std::string, NamedIValue > & get_parameters () const
 
const torch::OrderedDict< std::string, NamedIValue > & get_attributes () const
 
const torch::OrderedDict< std::string, std::unique_ptr< Method > > & get_methods () const
 
NamedIValuefind_parameter (const std::string &name)
 
NamedIValuefind_attribute (const std::string &name)
 
NamedIValuefind_buffer (const std::string &name)
 
NamedModulefind_module (const std::string &name)
 
Methodfind_method (const std::string &name)
 
void apply (std::function< void(Module &)> fn)
 
void train (bool on=true)
 Enables "training" mode.
 
void eval ()
 Calls train(false) to enable "eval" mode. More...
 
bool is_training ()
 True if the module is in training mode.
 
TORCH_API void to (at::Device device, at::ScalarType dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype and device. More...
 
TORCH_API void to (at::ScalarType dtype, bool non_blocking=false)
 Recursively casts all parameters to the given dtype. More...
 
TORCH_API void to (at::Device device, bool non_blocking=false)
 Recursively moves all parameters to the given device. More...
 
template<typename... Types>
IValue run_method (const std::string &method_name, Types &&...args)
 Run a method from this module. More...
 
void save (std::ostream &out, const ExtraFilesMap &extra_files=ExtraFilesMap())
 
void save (const std::string &filename, const ExtraFilesMap &extra_files=ExtraFilesMap())
 
void copy_into (ModuleLookup module_lookup, std::unordered_map< IValue *, IValue * > &parameter_remap, std::vector< std::string > names={}) const
 

Detailed Description

Definition at line 375 of file module.h.

Member Function Documentation

void torch::jit::script::Module::eval ( )
inline

Calls train(false) to enable "eval" mode.

Do not override this method, override train() instead.

Definition at line 538 of file module.h.

template<typename... Types>
IValue torch::jit::script::Module::run_method ( const std::string &  method_name,
Types &&...  args 
)
inline

Run a method from this module.

For example:

IValue output = module->run("relu_script", a, b);

To get a compile a module from a source string, see torch::jit::compile

Parameters
method_nameThe name of the method to run
argsArguments to be passed to the method
Returns
An IValue containing the return value (or values if it is a tuple) from the method

Definition at line 591 of file module.h.

void torch::jit::script::Module::to ( at::Device  device,
at::ScalarType  dtype,
bool  non_blocking = false 
)

Recursively casts all parameters to the given dtype and device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Definition at line 98 of file module.cpp.

void torch::jit::script::Module::to ( at::ScalarType  dtype,
bool  non_blocking = false 
)

Recursively casts all parameters to the given dtype.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Definition at line 102 of file module.cpp.

void torch::jit::script::Module::to ( at::Device  device,
bool  non_blocking = false 
)

Recursively moves all parameters to the given device.

If non_blocking is true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.

Definition at line 106 of file module.cpp.


The documentation for this struct was generated from the following files: