Caffe2 - C++ API
A deep learning, cross platform ML framework
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
cloneable.h
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <c10/core/TensorOptions.h>
8 #include <c10/util/Exception.h>
9 
10 #include <memory>
11 #include <utility>
12 
13 namespace torch {
14 namespace nn {
22 template <typename Derived>
23 class Cloneable : public virtual Module {
24  public:
25  using Module::Module;
26 
29  virtual void reset() = 0;
30 
34  std::shared_ptr<Module> clone(
35  const optional<Device>& device = nullopt) const override {
36  NoGradGuard no_grad;
37 
38  const auto& self = static_cast<const Derived&>(*this);
39  auto copy = std::make_shared<Derived>(self);
40  copy->parameters_.clear();
41  copy->buffers_.clear();
42  copy->children_.clear();
43  copy->reset();
44  AT_CHECK(
45  copy->parameters_.size() == parameters_.size(),
46  "The cloned module does not have the same number of "
47  "parameters as the original module after calling reset(). "
48  "Are you sure you called register_parameter() inside reset() "
49  "and not the constructor?");
50  for (const auto& parameter : parameters_) {
51  auto data = autograd::Variable(*parameter).data().clone();
52  copy->parameters_[parameter.key()].set_data(
53  device ? data.to(*device) : data);
54  }
55  AT_CHECK(
56  copy->buffers_.size() == buffers_.size(),
57  "The cloned module does not have the same number of "
58  "buffers as the original module after calling reset(). "
59  "Are you sure you called register_buffer() inside reset() "
60  "and not the constructor?");
61  for (const auto& buffer : buffers_) {
62  auto data = autograd::Variable(*buffer).data().clone();
63  copy->buffers_[buffer.key()].set_data(device ? data.to(*device) : data);
64  }
65  AT_CHECK(
66  copy->children_.size() == children_.size(),
67  "The cloned module does not have the same number of "
68  "child modules as the original module after calling reset(). "
69  "Are you sure you called register_module() inside reset() "
70  "and not the constructor?");
71  for (const auto& child : children_) {
72  copy->children_[child.key()]->clone_(*child.value(), device);
73  }
74  return copy;
75  }
76 
77  private:
78  void clone_(Module& other, const optional<Device>& device) final {
79  // Here we are *pretty* certain that `other's` type is `Derived` (because it
80  // was registered under the same name as `this`), but you never know what
81  // crazy things `reset()` does, so `dynamic_cast` just to be safe.
82  auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
83  AT_CHECK(
84  clone != nullptr,
85  "Attempted to clone submodule, but it is of a "
86  "different type than the submodule it was to be cloned into");
87  static_cast<Derived&>(*this) = std::move(*clone);
88  }
89 };
90 
91 } // namespace nn
92 } // namespace torch
virtual void reset()=0
reset() must perform initialization of all members with reference semantics, most importantly paramet...
std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const override
Performs a recursive "deep copy" of the Module, such that all parameters and submodules in the cloned...
Definition: cloneable.h:34
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
Definition: module.cpp:78
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
The base class for all modules in PyTorch.
Definition: module.h:62
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
Module()
Constructs the module without immediate knowledge of the submodule&#39;s name.
Definition: module.cpp:46