Caffe2 - C++ API
A deep learning, cross platform ML framework
six.h
1 #pragma once
2 
3 #include <pybind11/pybind11.h>
4 #include <torch/csrc/utils/structseq.h>
5 #include <torch/csrc/utils/object_ptr.h>
6 
7 namespace six {
8 
9 // Usually instances of PyStructSequence is also an instance of tuple
10 // but in some py2 environment it is not, so we have to manually check
11 // the name of the type to determine if it is a namedtupled returned
12 // by a pytorch operator.
13 
14 inline bool isStructSeq(pybind11::handle input) {
15  return pybind11::cast<std::string>(input.get_type().attr("__module__")) == "torch.return_types";
16 }
17 
18 inline bool isStructSeq(PyObject* obj) {
19  return isStructSeq(pybind11::handle(obj));
20 }
21 
22 inline bool isTuple(pybind11::handle input) {
23  if (PyTuple_Check(input.ptr())) {
24  return true;
25  }
26 #if PY_MAJOR_VERSION == 2
27  return isStructSeq(input);
28 #else
29  return false;
30 #endif
31 }
32 
33 inline bool isTuple(PyObject* obj) {
34  return isTuple(pybind11::handle(obj));
35 }
36 
37 // maybeAsTuple: if the input is a structseq, then convert it to a tuple
38 //
39 // On Python 3, structseq is a subtype of tuple, so these APIs could be used directly.
40 // But on Python 2, structseq is not a subtype of tuple, so we need to manually create a
41 // new tuple object from structseq.
42 inline THPObjectPtr maybeAsTuple(PyStructSequence *obj) {
43 #if PY_MAJOR_VERSION == 2
44  return THPObjectPtr(torch::utils::structseq_slice(obj, 0, Py_SIZE(obj)));
45 #else
46  Py_INCREF(obj);
47  return THPObjectPtr((PyObject *)obj);
48 #endif
49 }
50 
51 inline THPObjectPtr maybeAsTuple(PyObject *obj) {
52  if (isStructSeq(obj))
53  return maybeAsTuple((PyStructSequence *)obj);
54  Py_INCREF(obj);
55  return THPObjectPtr(obj);
56 }
57 
58 } // namespace six
Definition: six.h:7