Caffe2 - C++ API
A deep learning, cross platform ML framework
cpp_frontend_extension.cpp
1 #include <torch/extension.h>
2 
3 #include <cstddef>
4 #include <string>
5 
6 struct Net : torch::nn::Cloneable<Net> {
7  Net(int64_t in, int64_t out) : in_(in), out_(out) {
8  reset();
9  }
10 
11  void reset() override {
12  fc = register_module("fc", torch::nn::Linear(in_, out_));
13  buffer = register_buffer("buf", torch::eye(5));
14  }
15 
16  torch::Tensor forward(torch::Tensor x) {
17  return fc->forward(x);
18  }
19 
20  void set_bias(torch::Tensor bias) {
21  torch::NoGradGuard guard;
22  fc->bias.set_(bias);
23  }
24 
25  torch::Tensor get_bias() const {
26  return fc->bias;
27  }
28 
29  void add_new_parameter(const std::string& name, torch::Tensor tensor) {
30  register_parameter(name, tensor);
31  }
32 
33  void add_new_buffer(const std::string& name, torch::Tensor tensor) {
34  register_buffer(name, tensor);
35  }
36 
37  void add_new_submodule(const std::string& name) {
38  register_module(name, torch::nn::Linear(fc->options));
39  }
40 
41  int64_t in_, out_;
42  torch::nn::Linear fc{nullptr};
43  torch::Tensor buffer;
44 };
45 
46 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47  torch::python::bind_module<Net>(m, "Net")
48  .def(py::init<int64_t, int64_t>())
49  .def("set_bias", &Net::set_bias)
50  .def("get_bias", &Net::get_bias)
51  .def("add_new_parameter", &Net::add_new_parameter)
52  .def("add_new_buffer", &Net::add_new_buffer)
53  .def("add_new_submodule", &Net::add_new_submodule);
54 }
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
Definition: module.cpp:301
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
const std::string & name() const noexcept
Returns the name of the Module.
Definition: module.cpp:53
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
Definition: module.cpp:315
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
Definition: cloneable.h:23
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.
Definition: module.h:556