1 #include <torch/extension.h> 7 Net(int64_t in, int64_t out) : in_(in), out_(out) {
17 return fc->forward(x);
33 void add_new_buffer(
const std::string& name,
torch::Tensor tensor) {
37 void add_new_submodule(
const std::string& name) {
42 torch::nn::Linear fc{
nullptr};
46 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47 torch::python::bind_module<Net>(m,
"Net")
48 .def(py::init<int64_t, int64_t>())
49 .def(
"set_bias", &Net::set_bias)
50 .def(
"get_bias", &Net::get_bias)
51 .def(
"add_new_parameter", &Net::add_new_parameter)
52 .def(
"add_new_buffer", &Net::add_new_buffer)
53 .def(
"add_new_submodule", &Net::add_new_submodule);
Tensor & register_parameter(std::string name, Tensor tensor, bool requires_grad=true)
Registers a parameter with this Module.
void reset() override
reset() must perform initialization of all members with reference semantics, most importantly paramet...
const std::string & name() const noexcept
Returns the name of the Module.
Tensor & register_buffer(std::string name, Tensor tensor)
Registers a buffer with this Module.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
std::shared_ptr< ModuleType > register_module(std::string name, std::shared_ptr< ModuleType > module)
Registers a submodule with this Module.