3 #include <torch/nn/pimpl.h> 4 #include <torch/ordered_dict.h> 5 #include <torch/serialize/archive.h> 6 #include <torch/types.h> 15 #include <type_traits> 62 class TORCH_API
Module :
public std::enable_shared_from_this<Module> {
64 using ModuleApplyFunction = std::function<void(Module&)>;
65 using ConstModuleApplyFunction = std::function<void(const Module&)>;
66 using NamedModuleApplyFunction =
67 std::function<void(const std::string&, Module&)>;
68 using ConstNamedModuleApplyFunction =
69 std::function<void(const std::string&, const Module&)>;
70 using ModulePointerApplyFunction =
71 std::function<void(const std::shared_ptr<Module>&)>;
72 using NamedModulePointerApplyFunction =
73 std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
76 explicit Module(std::string name);
83 virtual ~
Module() =
default;
94 const std::string& name()
const noexcept;
112 virtual std::shared_ptr<Module> clone(
125 void apply(
const ModuleApplyFunction&
function);
137 void apply(
const ConstModuleApplyFunction&
function)
const;
153 const NamedModuleApplyFunction&
function,
154 const std::string& name_prefix = std::string());
170 const ConstNamedModuleApplyFunction&
function,
171 const std::string& name_prefix = std::string())
const;
183 void apply(
const ModulePointerApplyFunction&
function)
const;
201 const NamedModulePointerApplyFunction&
function,
202 const std::string& name_prefix = std::string())
const;
206 std::vector<Tensor> parameters(
bool recurse =
true)
const;
214 std::vector<Tensor> buffers(
bool recurse =
true)
const;
231 std::vector<std::shared_ptr<Module>> modules(
bool include_self =
true)
const;
247 const std::string& name_prefix = std::string(),
248 bool include_self =
true)
const;
251 std::vector<std::shared_ptr<Module>> children()
const;
258 virtual void train(
bool on =
true);
273 virtual bool is_training()
const noexcept;
284 bool non_blocking =
false);
292 virtual void to(torch::Dtype dtype,
bool non_blocking =
false);
300 virtual void to(
torch::Device device,
bool non_blocking =
false);
303 virtual void zero_grad();
321 template <
typename ModuleType>
322 typename ModuleType::ContainedType* as() noexcept;
339 template <
typename ModuleType>
340 const typename ModuleType::ContainedType* as()
const noexcept;
360 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
361 ModuleType* as() noexcept;
381 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
382 const ModuleType* as()
const noexcept;
397 virtual void pretty_print(std::ostream& stream)
const;
413 Tensor& register_parameter(
431 Tensor& register_buffer(std::string name,
Tensor tensor);
445 template <
typename ModuleType>
446 std::shared_ptr<ModuleType> register_module(
448 std::shared_ptr<ModuleType> module);
464 template <
typename ModuleType>
465 std::shared_ptr<ModuleType> register_module(
472 template <
typename Derived>
476 TORCH_API
friend std::ostream& operator<<(
477 std::ostream& stream,
486 template <
typename... Ts>
487 void to_impl(Ts&&... ts);
490 void pretty_print_recursive(
491 std::ostream& stream,
492 const std::string& indentation)
const;
496 void apply_to_submodules(
497 const NamedModulePointerApplyFunction&
function,
498 const std::string& name_prefix = std::string())
const;
501 std::shared_ptr<Module> shared_from_this_checked()
const;
516 bool is_training_{
true};
522 const std::shared_ptr<nn::Module>& module);
527 const std::shared_ptr<nn::Module>& module);
531 template <
typename ModuleType>
535 return as<typename ModuleType::ContainedType>();
538 template <
typename ModuleType>
539 const typename ModuleType::ContainedType*
Module::as() const noexcept {
542 return as<typename ModuleType::ContainedType>();
545 template <
typename ModuleType,
typename>
547 return dynamic_cast<ModuleType*
>(
this);
550 template <
typename ModuleType,
typename>
552 return dynamic_cast<const ModuleType*
>(
this);
555 template <
typename ModuleType>
558 std::shared_ptr<ModuleType> module) {
559 AT_CHECK(!name.empty(),
"Submodule name must not be empty");
561 name.find(
'.') == std::string::npos,
562 "Submodule name must not contain a dot (got '",
565 auto& base_module = children_.insert(std::move(name), std::move(module));
566 return std::dynamic_pointer_cast<ModuleType>(base_module);
569 template <
typename ModuleType>
573 return register_module(std::move(name), module_holder.
ptr());
576 template <
typename... Ts>
577 void Module::to_impl(Ts&&... ts) {
579 for (
auto& child : children_) {
580 child.value()->to(ts...);
583 for (
auto& parameter : parameters_) {
587 for (
auto& buffer : buffers_) {
Represents a a compute device on which a tensor is located.
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
The base class for all modules in PyTorch.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
ModuleType::ContainedType * as() noexcept
Attempts to cast this Module to the given ModuleType.
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
An ordered dictionary implementation, akin to Python's OrderedDict.