1 #include <torch/extension.h> 4 return x.sigmoid() + y.sigmoid();
10 torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(
true));
13 return tensor_.mm(weights);
24 return tensor.has_value();
27 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28 m.def(
"sigmoid_add", &sigmoid_add,
"sigmoid(x) + sigmoid(y)");
30 "function_taking_optional",
31 &function_taking_optional,
32 "function_taking_optional");
33 py::class_<MatrixMultiplier>(m,
"MatrixMultiplier")
34 .def(py::init<int, int>())
35 .def(
"forward", &MatrixMultiplier::forward)
36 .def(
"get", &MatrixMultiplier::get);