Caffe2 - C++ API
A deep learning, cross platform ML framework
pybind.h
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/THP.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <ATen/core/interned_strings.h>
9 #include <ATen/core/ivalue.h>
10 #include <torch/csrc/jit/pybind_utils.h>
11 #include <torch/csrc/jit/tracer.h>
12 #include <torch/csrc/utils/pybind.h>
13 
14 #include <pybind11/functional.h>
15 #include <pybind11/pybind11.h>
16 #include <pybind11/stl.h>
17 
18 namespace py = pybind11;
19 
20 namespace pybind11 {
21 namespace detail {
22 
23 template <>
24 struct type_caster<torch::jit::IValue> {
25  public:
26  PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue"));
27 
28  bool load(handle src, bool) {
29  try {
30  value = torch::jit::toIValue(src);
31  return true;
32  } catch (std::exception& e) {
33  return false;
34  }
35  }
36 
37  static handle cast(
39  return_value_policy /* policy */,
40  handle /* parent */) {
41  return torch::jit::toPyObject(std::move(src)).release();
42  }
43 };
44 
45 template <>
46 struct type_caster<torch::jit::Symbol> {
47  public:
48  PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol"));
49 
50  bool load(handle src, bool) {
51  // TODO: Is there a way to py::cast that doesn't raise an exception on
52  // failure? Can we catch pybind11::cast_error here instead?
53  std::string src_str;
54  try {
55  src_str = py::cast<std::string>(src);
56  } catch (std::exception& e) {
57  return false;
58  }
59  value = torch::jit::Symbol::fromQualString(src_str);
60  return true;
61  }
62 
63  static handle cast(
65  return_value_policy /* policy */,
66  handle /* parent */) {
67  return py::cast(std::string(src.toQualString()), return_value_policy::copy)
68  .release();
69  }
70 };
71 
72 template <>
73 struct type_caster<torch::jit::AttributeKind> {
74  public:
75  PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind"));
76 
77  bool load(handle src, bool) {
78  return false;
79  }
80 
81  static handle cast(
82  torch::jit::AttributeKind src,
83  return_value_policy /* policy */,
84  handle /* parent */) {
85  return py::cast(
86  std::string(torch::jit::toString(src)),
87  return_value_policy::copy)
88  .release();
89  }
90 };
91 
92 // See https://github.com/pybind/pybind11/issues/637
93 using ListCasterBase = pybind11::detail::
94  list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>;
95 template <>
96 struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase {
97  static handle cast(
98  const std::vector<torch::jit::Node*>& src,
99  return_value_policy,
100  handle parent) {
101  return ListCasterBase::cast(src, return_value_policy::reference, parent);
102  }
103  static handle cast(
104  const std::vector<torch::jit::Node*>* src,
105  return_value_policy pol,
106  handle parent) {
107  return cast(*src, pol, parent);
108  }
109 };
110 
111 } // namespace detail
112 } // namespace pybind11
113 
114 namespace torch {
115 namespace jit {
116 
117 static inline py::tuple tuple_tail(const py::tuple& tup) {
118  py::tuple r(tup.size() - 1);
119  for (size_t i = 1; i < tup.size(); i++) {
120  r[i - 1] = tup[i];
121  }
122  return r;
123 }
124 
125 } // namespace jit
126 } // namespace torch
Definition: jit_type.h:17