Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/python/init.h>
2 #include <torch/python.h>
3 
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 
7 #include <torch/csrc/utils/pybind.h>
8 
9 #include <string>
10 #include <vector>
11 
12 namespace py = pybind11;
13 
14 namespace pybind11 {
15 namespace detail {
16 #define ITEM_TYPE_CASTER(T, Name) \
17  template <> \
18  struct type_caster<typename torch::OrderedDict<std::string, T>::Item> { \
19  public: \
20  using Item = typename torch::OrderedDict<std::string, T>::Item; \
21  using PairCaster = make_caster<std::pair<std::string, T>>; \
22  PYBIND11_TYPE_CASTER(Item, _("Ordered" #Name "DictItem")); \
23  bool load(handle src, bool convert) { \
24  return PairCaster().load(src, convert); \
25  } \
26  static handle cast(Item src, return_value_policy policy, handle parent) { \
27  return PairCaster::cast( \
28  src.pair(), std::move(policy), std::move(parent)); \
29  } \
30  }
31 
32 ITEM_TYPE_CASTER(torch::Tensor, Tensor);
33 ITEM_TYPE_CASTER(std::shared_ptr<torch::nn::Module>, Module);
34 } // namespace detail
35 } // namespace pybind11
36 
37 namespace torch {
38 namespace python {
39 namespace {
40 template <typename T>
41 void bind_ordered_dict(py::module module, const char* dict_name) {
42  using ODict = OrderedDict<std::string, T>;
43  // clang-format off
44  py::class_<ODict>(module, dict_name)
45  .def("items", &ODict::items)
46  .def("keys", &ODict::keys)
47  .def("values", &ODict::values)
48  .def("__iter__", [](const ODict& dict) {
49  return py::make_iterator(dict.begin(), dict.end());
50  }, py::keep_alive<0, 1>())
51  .def("__len__", &ODict::size)
52  .def("__contains__", &ODict::contains)
53  .def("__getitem__", [](const ODict& dict, const std::string& key) {
54  return dict[key];
55  })
56  .def("__getitem__", [](const ODict& dict, size_t index) {
57  return dict[index];
58  });
59  // clang-format on
60 }
61 } // namespace
62 
63 void init_bindings(PyObject* module) {
64  py::module m = py::handle(module).cast<py::module>();
65  py::module cpp = m.def_submodule("cpp");
66 
67  bind_ordered_dict<Tensor>(cpp, "OrderedTensorDict");
68  bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp, "OrderedModuleDict");
69 
70  py::module nn = cpp.def_submodule("nn");
71  add_module_bindings(
72  py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module"));
73 }
74 } // namespace python
75 } // namespace torch
Definition: jit_type.h:17
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16