1 #include <torch/csrc/Device.h> 3 #include <torch/csrc/Exceptions.h> 4 #include <torch/csrc/utils/object_ptr.h> 5 #include <torch/csrc/utils/python_arg_parser.h> 6 #include <torch/csrc/utils/python_strings.h> 7 #include <torch/csrc/utils/python_numbers.h> 8 #include <torch/csrc/utils/pybind.h> 10 #include <ATen/Device.h> 11 #include <c10/util/Exception.h> 15 #include <structmember.h> 18 PyObject *THPDevice_New(
const at::Device& device)
20 auto type = (PyTypeObject*)&THPDeviceType;
23 auto self_ =
reinterpret_cast<THPDevice*
>(
self.get());
24 self_->device = device;
25 return self.release();
30 std::ostringstream oss;
31 oss <<
"device(type=\'" <<
self->device.type() <<
"\'";
32 if (self->device.has_index()) {
33 oss <<
", index=" <<
self->device.index();
36 return THPUtils_packString(oss.str().c_str());
41 std::ostringstream oss;
43 return THPUtils_packString(oss.str().c_str());
46 PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs)
50 "Device(Device device)",
51 "Device(std::string type, int64_t? index=-1)" 54 auto r = parser.parse(args, kwargs, parsed_args);
56 auto device = r.device(0);
57 return THPDevice_New(device);
58 }
else if (r.idx == 1) {
59 auto as_device = r.device(0);
60 auto device_type = r.string(0);
61 if (as_device.has_index()) {
62 throw std::runtime_error(
"type (string) must not include an index because index " 63 "was passed explicitly: " + device_type);
65 int32_t device_index = -1;
67 device_index = r.toInt64(1);
70 AT_CHECK(device_index >= 0,
"Device index must not be negative");
72 at::Device device(as_device.type(), device_index);
73 return THPDevice_New(device);
82 std::ostringstream oss;
83 oss <<
self->device.type();
84 return THPUtils_packString(oss.str().c_str());
89 PyObject *THPDevice_index(
THPDevice *
self)
92 if (self->device.has_index()) {
93 return THPUtils_packInt64(self->device.index());
100 static Py_ssize_t THPDevice_hash(
THPDevice *
self)
103 return static_cast<Py_ssize_t
>(std::hash<at::Device>{}(
self->device) % std::numeric_limits<Py_ssize_t>::max());
104 END_HANDLE_TH_ERRORS_RET(-1)
107 PyObject *THPDevice_rc(PyObject *a, PyObject *b,
int op) {
109 if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
111 Py_INCREF(Py_NotImplemented);
112 return Py_NotImplemented;
119 if (da->device == db->device) {
125 if (da->device == db->device) {
141 PyObject *THPDevice_reduce(
THPDevice *
self)
147 py::object torch_module = py::module::import(
"torch");
148 py::object torch_device = torch_module.attr(
"device");
149 PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
152 std::ostringstream oss;
153 oss <<
self->device.type();
154 if (self->device.has_index()) {
155 args =
THPObjectPtr{Py_BuildValue(
"(si)", oss.str().c_str(),
self->device.index())};
157 args =
THPObjectPtr{Py_BuildValue(
"(s)", oss.str().c_str())};
160 PyTuple_SET_ITEM(ret.get(), 1, args.release());
162 return ret.release();
166 typedef PyObject *(*getter)(PyObject *,
void *);
168 static struct PyGetSetDef THPDevice_properties[] = {
169 {
"type", (getter)THPDevice_type,
nullptr,
nullptr,
nullptr},
170 {
"index", (getter)THPDevice_index,
nullptr,
nullptr,
nullptr},
174 static PyMethodDef THPDevice_methods[] = {
175 {
"__reduce__", (PyCFunction)THPDevice_reduce, METH_NOARGS,
nullptr},
179 PyTypeObject THPDeviceType = {
180 PyVarObject_HEAD_INIT(
nullptr, 0)
189 (reprfunc)THPDevice_repr,
193 (hashfunc)THPDevice_hash,
195 (reprfunc)THPDevice_str,
203 (richcmpfunc)THPDevice_rc,
209 THPDevice_properties,
220 void THPDevice_init(PyObject *module)
222 if (PyType_Ready(&THPDeviceType) < 0) {
225 Py_INCREF(&THPDeviceType);
226 if (PyModule_AddObject(module,
"device", (PyObject *)&THPDeviceType) != 0) {
Represents a a compute device on which a tensor is located.