1 #include <torch/nn/module.h> 3 #include <torch/ordered_dict.h> 5 #include <torch/csrc/autograd/generated/VariableType.h> 7 #include <c10/util/Exception.h> 21 std::string join_name(
const std::string& name_prefix,
const std::string& name) {
22 size_t total_size = name.size();
23 if (!name_prefix.empty()) {
24 total_size += name_prefix.size() + 1;
26 std::string full_name;
27 full_name.reserve(total_size);
28 if (!name_prefix.empty()) {
29 full_name += name_prefix;
30 full_name.push_back(
'.');
37 std::vector<Tensor>& vector,
39 vector.reserve(vector.size() + dict.
size());
40 for (
const auto& item : dict) {
41 vector.push_back(item.value());
47 : parameters_(
"Parameter"), buffers_(
"Buffer"), children_(
"Submodule") {}
50 name_ = std::move(name);
64 if (!name_.has_value()) {
68 if (name_->find(
"struct ") == 0) {
69 name_->erase(name_->begin(), name_->begin() + 7);
70 }
else if (name_->find(
"class ") == 0) {
71 name_->erase(name_->begin(), name_->begin() + 6);
73 #endif // defined(_WIN32) 80 "clone() has not been implemented for ",
82 ". Subclass torch::nn::Cloneable<",
84 "> instead of torch::nn::Module to inherit the ability to clone.");
90 [&
function](
const std::string&,
const std::shared_ptr<Module>& module) {
98 [&
function](
const std::string&,
const std::shared_ptr<Module>& module) {
104 const NamedModuleApplyFunction&
function,
105 const std::string& name_prefix) {
106 function(name_prefix, *
this);
109 const std::string& name,
const std::shared_ptr<Module>& module) {
110 function(
name, *module);
116 const ConstNamedModuleApplyFunction&
function,
117 const std::string& name_prefix)
const {
118 function(name_prefix, *
this);
121 const std::string& name,
const std::shared_ptr<Module>& module) {
122 function(
name, *module);
128 function(shared_from_this_checked());
130 [&
function](
const std::string&,
const std::shared_ptr<Module>& module) {
136 const NamedModulePointerApplyFunction&
function,
137 const std::string& name_prefix)
const {
139 name_prefix, shared_from_this_checked());
140 apply_to_submodules(
function, name_prefix);
145 return parameters_.values();
147 std::vector<Tensor> result;
149 [&result](
const Module& module) { extend(result, module.parameters_); });
158 apply([&result](
const std::string& name,
const Module& module) {
159 for (
const auto& parameter : module.parameters_) {
160 result.insert(join_name(name, parameter.key()), parameter.value());
168 return buffers_.values();
170 std::vector<Tensor> result;
171 apply([&result](
const Module& module) { extend(result, module.buffers_); });
179 apply([&result](
const std::string& name,
const Module& module) {
180 for (
const auto& buffer : module.buffers_) {
181 result.insert(join_name(name, buffer.key()), buffer.value());
188 std::vector<std::shared_ptr<Module>> result;
190 apply([&result](
const std::shared_ptr<Module>& module) {
191 result.push_back(module);
195 [&result](
const std::string&,
const std::shared_ptr<Module>& module) {
196 result.push_back(module);
203 const std::string& name_prefix,
204 bool include_self)
const {
209 const std::string& key,
const std::shared_ptr<Module>& module) {
210 result.
insert(key, module);
216 const std::string& key,
const std::shared_ptr<Module>& module) {
217 result.
insert(key, module);
225 return children_.values();
234 for (
auto& child : children_) {
235 child.value()->train(on);
245 to_impl(device, dtype, non_blocking);
249 to_impl(dtype, non_blocking);
253 to_impl(device, non_blocking);
261 for (
auto& child : children_) {
262 child.value()->zero_grad();
264 for (
auto& parameter : parameters_) {
265 auto& grad = parameter->grad();
266 if (grad.defined()) {
267 grad = grad.detach();
274 for (
const auto& parameter : parameters_) {
275 archive.
write(parameter.key(), parameter.value());
277 for (
const auto& buffer : buffers_) {
278 archive.
write(buffer.key(), buffer.value(),
true);
280 for (
const auto& child : children_) {
282 child.value()->save(child_archive);
283 archive.
write(child.key(), child_archive);
288 for (
auto& parameter : parameters_) {
289 archive.
read(parameter.key(), parameter.value());
291 for (
auto& buffer : buffers_) {
292 archive.
read(buffer.key(), buffer.value(),
true);
294 for (
const auto& child : children_) {
296 archive.
read(child.key(), child_archive);
297 child.value()->load(child_archive);
304 bool requires_grad) {
305 AT_CHECK(!name.empty(),
"Parameter name must not be empty");
307 name.find(
'.') == std::string::npos,
308 "Parameter name must not contain a dot (got '",
311 tensor.set_requires_grad(requires_grad);
312 return parameters_.insert(std::move(name), std::move(tensor));
316 AT_CHECK(!name.empty(),
"Buffer name must not be empty");
318 name.find(
'.') == std::string::npos,
319 "Buffer name must not contain a dot (got '",
322 return buffers_.insert(std::move(name), std::move(tensor));
329 void Module::pretty_print_recursive(
330 std::ostream& stream,
331 const std::string& indentation)
const {
333 if (!children_.is_empty()) {
335 const std::string next_indentation = indentation +
" ";
336 for (
const auto& child : children_) {
337 stream << next_indentation <<
"(" << child.key() <<
"): ";
338 child.value()->pretty_print_recursive(stream, next_indentation);
341 stream << indentation <<
")";
347 void Module::apply_to_submodules(
348 const NamedModulePointerApplyFunction&
function,
349 const std::string& name_prefix)
const {
350 for (
const auto& child : children_) {
351 auto qualified_name = join_name(name_prefix, child.key());
352 function(qualified_name, child.value());
353 child.value()->apply_to_submodules(
function, qualified_name);
357 std::shared_ptr<Module> Module::shared_from_this_checked()
const {
358 std::shared_ptr<const Module> ptr;
360 ptr = shared_from_this();
361 }
catch (
const std::bad_weak_ptr& e) {
363 "It looks like you attempted to retrieve your top-level module " 364 "as a shared_ptr, but it is not stored in a shared_ptr. " 365 "Use std::make_shared<",
367 "> instead of creating your module on " 368 "the stack, or alternatively do not try to access your top-level " 369 "module at all by passing /*include_self=*/false " 370 "to modules() or named_modules()");
372 return std::const_pointer_cast<
Module>(ptr);
376 module.pretty_print_recursive(stream,
"");
382 const std::shared_ptr<nn::Module>& module) {
383 AT_CHECK(module !=
nullptr,
"Cannot serialize empty module");
384 module->save(archive);
390 const std::shared_ptr<nn::Module>& module) {
391 AT_CHECK(module !=
nullptr,
"Cannot deserialize empty module");
392 module->load(archive);
virtual void pretty_print(std::ostream &stream) const
Streams a pretty representation of the Module into the given stream.
size_t size() const noexcept
Returns the number of items currently stored in the OrderedDict.
virtual void save(serialize::OutputArchive &archive) const
Serializes the Module into the given OutputArchive.
std::vector< Tensor > buffers(bool recurse=true) const
Returns the buffers of this Module and if recurse is true, also recursively of every submodule...
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
OrderedDict< std::string, Tensor > named_parameters(bool recurse=true) const
Returns an OrderedDict with the parameters of this Module along with their keys, and if recurse is tr...
Value & insert(K &&key, V &&value)
Inserts a new (key, value) pair into the OrderedDict.
virtual void load(serialize::InputArchive &archive)
Deserializes the Module from the given InputArchive.
const std::string & name() const noexcept
Returns the name of the Module.
virtual bool is_training() const noexcept
True if the module is in training mode.
std::vector< std::shared_ptr< Module > > children() const
Returns the direct submodules of this Module.
virtual void zero_grad()
Recursively zeros out the grad value of each registered parameter.
std::vector< std::shared_ptr< Module > > modules(bool include_self=true) const
Returns the submodules of this Module (the entire submodule hierarchy) and if include_self is true...
Represents a a compute device on which a tensor is located.
std::vector< Tensor > parameters(bool recurse=true) const
Returns the parameters of this Module and if recurse is true, also recursively of every submodule...
virtual void train(bool on=true)
Enables "training" mode.
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
virtual std::shared_ptr< Module > clone(const optional< Device > &device=nullopt) const
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodule...
TORCH_API friend std::ostream & operator<<(std::ostream &stream, const nn::Module &module)
Pretty prints the given Module into the ostream.
OrderedDict< std::string, std::shared_ptr< Module > > named_children() const
Returns an OrderedDict of the direct submodules of this Module and their keys.
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
The base class for all modules in PyTorch.
OrderedDict< std::string, Tensor > named_buffers(bool recurse=true) const
Returns an OrderedDict with the buffers of this Module along with their keys, and if recurse is true ...
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
void apply(const ModuleApplyFunction &function)
Applies the function to the Module and recursively to every submodule.
Module()
Constructs the module without immediate knowledge of the submodule's name.
void eval()
Calls train(false) to enable "eval" mode.
An ordered dictionary implementation, akin to Python's OrderedDict.
OrderedDict< std::string, std::shared_ptr< Module > > named_modules(const std::string &name_prefix=std::string(), bool include_self=true) const
Returns an OrderedDict of he submodules of this Module (the entire submodule hierarchy) and thei keys...