3 #include <torch/csrc/python_headers.h> 4 #include <pybind11/pybind11.h> 5 #include <pybind11/stl.h> 7 #include <torch/csrc/autograd/python_function.h> 8 #include <torch/csrc/autograd/python_cpp_function.h> 15 template <>
struct type_caster<
std::shared_ptr<torch::autograd::Function>> {
17 PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _(
"std::shared_ptr<torch::autograd::Function>"));
19 bool load(handle src,
bool) {
20 if (!THPFunction_Check(src.ptr()))
return false;
21 value = THPFunction_asFunction((
THPFunction*)src.ptr());
24 static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy , handle ) {
25 auto fn = functionToPyObject(std::move(src));