1 #include <torch/python/init.h> 2 #include <torch/python.h> 4 #include <torch/nn/module.h> 5 #include <torch/ordered_dict.h> 7 #include <torch/csrc/utils/pybind.h> 16 #define ITEM_TYPE_CASTER(T, Name) \ 18 struct type_caster<typename torch::OrderedDict<std::string, T>::Item> { \ 20 using Item = typename torch::OrderedDict<std::string, T>::Item; \ 21 using PairCaster = make_caster<std::pair<std::string, T>>; \ 22 PYBIND11_TYPE_CASTER(Item, _("Ordered" #Name "DictItem")); \ 23 bool load(handle src, bool convert) { \ 24 return PairCaster().load(src, convert); \ 26 static handle cast(Item src, return_value_policy policy, handle parent) { \ 27 return PairCaster::cast( \ 28 src.pair(), std::move(policy), std::move(parent)); \ 33 ITEM_TYPE_CASTER(std::shared_ptr<torch::nn::Module>, Module);
41 void bind_ordered_dict(py::module module,
const char* dict_name) {
44 py::class_<ODict>(module, dict_name)
45 .def(
"items", &ODict::items)
46 .def(
"keys", &ODict::keys)
47 .def(
"values", &ODict::values)
48 .def(
"__iter__", [](
const ODict& dict) {
49 return py::make_iterator(dict.begin(), dict.end());
50 }, py::keep_alive<0, 1>())
51 .def(
"__len__", &ODict::size)
52 .def(
"__contains__", &ODict::contains)
53 .def(
"__getitem__", [](
const ODict& dict,
const std::string& key) {
56 .def(
"__getitem__", [](
const ODict& dict,
size_t index) {
63 void init_bindings(PyObject* module) {
64 py::module m = py::handle(module).cast<py::module>();
65 py::module cpp = m.def_submodule(
"cpp");
67 bind_ordered_dict<Tensor>(cpp,
"OrderedTensorDict");
68 bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp,
"OrderedModuleDict");
70 py::module nn = cpp.def_submodule(
"nn");
72 py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn,
"Module"));
An ordered dictionary implementation, akin to Python's OrderedDict.