Caffe2 - C++ API
A deep learning, cross platform ML framework
module.cpp
1 #include <torch/nn/module.h>
2 
3 #include <torch/ordered_dict.h>
4 
5 #include <torch/csrc/autograd/generated/VariableType.h>
6 
7 #include <c10/util/Exception.h>
8 
9 #include <algorithm>
10 #include <functional>
11 #include <map>
12 #include <ostream>
13 #include <string>
14 #include <typeinfo>
15 
16 namespace torch {
17 namespace nn {
18 namespace {
21 std::string join_name(const std::string& name_prefix, const std::string& name) {
22  size_t total_size = name.size();
23  if (!name_prefix.empty()) {
24  total_size += name_prefix.size() + 1;
25  }
26  std::string full_name;
27  full_name.reserve(total_size);
28  if (!name_prefix.empty()) {
29  full_name += name_prefix;
30  full_name.push_back('.');
31  }
32  full_name += name;
33  return full_name;
34 }
35 
36 void extend(
37  std::vector<Tensor>& vector,
39  vector.reserve(vector.size() + dict.size());
40  for (const auto& item : dict) {
41  vector.push_back(item.value());
42  }
43 }
44 } // namespace
45 
47  : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {}
48 
49 Module::Module(std::string name) : Module() {
50  name_ = std::move(name);
51 }
52 
53 const std::string& Module::name() const noexcept {
54  // If the name optional is empty at this point, we grab the name of the
55  // dynamic type via RTTI. Note that we cannot do this in the constructor,
56  // because in the constructor of a base class `this` always refers to the base
57  // type. Inheritance effectively does not work in constructors. Also this note
58  // from http://en.cppreference.com/w/cpp/language/typeid:
59  // If typeid is used on an object under construction or destruction (in a
60  // destructor or in a constructor, including constructor's initializer list
61  // or default member initializers), then the std::type_info object referred
62  // to by this typeid represents the class that is being constructed or
63  // destroyed even if it is not the most-derived class.
64  if (!name_.has_value()) {
65  name_ = c10::demangle(typeid(*this).name());
66 #if defined(_WIN32)
67  // Windows adds "struct" or "class" as a prefix.
68  if (name_->find("struct ") == 0) {
69  name_->erase(name_->begin(), name_->begin() + 7);
70  } else if (name_->find("class ") == 0) {
71  name_->erase(name_->begin(), name_->begin() + 6);
72  }
73 #endif // defined(_WIN32)
74  }
75  return *name_;
76 }
77 
78 std::shared_ptr<Module> Module::clone(const optional<Device>& device) const {
79  AT_ERROR(
80  "clone() has not been implemented for ",
81  name(),
82  ". Subclass torch::nn::Cloneable<",
83  name(),
84  "> instead of torch::nn::Module to inherit the ability to clone.");
85 }
86 
87 void Module::apply(const ModuleApplyFunction& function) {
88  function(*this);
89  apply_to_submodules(
90  [&function](const std::string&, const std::shared_ptr<Module>& module) {
91  function(*module);
92  });
93 }
94 
95 void Module::apply(const ConstModuleApplyFunction& function) const {
96  function(*this);
97  apply_to_submodules(
98  [&function](const std::string&, const std::shared_ptr<Module>& module) {
99  function(*module);
100  });
101 }
102 
104  const NamedModuleApplyFunction& function,
105  const std::string& name_prefix) {
106  function(/*name=*/name_prefix, *this);
107  apply_to_submodules(
108  [&function](
109  const std::string& name, const std::shared_ptr<Module>& module) {
110  function(name, *module);
111  },
112  name_prefix);
113 }
114 
116  const ConstNamedModuleApplyFunction& function,
117  const std::string& name_prefix) const {
118  function(/*name=*/name_prefix, *this);
119  apply_to_submodules(
120  [&function](
121  const std::string& name, const std::shared_ptr<Module>& module) {
122  function(name, *module);
123  },
124  name_prefix);
125 }
126 
127 void Module::apply(const ModulePointerApplyFunction& function) const {
128  function(shared_from_this_checked());
129  apply_to_submodules(
130  [&function](const std::string&, const std::shared_ptr<Module>& module) {
131  function(module);
132  });
133 }
134 
136  const NamedModulePointerApplyFunction& function,
137  const std::string& name_prefix) const {
138  function(
139  /*name=*/name_prefix, shared_from_this_checked());
140  apply_to_submodules(function, name_prefix);
141 }
142 
143 std::vector<Tensor> Module::parameters(bool recurse) const {
144  if (!recurse) {
145  return parameters_.values();
146  }
147  std::vector<Tensor> result;
148  apply(
149  [&result](const Module& module) { extend(result, module.parameters_); });
150  return result;
151 }
152 
154  if (!recurse) {
155  return parameters_;
156  }
158  apply([&result](const std::string& name, const Module& module) {
159  for (const auto& parameter : module.parameters_) {
160  result.insert(join_name(name, parameter.key()), parameter.value());
161  }
162  });
163  return result;
164 }
165 
166 std::vector<Tensor> Module::buffers(bool recurse) const {
167  if (!recurse) {
168  return buffers_.values();
169  }
170  std::vector<Tensor> result;
171  apply([&result](const Module& module) { extend(result, module.buffers_); });
172  return result;
173 }
175  if (!recurse) {
176  return buffers_;
177  }
179  apply([&result](const std::string& name, const Module& module) {
180  for (const auto& buffer : module.buffers_) {
181  result.insert(join_name(name, buffer.key()), buffer.value());
182  }
183  });
184  return result;
185 }
186 
187 std::vector<std::shared_ptr<Module>> Module::modules(bool include_self) const {
188  std::vector<std::shared_ptr<Module>> result;
189  if (include_self) {
190  apply([&result](const std::shared_ptr<Module>& module) {
191  result.push_back(module);
192  });
193  } else {
194  apply_to_submodules(
195  [&result](const std::string&, const std::shared_ptr<Module>& module) {
196  result.push_back(module);
197  });
198  }
199  return result;
200 }
201 
203  const std::string& name_prefix,
204  bool include_self) const {
206  if (include_self) {
207  apply(
208  [&result](
209  const std::string& key, const std::shared_ptr<Module>& module) {
210  result.insert(key, module);
211  },
212  name_prefix);
213  } else {
214  apply_to_submodules(
215  [&result](
216  const std::string& key, const std::shared_ptr<Module>& module) {
217  result.insert(key, module);
218  },
219  name_prefix);
220  }
221  return result;
222 }
223 
224 std::vector<std::shared_ptr<Module>> Module::children() const {
225  return children_.values();
226 }
227 
229  const {
230  return children_;
231 }
232 
233 void Module::train(bool on) {
234  for (auto& child : children_) {
235  child.value()->train(on);
236  }
237  is_training_ = on;
238 }
239 
240 void Module::eval() {
241  train(/*on=*/false);
242 }
243 
244 void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) {
245  to_impl(device, dtype, non_blocking);
246 }
247 
248 void Module::to(torch::Dtype dtype, bool non_blocking) {
249  to_impl(dtype, non_blocking);
250 }
251 
252 void Module::to(torch::Device device, bool non_blocking) {
253  to_impl(device, non_blocking);
254 }
255 
256 bool Module::is_training() const noexcept {
257  return is_training_;
258 }
259 
261  for (auto& child : children_) {
262  child.value()->zero_grad();
263  }
264  for (auto& parameter : parameters_) {
265  auto& grad = parameter->grad();
266  if (grad.defined()) {
267  grad = grad.detach();
268  grad.zero_();
269  }
270  }
271 }
272 
274  for (const auto& parameter : parameters_) {
275  archive.write(parameter.key(), parameter.value());
276  }
277  for (const auto& buffer : buffers_) {
278  archive.write(buffer.key(), buffer.value(), /*is_buffer=*/true);
279  }
280  for (const auto& child : children_) {
281  serialize::OutputArchive child_archive;
282  child.value()->save(child_archive);
283  archive.write(child.key(), child_archive);
284  }
285 }
286 
288  for (auto& parameter : parameters_) {
289  archive.read(parameter.key(), parameter.value());
290  }
291  for (auto& buffer : buffers_) {
292  archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true);
293  }
294  for (const auto& child : children_) {
295  serialize::InputArchive child_archive;
296  archive.read(child.key(), child_archive);
297  child.value()->load(child_archive);
298  }
299 }
300 
302  std::string name,
303  Tensor tensor,
304  bool requires_grad) {
305  AT_CHECK(!name.empty(), "Parameter name must not be empty");
306  AT_CHECK(
307  name.find('.') == std::string::npos,
308  "Parameter name must not contain a dot (got '",
309  name,
310  "')");
311  tensor.set_requires_grad(requires_grad);
312  return parameters_.insert(std::move(name), std::move(tensor));
313 }
314 
315 Tensor& Module::register_buffer(std::string name, Tensor tensor) {
316  AT_CHECK(!name.empty(), "Buffer name must not be empty");
317  AT_CHECK(
318  name.find('.') == std::string::npos,
319  "Buffer name must not contain a dot (got '",
320  name,
321  "')");
322  return buffers_.insert(std::move(name), std::move(tensor));
323 }
324 
325 void Module::pretty_print(std::ostream& stream) const {
326  stream << name();
327 }
328 
329 void Module::pretty_print_recursive(
330  std::ostream& stream,
331  const std::string& indentation) const {
332  pretty_print(stream);
333  if (!children_.is_empty()) {
334  stream << "(\n";
335  const std::string next_indentation = indentation + " ";
336  for (const auto& child : children_) {
337  stream << next_indentation << "(" << child.key() << "): ";
338  child.value()->pretty_print_recursive(stream, next_indentation);
339  stream << '\n';
340  }
341  stream << indentation << ")";
342  }
343 }
344 
345 void Module::clone_(Module& other, const optional<Device>& device) {}
346 
347 void Module::apply_to_submodules(
348  const NamedModulePointerApplyFunction& function,
349  const std::string& name_prefix) const {
350  for (const auto& child : children_) {
351  auto qualified_name = join_name(name_prefix, child.key());
352  function(qualified_name, child.value());
353  child.value()->apply_to_submodules(function, qualified_name);
354  }
355 }
356 
357 std::shared_ptr<Module> Module::shared_from_this_checked() const {
358  std::shared_ptr<const Module> ptr;
359  try {
360  ptr = shared_from_this();
361  } catch (const std::bad_weak_ptr& e) {
362  AT_ERROR(
363  "It looks like you attempted to retrieve your top-level module "
364  "as a shared_ptr, but it is not stored in a shared_ptr. "
365  "Use std::make_shared<",
366  name(),
367  "> instead of creating your module on "
368  "the stack, or alternatively do not try to access your top-level "
369  "module at all by passing /*include_self=*/false "
370  "to modules() or named_modules()");
371  }
372  return std::const_pointer_cast<Module>(ptr);
373 }
374 
375 std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
376  module.pretty_print_recursive(stream, "");
377  return stream;
378 }
379 
381  serialize::OutputArchive& archive,
382  const std::shared_ptr<nn::Module>& module) {
383  AT_CHECK(module != nullptr, "Cannot serialize empty module");
384  module->save(archive);
385  return archive;
386 }
387 
388 serialize::InputArchive& operator>>(
389  serialize::InputArchive& archive,
390  const std::shared_ptr<nn::Module>& module) {
391  AT_CHECK(module != nullptr, "Cannot deserialize empty module");
392  module->load(archive);
393  return archive;
394 }
395 } // namespace nn
396 } // namespace torch
virtual void pretty_print(std::ostream &stream) const
Streams a pretty representation of the Module into the given stream.
Definition: module.cpp:325
size_t size() const noexcept
Returns the number of items currently stored in the OrderedDict.
Definition: ordered_dict.h:419
virtual void save(serialize::OutputArchive &archive) const
Serializes the Module into the given OutputArchive.
Definition: module.cpp:273
std::vector< Tensor > buffers(bool recurse=true) const
Returns the buffers of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:166
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
OrderedDict< std::string, Tensor > named_parameters(bool recurse=true) const
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is tr...
Definition: module.cpp:153
Value & insert(K &&key, V &&value)
Inserts a new (key, value) pair into the OrderedDict.
Definition: ordered_dict.h:357
virtual void load(serialize::InputArchive &archive)
Deserializes the Module from the given InputArchive.
Definition: module.cpp:287
const std::string & name() const noexcept
Returns the name of the Module.
Definition: module.cpp:53
virtual bool is_training() const noexcept
True if the module is in training mode.
Definition: module.cpp:256
std::vector< std::shared_ptr< Module > > children() const
Returns the direct submodules of this Module.
Definition: module.cpp:224
virtual void zero_grad()
Recursively zeros out the grad value of each registered parameter.
Definition: module.cpp:260
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
Definition: module.cpp:187
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
Definition: module.cpp:143
virtual void train(bool on=true)
Enables "training" mode.
Definition: module.cpp:233
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
Definition: module.cpp:315
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
Definition: module.cpp:78
TORCH_API friend std::ostream & operator<<(std::ostream &stream, const nn::Module &module)
Pretty prints the given Module into the ostream.
Definition: module.cpp:375
OrderedDict< std::string, std::shared_ptr< Module > > named_children() const
Returns an OrderedDict of the direct submodules of this Module and their keys.
Definition: module.cpp:228
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Definition: module.cpp:244
The base class for all modules in PyTorch.
Definition: module.h:62
Definition: jit_type.h:17
OrderedDict< std::string, Tensor > named_buffers(bool recurse=true) const
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true ...
Definition: module.cpp:174
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
Definition: Type.cpp:23
void apply(const ModuleApplyFunction &function)
Applies the function to the Module and recursively to every submodule.
Definition: module.cpp:87
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
Module()
Constructs the module without immediate knowledge of the submodule&#39;s name.
Definition: module.cpp:46
void eval()
Calls train(false) to enable "eval" mode.
Definition: module.cpp:240
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16
OrderedDict< std::string, std::shared_ptr< Module > > named_modules(const std::string &name_prefix=std::string(), bool include_self=true) const
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys...
Definition: module.cpp:202