Caffe2 - C++ API
A deep learning, cross platform ML framework
extension.cpp
1 #include <torch/extension.h>
2 
3 torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
4  return x.sigmoid() + y.sigmoid();
5 }
6 
8  MatrixMultiplier(int A, int B) {
9  tensor_ =
10  torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
11  }
12  torch::Tensor forward(torch::Tensor weights) {
13  return tensor_.mm(weights);
14  }
15  torch::Tensor get() const {
16  return tensor_;
17  }
18 
19  private:
20  torch::Tensor tensor_;
21 };
22 
23 bool function_taking_optional(c10::optional<torch::Tensor> tensor) {
24  return tensor.has_value();
25 }
26 
27 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28  m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
29  m.def(
30  "function_taking_optional",
31  &function_taking_optional,
32  "function_taking_optional");
33  py::class_<MatrixMultiplier>(m, "MatrixMultiplier")
34  .def(py::init<int, int>())
35  .def("forward", &MatrixMultiplier::forward)
36  .def("get", &MatrixMultiplier::get);
37 }
Definition: static.cpp:52
Definition: static.cpp:58