1 #include <torch/csrc/tensor/python_tensor.h> 3 #include <structmember.h> 4 #include <pybind11/pybind11.h> 6 #include <torch/csrc/Dtype.h> 7 #include <torch/csrc/DynamicTypes.h> 8 #include <torch/csrc/Exceptions.h> 9 #include <torch/csrc/Layout.h> 10 #include <torch/csrc/autograd/variable.h> 11 #include <torch/csrc/autograd/python_variable.h> 12 #include <torch/csrc/autograd/generated/VariableType.h> 13 #include <torch/csrc/autograd/utils/wrap_outputs.h> 14 #include <torch/csrc/utils/cuda_enabled.h> 15 #include <torch/csrc/utils/cuda_lazy_init.h> 16 #include <torch/csrc/utils/python_strings.h> 17 #include <torch/csrc/utils/tensor_new.h> 18 #include <torch/csrc/utils/tensor_types.h> 20 #include <ATen/ATen.h> 24 #include <type_traits> 27 namespace torch {
namespace tensors {
46 torch::utils::cuda_lazy_init();
48 auto* baseType = globalContext().getNonVariableTypeOpt(static_cast<at::Backend>(backend), static_cast<at::ScalarType>(scalar_type));
49 aten_type_ = baseType ? torch::autograd::VariableType::getVariableTypeFromBaseType(*baseType) :
nullptr;
55 static_assert(std::is_standard_layout<PyTensorType>::value,
"PyTensorType must be standard layout");
58 static at::Type* default_tensor_type;
60 static void py_bind_tensor_types(
const std::vector<PyTensorType>& tensor_types);
63 const char* cuda_msg = torch::utils::cuda_enabled() ?
". Torch not compiled with CUDA enabled." :
"";
64 return TypeError(
"type %s not available%s", type.name, cuda_msg);
67 static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
70 auto aten_type = tensor_type.aten_type();
72 throw unavailable_type(tensor_type);
74 return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(*aten_type, args, kwargs));
78 static PyObject* Tensor_instancecheck(
PyTensorType*
self, PyObject* arg) {
80 if (THPVariable_Check(arg)) {
88 if (&var.type() ==
self->aten_type()) {
97 return torch::autograd::utils::wrap(self->dtype);
101 return torch::autograd::utils::wrap(self->layout);
113 if (self->layout->layout == at::Layout::Strided) {
120 static struct PyMethodDef metaclass_methods[] = {
121 {
"__instancecheck__", (PyCFunction)Tensor_instancecheck, METH_O,
nullptr},
125 typedef PyObject *(*getter)(PyObject *,
void *);
127 static struct PyGetSetDef metaclass_properties[] = {
128 {
"dtype", (getter)Tensor_dtype,
nullptr,
nullptr,
nullptr},
129 {
"layout", (getter)Tensor_layout,
nullptr,
nullptr,
nullptr},
130 {
"is_cuda", (getter)Tensor_is_cuda,
nullptr,
nullptr,
nullptr},
131 {
"is_sparse", (getter)Tensor_is_sparse,
nullptr,
nullptr,
nullptr},
135 static PyTypeObject metaclass;
137 static void py_initialize_metaclass(PyTypeObject& metaclass) {
138 ((PyObject*)&metaclass)->ob_refcnt = 1;
139 metaclass.tp_basicsize =
sizeof(PyTypeObject);
140 metaclass.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
141 metaclass.tp_methods = metaclass_methods;
142 metaclass.tp_getset = metaclass_properties;
143 metaclass.tp_name =
"torch.tensortype";
144 metaclass.tp_base = &PyType_Type;
145 if (PyType_Ready(&metaclass) < 0) {
150 static void py_initialize_tensor_type(PyTypeObject& type,
const char* name, PyObject* tp_dict) {
156 memset(&type, 0,
sizeof(PyTypeObject));
157 ((PyObject*)&type)->ob_refcnt = 1;
158 ((PyObject*)&type)->ob_type = &metaclass;
160 type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
162 type.tp_new = Tensor_new;
163 if (PyType_Ready(&type) < 0) {
166 if (PyDict_Merge(type.tp_dict, tp_dict, 0) < 0) {
171 static const char* get_module(
Backend backend) {
173 case Backend::CPU:
return "torch";
174 case Backend::CUDA:
return "torch.cuda";
175 case Backend::SparseCPU:
return "torch.sparse";
176 case Backend::SparseCUDA:
return "torch.cuda.sparse";
177 default: AT_ERROR(
"invalid backend: ", toString(backend));
181 static std::string get_name(
Backend backend, ScalarType scalarType) {
182 std::ostringstream ss;
183 ss << get_module(backend) <<
"." << toString(scalarType) <<
"Tensor";
188 auto module_name = get_module(type.backend());
189 auto module_obj =
THPObjectPtr(PyImport_ImportModule(module_name));
192 auto storage_name = std::string(toString(type.scalarType())) +
"Storage";
193 THPObjectPtr storage(PyObject_GetAttrString(module_obj.get(), storage_name.c_str()));
194 if (!storage.get()) {
195 throw TypeError(
"couldn't find storage object %s", storage_name.c_str());
202 type_obj.aten_type_ =
nullptr;
203 type_obj.backend =
static_cast<int>(backend);
204 type_obj.scalar_type =
static_cast<int>(scalarType);
205 type_obj.layout = torch::getLayout(backend);
206 type_obj.dtype = torch::getDtype(scalarType);
207 type_obj.is_cuda = (backend == at::Backend::CUDA || backend == at::Backend::SparseCUDA);
210 static void set_name(
PyTensorType& type_obj,
const std::string& name) {
211 size_t n =
sizeof(type_obj.name);
212 strncpy(type_obj.name, name.c_str(), n);
213 type_obj.name[n - 1] =
'\0';
223 auto tensor_type = (PyTypeObject*)tensor_class.get();
224 AT_CHECK(tensor_type->tp_base,
"missing base type for Tensor");
229 if (PyDict_Merge(res.get(), tensor_type->tp_dict, 0) < 0) {
232 if (PyDict_Merge(res.get(), tensor_type->tp_base->tp_dict, 0) < 0) {
239 static std::vector<PyTensorType> tensor_types;
241 static void initialize_aten_types(std::vector<PyTensorType>& tensor_types) {
243 auto declared_types = torch::utils::all_declared_types();
244 tensor_types.resize(declared_types.size());
246 for (
size_t i = 0, end = declared_types.size(); i != end; i++) {
247 auto& tensor_type = tensor_types[i];
248 Backend backend = declared_types[i].first;
249 ScalarType scalar_type = declared_types[i].second;
250 set_type(tensor_type, backend, scalar_type);
251 set_name(tensor_type, get_name(backend, scalar_type));
255 void initialize_python_bindings() {
258 initialize_aten_types(tensor_types);
263 py_initialize_metaclass(metaclass);
268 auto tensor_dict = get_tensor_dict();
271 for (
auto& tensor_type : tensor_types) {
272 py_initialize_tensor_type(tensor_type.py_type, tensor_type.name, tensor_dict.get());
278 py_bind_tensor_types(tensor_types);
281 set_default_tensor_type(at::globalContext().getVariableType(at::Backend::CPU, at::kFloat));
284 static void py_bind_tensor_types(
const std::vector<PyTensorType>& tensor_types) {
285 auto torch_module =
THPObjectPtr(PyImport_ImportModule(
"torch"));
288 auto tensor_classes =
THPObjectPtr(PyObject_GetAttrString(torch_module.get(),
"_tensor_classes"));
291 for (
auto& tensor_type : tensor_types) {
292 auto name = std::string(tensor_type.name);
293 auto idx = name.rfind(
'.');
294 auto type_name = name.substr(idx + 1);
295 auto module_name = name.substr(0, idx);
297 auto module_obj =
THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
300 PyObject* type_obj = (PyObject*)&tensor_type;
302 if (PyModule_AddObject(module_obj.get(), type_name.c_str(), type_obj) < 0) {
305 if (PySet_Add(tensor_classes.get(), type_obj) < 0) {
311 static bool PyTensorType_Check(PyObject* obj) {
312 auto it = std::find_if(tensor_types.begin(), tensor_types.end(),
314 return (PyObject*)&x == obj;
316 return it != tensor_types.end();
320 auto it = std::find_if(tensor_types.begin(), tensor_types.end(),
322 return x.dtype == dtype && x.layout == layout && x.is_cuda == is_cuda;
324 if (it == tensor_types.end()) {
330 void py_set_default_tensor_type(PyObject* obj) {
332 if (PyTensorType_Check(obj)) {
337 auto aten_type = type->aten_type();
339 throw unavailable_type(*type);
341 set_default_tensor_type(*aten_type);
344 void py_set_default_dtype(PyObject* obj) {
346 if (THPDtype_Check(obj)) {
347 auto ¤t_default = get_default_tensor_type();
348 type = &get_tensor_type((
THPDtype*)obj, torch::getLayout(current_default.backend()),
349 torch::getDeviceType(current_default) == at::Device::Type::CUDA);
353 auto aten_type = type->aten_type();
355 throw unavailable_type(*type);
357 set_default_tensor_type(*aten_type);
360 void set_default_tensor_type(
const at::Type& type) {
361 if (!at::isFloatingType(type.scalarType())) {
362 throw TypeError(
"only floating-point types are supported as the default type");
364 if (!type.is_variable() && !type.is_undefined()) {
365 throw TypeError(
"only variable types are supported");
367 if (type.is_sparse()) {
368 throw TypeError(
"only dense types are supported as the default type");
374 default_tensor_type =
const_cast<Type*
>(&type);
375 at::set_default_dtype(default_tensor_type->typeMeta());
377 auto torch_module =
THPObjectPtr(PyImport_ImportModule(
"torch"));
380 if (PyObject_SetAttrString(torch_module.get(),
"Storage", storage) != 0) {
386 at::Type& get_default_tensor_type() {
387 AT_ASSERT(default_tensor_type);
388 return *default_tensor_type;
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Flush-To-Zero and Denormals-Are-Zero mode.