Caffe2 - C++ API
A deep learning, cross platform ML framework
pybind.h
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <pybind11/pybind11.h>
5 #include <pybind11/stl.h>
6 
7 #include <torch/csrc/autograd/python_function.h>
8 #include <torch/csrc/autograd/python_cpp_function.h>
9 
10 namespace py = pybind11;
11 
12 namespace pybind11 { namespace detail {
13 
14 // handle Python <-> torch::autograd::Function conversions
15 template <> struct type_caster<std::shared_ptr<torch::autograd::Function>> {
16 public:
17  PYBIND11_TYPE_CASTER(std::shared_ptr<torch::autograd::Function>, _("std::shared_ptr<torch::autograd::Function>"));
18 
19  bool load(handle src, bool) {
20  if (!THPFunction_Check(src.ptr())) return false;
21  value = THPFunction_asFunction((THPFunction*)src.ptr());
22  return true;
23  }
24  static handle cast(std::shared_ptr<torch::autograd::Function> src, return_value_policy /* policy */, handle /* parent */) {
25  auto fn = functionToPyObject(std::move(src));
26  return handle(fn);
27  }
28 };
29 
30 
31 }} // namespace pybind11::detail