3 #include <torch/nn/module.h> 4 #include <torch/types.h> 5 #include <torch/utils.h> 7 #include <c10/core/TensorOptions.h> 8 #include <c10/util/Exception.h> 22 template <
typename Derived>
29 virtual void reset() = 0;
38 const auto&
self =
static_cast<const Derived&
>(*this);
39 auto copy = std::make_shared<Derived>(
self);
40 copy->parameters_.clear();
41 copy->buffers_.clear();
42 copy->children_.clear();
45 copy->parameters_.size() == parameters_.size(),
46 "The cloned module does not have the same number of " 47 "parameters as the original module after calling reset(). " 48 "Are you sure you called register_parameter() inside reset() " 49 "and not the constructor?");
50 for (
const auto& parameter : parameters_) {
52 copy->parameters_[parameter.key()].set_data(
53 device ? data.to(*device) : data);
56 copy->buffers_.size() == buffers_.size(),
57 "The cloned module does not have the same number of " 58 "buffers as the original module after calling reset(). " 59 "Are you sure you called register_buffer() inside reset() " 60 "and not the constructor?");
61 for (
const auto& buffer : buffers_) {
63 copy->buffers_[buffer.key()].set_data(device ? data.to(*device) : data);
66 copy->children_.size() == children_.size(),
67 "The cloned module does not have the same number of " 68 "child modules as the original module after calling reset(). " 69 "Are you sure you called register_module() inside reset() " 70 "and not the constructor?");
71 for (
const auto& child : children_) {
72 copy->children_[child.key()]->clone_(*child.value(), device);
82 auto clone = std::dynamic_pointer_cast<Derived>(other.
clone(device));
85 "Attempted to clone submodule, but it is of a " 86 "different type than the submodule it was to be cloned into");
87 static_cast<Derived&
>(*this) = std::move(*
clone);
virtual void reset()=0
reset() must perform initialization of all members with reference semantics, most importantly paramet...
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned...
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
The base class for all modules in PyTorch.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Module()
Constructs the module without immediate knowledge of the submodule's name.