Caffe2 - C++ API
A deep learning, cross platform ML framework
Device.cpp
1 #include <torch/csrc/Device.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/utils/object_ptr.h>
5 #include <torch/csrc/utils/python_arg_parser.h>
6 #include <torch/csrc/utils/python_strings.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 #include <torch/csrc/utils/pybind.h>
9 
10 #include <ATen/Device.h>
11 #include <c10/util/Exception.h>
12 
13 #include <cstring>
14 #include <limits>
15 #include <structmember.h>
16 #include <sstream>
17 
18 PyObject *THPDevice_New(const at::Device& device)
19 {
20  auto type = (PyTypeObject*)&THPDeviceType;
21  auto self = THPObjectPtr{type->tp_alloc(type, 0)};
22  if (!self) throw python_error();
23  auto self_ = reinterpret_cast<THPDevice*>(self.get());
24  self_->device = device;
25  return self.release();
26 }
27 
28 PyObject *THPDevice_repr(THPDevice *self)
29 {
30  std::ostringstream oss;
31  oss << "device(type=\'" << self->device.type() << "\'";
32  if (self->device.has_index()) {
33  oss << ", index=" << self->device.index();
34  }
35  oss << ")";
36  return THPUtils_packString(oss.str().c_str());
37 }
38 
39 PyObject *THPDevice_str(THPDevice *self)
40 {
41  std::ostringstream oss;
42  oss << self->device;
43  return THPUtils_packString(oss.str().c_str());
44 }
45 
46 PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
47 {
48  HANDLE_TH_ERRORS
49  static torch::PythonArgParser parser({
50  "Device(Device device)",
51  "Device(std::string type, int64_t? index=-1)"
52  });
53  torch::ParsedArgs<2> parsed_args;
54  auto r = parser.parse(args, kwargs, parsed_args);
55  if (r.idx == 0) {
56  auto device = r.device(0);
57  return THPDevice_New(device);
58  } else if (r.idx == 1) {
59  auto as_device = r.device(0); // this works, because device can take strings
60  auto device_type = r.string(0);
61  if (as_device.has_index()) {
62  throw std::runtime_error("type (string) must not include an index because index "
63  "was passed explicitly: " + device_type);
64  }
65  int32_t device_index = -1;
66  if (!r.isNone(1)) {
67  device_index = r.toInt64(1);
68  // -1 is allowed in ATen/C++, to mean the default device, but not in
69  // Python.
70  AT_CHECK(device_index >= 0, "Device index must not be negative");
71  }
72  at::Device device(as_device.type(), device_index);
73  return THPDevice_New(device);
74  }
75  Py_RETURN_NONE;
76  END_HANDLE_TH_ERRORS
77 }
78 
79 PyObject *THPDevice_type(THPDevice *self)
80 {
81  HANDLE_TH_ERRORS
82  std::ostringstream oss;
83  oss << self->device.type();
84  return THPUtils_packString(oss.str().c_str());
85  Py_RETURN_NONE;
86  END_HANDLE_TH_ERRORS
87 }
88 
89 PyObject *THPDevice_index(THPDevice *self)
90 {
91  HANDLE_TH_ERRORS
92  if (self->device.has_index()) {
93  return THPUtils_packInt64(self->device.index());
94  } else {
95  Py_RETURN_NONE;
96  }
97  END_HANDLE_TH_ERRORS
98 }
99 
100 static Py_ssize_t THPDevice_hash(THPDevice *self)
101 {
102  HANDLE_TH_ERRORS
103  return static_cast<Py_ssize_t>(std::hash<at::Device>{}(self->device) % std::numeric_limits<Py_ssize_t>::max());
104  END_HANDLE_TH_ERRORS_RET(-1)
105 }
106 
107 PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) {
108  HANDLE_TH_ERRORS
109  if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
110  // Py_RETURN_NOTIMPLEMENTED not in python 2.
111  Py_INCREF(Py_NotImplemented);
112  return Py_NotImplemented;
113  }
114  THPDevice *da = reinterpret_cast<THPDevice*>(a);
115  THPDevice *db = reinterpret_cast<THPDevice*>(b);
116 
117  switch(op) {
118  case Py_EQ:
119  if (da->device == db->device) {
120  Py_RETURN_TRUE;
121  } else {
122  Py_RETURN_FALSE;
123  }
124  case Py_NE:
125  if (da->device == db->device) {
126  Py_RETURN_FALSE;
127  } else {
128  Py_RETURN_TRUE;
129  }
130  case Py_LT:
131  case Py_LE:
132  case Py_GT:
133  case Py_GE:
134  throw torch::TypeError("comparison not implemented");
135  default:
136  throw torch::TypeError("unexpected comparison op");
137  }
138  END_HANDLE_TH_ERRORS
139 }
140 
141 PyObject *THPDevice_reduce(THPDevice *self)
142 {
143  HANDLE_TH_ERRORS
144  auto ret = THPObjectPtr{PyTuple_New(2)};
145  if (!ret) throw python_error();
146 
147  py::object torch_module = py::module::import("torch");
148  py::object torch_device = torch_module.attr("device");
149  PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
150 
151  THPObjectPtr args;
152  std::ostringstream oss;
153  oss << self->device.type();
154  if (self->device.has_index()) {
155  args = THPObjectPtr{Py_BuildValue("(si)", oss.str().c_str(), self->device.index())};
156  } else {
157  args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
158  }
159  if (!args) throw python_error();
160  PyTuple_SET_ITEM(ret.get(), 1, args.release());
161 
162  return ret.release();
163  END_HANDLE_TH_ERRORS
164 }
165 
166 typedef PyObject *(*getter)(PyObject *, void *);
167 
168 static struct PyGetSetDef THPDevice_properties[] = {
169  {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
170  {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
171  {nullptr}
172 };
173 
174 static PyMethodDef THPDevice_methods[] = {
175  {"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS, nullptr},
176  {nullptr} /* Sentinel */
177 };
178 
179 PyTypeObject THPDeviceType = {
180  PyVarObject_HEAD_INIT(nullptr, 0)
181  "torch.device", /* tp_name */
182  sizeof(THPDevice), /* tp_basicsize */
183  0, /* tp_itemsize */
184  nullptr, /* tp_dealloc */
185  nullptr, /* tp_print */
186  nullptr, /* tp_getattr */
187  nullptr, /* tp_setattr */
188  nullptr, /* tp_reserved */
189  (reprfunc)THPDevice_repr, /* tp_repr */
190  nullptr, /* tp_as_number */
191  nullptr, /* tp_as_sequence */
192  nullptr, /* tp_as_mapping */
193  (hashfunc)THPDevice_hash, /* tp_hash */
194  nullptr, /* tp_call */
195  (reprfunc)THPDevice_str, /* tp_str */
196  nullptr, /* tp_getattro */
197  nullptr, /* tp_setattro */
198  nullptr, /* tp_as_buffer */
199  Py_TPFLAGS_DEFAULT, /* tp_flags */
200  nullptr, /* tp_doc */
201  nullptr, /* tp_traverse */
202  nullptr, /* tp_clear */
203  (richcmpfunc)THPDevice_rc, /* tp_richcompare */
204  0, /* tp_weaklistoffset */
205  nullptr, /* tp_iter */
206  nullptr, /* tp_iternext */
207  THPDevice_methods, /* tp_methods */
208  nullptr, /* tp_members */
209  THPDevice_properties, /* tp_getset */
210  nullptr, /* tp_base */
211  nullptr, /* tp_dict */
212  nullptr, /* tp_descr_get */
213  nullptr, /* tp_descr_set */
214  0, /* tp_dictoffset */
215  nullptr, /* tp_init */
216  nullptr, /* tp_alloc */
217  THPDevice_pynew, /* tp_new */
218 };
219 
220 void THPDevice_init(PyObject *module)
221 {
222  if (PyType_Ready(&THPDeviceType) < 0) {
223  throw python_error();
224  }
225  Py_INCREF(&THPDeviceType);
226  if (PyModule_AddObject(module, "device", (PyObject *)&THPDeviceType) != 0) {
227  throw python_error();
228  }
229 }
Represents a a compute device on which a tensor is located.
Definition: Device.h:30