Caffe2 - C++ API
A deep learning, cross platform ML framework
module.h
1 #pragma once
2 
3 #include <torch/nn/pimpl.h>
4 #include <torch/ordered_dict.h>
5 #include <torch/serialize/archive.h>
6 #include <torch/types.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <functional>
11 #include <iosfwd>
12 #include <map>
13 #include <memory>
14 #include <string>
15 #include <type_traits>
16 
17 namespace torch {
18 namespace nn {
19 
62 class TORCH_API Module : public std::enable_shared_from_this<Module> {
63  public:
64  using ModuleApplyFunction = std::function<void(Module&)>;
65  using ConstModuleApplyFunction = std::function<void(const Module&)>;
66  using NamedModuleApplyFunction =
67  std::function<void(const std::string&, Module&)>;
68  using ConstNamedModuleApplyFunction =
69  std::function<void(const std::string&, const Module&)>;
70  using ModulePointerApplyFunction =
71  std::function<void(const std::shared_ptr<Module>&)>;
72  using NamedModulePointerApplyFunction =
73  std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
74 
76  explicit Module(std::string name);
77 
81  Module();
82 
83  virtual ~Module() = default;
84 
94  const std::string& name() const noexcept;
95 
112  virtual std::shared_ptr<Module> clone(
113  const optional<Device>& device = nullopt) const;
114 
125  void apply(const ModuleApplyFunction& function);
126 
137  void apply(const ConstModuleApplyFunction& function) const;
138 
152  void apply(
153  const NamedModuleApplyFunction& function,
154  const std::string& name_prefix = std::string());
155 
169  void apply(
170  const ConstNamedModuleApplyFunction& function,
171  const std::string& name_prefix = std::string()) const;
172 
183  void apply(const ModulePointerApplyFunction& function) const;
184 
200  void apply(
201  const NamedModulePointerApplyFunction& function,
202  const std::string& name_prefix = std::string()) const;
203 
206  std::vector<Tensor> parameters(bool recurse = true) const;
207 
210  OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
211 
214  std::vector<Tensor> buffers(bool recurse = true) const;
215 
218  OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;
219 
231  std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;
232 
247  const std::string& name_prefix = std::string(),
248  bool include_self = true) const;
249 
251  std::vector<std::shared_ptr<Module>> children() const;
252 
255  OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;
256 
258  virtual void train(bool on = true);
259 
262  void eval();
263 
273  virtual bool is_training() const noexcept;
274 
281  virtual void to(
282  torch::Device device,
283  torch::Dtype dtype,
284  bool non_blocking = false);
285 
292  virtual void to(torch::Dtype dtype, bool non_blocking = false);
293 
300  virtual void to(torch::Device device, bool non_blocking = false);
301 
303  virtual void zero_grad();
304 
321  template <typename ModuleType>
322  typename ModuleType::ContainedType* as() noexcept;
323 
339  template <typename ModuleType>
340  const typename ModuleType::ContainedType* as() const noexcept;
341 
358  template <
359  typename ModuleType,
360  typename = torch::detail::disable_if_module_holder_t<ModuleType>>
361  ModuleType* as() noexcept;
362 
379  template <
380  typename ModuleType,
381  typename = torch::detail::disable_if_module_holder_t<ModuleType>>
382  const ModuleType* as() const noexcept;
383 
385  virtual void save(serialize::OutputArchive& archive) const;
386 
388  virtual void load(serialize::InputArchive& archive);
389 
397  virtual void pretty_print(std::ostream& stream) const;
398 
399  protected:
413  Tensor& register_parameter(
414  std::string name,
415  Tensor tensor,
416  bool requires_grad = true);
417 
431  Tensor& register_buffer(std::string name, Tensor tensor);
432 
445  template <typename ModuleType>
446  std::shared_ptr<ModuleType> register_module(
447  std::string name,
448  std::shared_ptr<ModuleType> module);
449 
464  template <typename ModuleType>
465  std::shared_ptr<ModuleType> register_module(
466  std::string name,
467  ModuleHolder<ModuleType> module_holder);
468 
469  private:
470  // Friend classes.
471 
472  template <typename Derived>
473  friend class Cloneable;
474 
476  TORCH_API friend std::ostream& operator<<(
477  std::ostream& stream,
478  const nn::Module& module);
479 
480  // Private methods.
481 
483  virtual void clone_(Module& other, const optional<Device>& device);
484 
486  template <typename... Ts>
487  void to_impl(Ts&&... ts);
488 
490  void pretty_print_recursive(
491  std::ostream& stream,
492  const std::string& indentation) const;
493 
496  void apply_to_submodules(
497  const NamedModulePointerApplyFunction& function,
498  const std::string& name_prefix = std::string()) const;
499 
501  std::shared_ptr<Module> shared_from_this_checked() const;
502 
505 
508 
511 
513  mutable optional<std::string> name_;
514 
516  bool is_training_{true};
517 };
518 
520 TORCH_API serialize::OutputArchive& operator<<(
521  serialize::OutputArchive& archive,
522  const std::shared_ptr<nn::Module>& module);
523 
525 TORCH_API serialize::InputArchive& operator>>(
526  serialize::InputArchive& archive,
527  const std::shared_ptr<nn::Module>& module);
528 
529 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
530 
531 template <typename ModuleType>
532 typename ModuleType::ContainedType* Module::as() noexcept {
533  // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
534  // `Linear`, since `LinearImpl` inherits `nn::Module`.
535  return as<typename ModuleType::ContainedType>();
536 }
537 
538 template <typename ModuleType>
539 const typename ModuleType::ContainedType* Module::as() const noexcept {
540  // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
541  // `Linear`, since `LinearImpl` inherits `nn::Module`.
542  return as<typename ModuleType::ContainedType>();
543 }
544 
545 template <typename ModuleType, typename>
546 ModuleType* Module::as() noexcept {
547  return dynamic_cast<ModuleType*>(this);
548 }
549 
550 template <typename ModuleType, typename>
551 const ModuleType* Module::as() const noexcept {
552  return dynamic_cast<const ModuleType*>(this);
553 }
554 
555 template <typename ModuleType>
556 std::shared_ptr<ModuleType> Module::register_module(
557  std::string name,
558  std::shared_ptr<ModuleType> module) {
559  AT_CHECK(!name.empty(), "Submodule name must not be empty");
560  AT_CHECK(
561  name.find('.') == std::string::npos,
562  "Submodule name must not contain a dot (got '",
563  name,
564  "')");
565  auto& base_module = children_.insert(std::move(name), std::move(module));
566  return std::dynamic_pointer_cast<ModuleType>(base_module);
567 }
568 
569 template <typename ModuleType>
570 std::shared_ptr<ModuleType> Module::register_module(
571  std::string name,
572  ModuleHolder<ModuleType> module_holder) {
573  return register_module(std::move(name), module_holder.ptr());
574 }
575 
576 template <typename... Ts>
577 void Module::to_impl(Ts&&... ts) {
578  // First call `to()` on every child module.
579  for (auto& child : children_) {
580  child.value()->to(ts...);
581  }
582  // Then move every parameter to the new dtype/device.
583  for (auto& parameter : parameters_) {
584  parameter->set_data(autograd::Variable(*parameter).data().to(ts...));
585  }
586  // Then move every buffer to the new dtype/device.
587  for (auto& buffer : buffers_) {
588  buffer->set_data(autograd::Variable(*buffer).data().to(ts...));
589  }
590 }
591 
592 } // namespace nn
593 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
The base class for all modules in PyTorch.
Definition: module.h:62
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
const std::shared_ptr< Contained > & ptr() const
Returns a shared pointer to the underlying module.
Definition: pimpl.h:100
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
ModuleType::ContainedType * as() noexcept
Attempts to cast this Module to the given ModuleType.
Definition: module.h:532
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
Definition: module.h:556
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16