1 #include <torch/csrc/python_headers.h> 2 #include <torch/csrc/utils/tensor_new.h> 4 #include <torch/csrc/DynamicTypes.h> 5 #include <torch/csrc/Exceptions.h> 6 #include <torch/csrc/Size.h> 7 #include <torch/csrc/autograd/variable.h> 8 #include <torch/csrc/utils/auto_gil.h> 9 #include <torch/csrc/utils/cuda_lazy_init.h> 10 #include <torch/csrc/utils/numpy_stub.h> 11 #include <torch/csrc/utils/python_arg_parser.h> 12 #include <torch/csrc/utils/python_numbers.h> 13 #include <torch/csrc/utils/python_scalars.h> 14 #include <torch/csrc/utils/python_strings.h> 15 #include <torch/csrc/utils/tensor_numpy.h> 16 #include <torch/csrc/autograd/generated/variable_factories.h> 18 #include <ATen/ATen.h> 19 #include <ATen/InitialTensorOptions.h> 20 #include <c10/util/Exception.h> 21 #include <c10/util/Optional.h> 40 namespace torch {
namespace utils {
42 const int MAX_DIMS = 128;
44 void maybe_initialize_cuda(
const Type &type) {
46 torch::utils::cuda_lazy_init();
50 void maybe_initialize_cuda(
const Device device) {
51 if (device.is_cuda()) {
52 torch::utils::cuda_lazy_init();
57 maybe_initialize_cuda(type);
59 return torch::zeros(sizes, type.
options(std::move(device)));
63 maybe_initialize_cuda(type);
65 return torch::ones(sizes, type.
options(std::move(device)));
69 maybe_initialize_cuda(type);
71 return torch::full(sizes, fill_value, type.
options(std::move(device)));
75 maybe_initialize_cuda(type);
77 return torch::empty(sizes, type.
options(std::move(device)));
80 Tensor new_with_storage(
const Type& type, Storage storage) {
81 auto tensor = at::empty({}, type.
options());
82 tensor.set_(std::move(storage));
87 if (other.type() != type) {
88 throw TypeError(
"expected %s (got %s)", type.toString(), other.type().toString());
93 std::vector<int64_t> compute_sizes(PyObject* seq) {
94 std::vector<int64_t> sizes;
96 while (PySequence_Check(seq)) {
97 auto length = PySequence_Length(seq);
99 sizes.push_back(length);
100 if (sizes.size() > MAX_DIMS) {
101 throw ValueError(
"too many dimensions '%s'", Py_TYPE(seq)->tp_name);
103 if (length == 0)
break;
106 throw ValueError(
"could not determine the shape of object type '%s'", Py_TYPE(seq)->tp_name);
114 ScalarType infer_scalar_type(PyObject *obj) {
115 if (PyFloat_Check(obj)) {
118 return torch::tensors::get_default_tensor_type().scalarType();
120 if (THPUtils_checkLong(obj)) {
121 return ScalarType::Long;
123 if (PyBool_Check(obj)) {
125 return ScalarType::Byte;
127 if (THPVariable_Check(obj)) {
128 auto var =
reinterpret_cast<THPVariable*
>(obj)->cdata;
129 return var.scalar_type();
132 if (PyArray_Check(obj)) {
133 return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
135 if (PyArray_CheckScalar(obj)) {
136 return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)(PyArray_FromScalar(obj,
nullptr))));
139 if (THPUtils_checkString(obj)) {
140 throw TypeError(
"new(): invalid data type '%s'", Py_TYPE(obj)->tp_name);
142 if (PySequence_Check(obj)) {
144 auto length = PySequence_Length(obj);
147 if (length == 0)
return torch::tensors::get_default_tensor_type().scalarType();
148 for (
int i = 0; i < length; ++i) {
151 auto cur_item = handle.get();
152 if (cur_item == obj)
throw TypeError(
"new(): self-referential lists are incompatible");
153 ScalarType item_scalarType = infer_scalar_type(cur_item);
154 scalarType = (scalarType) ?
155 at::promoteTypes(*scalarType, item_scalarType) : item_scalarType;
156 if (scalarType == ScalarType::Double) {
163 AT_ERROR(
"Could not infer dtype of ", Py_TYPE(obj)->tp_name);
166 void recursive_store(
char* data, IntArrayRef sizes, IntArrayRef strides, int64_t dim,
167 ScalarType scalarType,
int elementSize, PyObject* obj) {
168 int64_t ndim = sizes.size();
170 torch::utils::store_scalar(data, scalarType, obj);
175 auto seq =
THPObjectPtr(PySequence_Fast(obj,
"not a sequence"));
177 auto seq_size = PySequence_Fast_GET_SIZE(seq.get());
179 throw ValueError(
"expected sequence of length %lld at dim %lld (got %lld)",
180 (
long long)n, (
long long)dim, (
long long)seq_size);
183 PyObject** items = PySequence_Fast_ITEMS(seq.get());
184 for (int64_t i = 0; i < n; i++) {
185 recursive_store(data, sizes, strides, dim + 1, scalarType, elementSize, items[i]);
186 data += strides[dim] * elementSize;
190 Tensor internal_new_from_data(
196 bool type_inference) {
197 if (THPUtils_checkString(data)) {
198 throw TypeError(
"new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
201 if (THPVariable_Check(data)) {
202 auto var =
reinterpret_cast<THPVariable*
>(data)->cdata;
203 if (copy_variables) {
208 const auto& scalar_type = type_inference ? var.scalar_type() : type.scalarType();
209 auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() :
at::Device(torch::getDeviceType(type)));
211 maybe_initialize_cuda(device);
212 return var.
to(device, scalar_type,
false, copy_variables);
216 if (PyArray_Check(data)) {
217 auto tensor = autograd::make_variable(tensor_from_numpy(data),
false);
218 const auto& scalar_type = type_inference ? tensor.scalar_type() : type.scalarType();
219 auto device = device_opt.has_value() ? *device_opt :
at::Device(type.device_type());
221 maybe_initialize_cuda(device);
222 return tensor.to(device, scalar_type,
false, copy_numpy);
226 auto sizes = compute_sizes(data);
227 ScalarType scalar_type = type_inference ? infer_scalar_type(data) : type.scalarType();
228 auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type)),
false);
230 (
char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
231 scalar_type, tensor.element_size(), data);
232 auto device = device_opt.has_value() ? *device_opt :
at::Device(torch::getDeviceType(type));
234 maybe_initialize_cuda(device);
235 return tensor.to(device, scalar_type,
false,
false);
238 Tensor new_from_data_copy(
242 return internal_new_from_data(type, std::move(device), data,
true,
true,
false);
245 Tensor legacy_new_from_sequence(
249 if (!PySequence_Check(data)) {
250 throw TypeError(
"new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
252 return internal_new_from_data(type, std::move(device), data,
false,
false,
false);
256 if (device.has_value()) {
257 AT_CHECK(type.device_type() == device.value().type(),
258 "legacy constructor for device type: ", type.device_type(),
259 " was passed device type: ", device.value().type(),
260 ", but device type must be: ", type.device_type());
264 Tensor legacy_sparse_tensor_ctor(
const Type& type, PyObject* args, PyObject* kwargs) {
265 static PythonArgParser parser({
266 "new(*, Device? device=None)",
267 "new(*, int64_t cdata)|hidden",
268 "new(Tensor indices, Tensor values, *, Device? device=None)",
269 "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
270 "new(IntArrayRef size, *, Device? device=None)",
272 ParsedArgs<4> parsed_args;
273 auto r = parser.parse(args, kwargs, parsed_args);
275 auto deviceOptional = r.deviceOptional(0);
276 check_legacy_ctor_device(type, deviceOptional);
277 return at::empty({0}, type.
options(r.deviceOptional(0)));
278 }
else if (r.idx == 1) {
279 auto cdata =
reinterpret_cast<void*
>(r.toInt64(0));
280 return type.unsafeTensorFromTH(cdata,
true);
281 }
else if (r.idx == 2) {
282 auto deviceOptional = r.deviceOptional(2);
283 check_legacy_ctor_device(type, deviceOptional);
285 return at::sparse_coo_tensor(r.tensor(0), r.tensor(1));
286 }
else if (r.idx == 3) {
287 auto deviceOptional = r.deviceOptional(3);
288 check_legacy_ctor_device(type, deviceOptional);
290 return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
291 }
else if (r.idx == 4) {
292 PyObject* arg = r.pyobject(0);
293 auto deviceOptional = r.deviceOptional(1);
294 check_legacy_ctor_device(type, deviceOptional);
295 if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
298 return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
300 return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
302 throw std::runtime_error(
"new(): invalid arguments");
305 Tensor legacy_sparse_tensor_new(
const Type& type, PyObject* args, PyObject* kwargs) {
306 static PythonArgParser parser({
307 "new(*, Device? device=None)",
308 "new(*, int64_t cdata)|hidden",
309 "new(Tensor indices, Tensor values, *, Device? device=None)",
310 "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
311 "new(IntArrayRef size, *, Device? device=None)",
313 ParsedArgs<5> parsed_args;
314 auto r = parser.parse(args, kwargs, parsed_args);
316 auto deviceOptional = r.deviceOptional(0);
317 check_legacy_ctor_device(type, deviceOptional);
319 return at::empty({0}, type.
options());
320 }
else if (r.idx == 1) {
321 auto cdata =
reinterpret_cast<void*
>(r.toInt64(0));
322 return type.unsafeTensorFromTH(cdata,
true);
323 }
else if (r.idx == 2) {
326 auto deviceOptional = r.deviceOptional(2);
327 check_legacy_ctor_device(type, deviceOptional);
329 return at::sparse_coo_tensor(r.tensor(0), r.tensor(1));
330 }
else if (r.idx == 3) {
333 auto deviceOptional = r.deviceOptional(3);
334 check_legacy_ctor_device(type, deviceOptional);
336 return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
337 }
else if (r.idx == 4) {
338 PyObject* arg = r.pyobject(0);
339 auto deviceOptional = r.deviceOptional(1);
340 check_legacy_ctor_device(type, deviceOptional);
341 if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
344 return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
346 return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
348 throw std::runtime_error(
"new(): invalid arguments");
352 const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx,
const Type& type) {
353 const auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType());
354 const Device types_device_type(type.device_type());
355 const auto device_type = r.isNone(device_idx) ? types_device_type : r.device(device_idx).type();
356 return torch::getVariableType(scalartype, *torch::getLayout(type.backend()), device_type);
360 Tensor legacy_tensor_ctor(
const Type& type, PyObject* args, PyObject* kwargs) {
361 static PythonArgParser parser({
362 "new(*, Device? device=None)",
363 "new(Storage storage)",
364 "new(*, int64_t cdata)|hidden",
366 "new(IntArrayRef size, *, Device? device=None)",
367 "new(PyObject* data, *, Device? device=None)",
370 if (type.is_sparse()) {
371 return legacy_sparse_tensor_ctor(type, args, kwargs);
374 ParsedArgs<2> parsed_args;
375 auto r = parser.parse(args, kwargs, parsed_args);
377 auto deviceOptional = r.deviceOptional(0);
378 check_legacy_ctor_device(type, deviceOptional);
380 return at::empty({0}, type.
options());
381 }
else if (r.idx == 1) {
382 return new_with_storage(type, r.storage(0));
383 }
else if (r.idx == 2) {
384 auto cdata =
reinterpret_cast<void*
>(r.toInt64(0));
385 return type.unsafeTensorFromTH(cdata,
true);
386 }
else if (r.idx == 3) {
387 return new_with_tensor(type, r.tensor(0));
388 }
else if (r.idx == 4) {
389 PyObject* arg = r.pyobject(0);
390 auto deviceOptional = r.deviceOptional(1);
391 check_legacy_ctor_device(type, deviceOptional);
392 if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
395 return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
397 return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
398 }
else if (r.idx == 5) {
399 auto deviceOptional = r.deviceOptional(1);
400 check_legacy_ctor_device(type, deviceOptional);
401 return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
403 throw std::runtime_error(
"new(): invalid arguments");
406 Tensor legacy_tensor_new(
const Type& type, PyObject* args, PyObject* kwargs) {
407 static PythonArgParser parser({
408 "new(*, Device? device=None)",
409 "new(Storage storage)",
410 "new(*, int64_t cdata)|hidden",
412 "new(IntArrayRef size, *, Device? device=None)",
413 "new(PyObject* data, *, Device? device=None)",
416 if (type.is_sparse()) {
417 return legacy_sparse_tensor_new(type, args, kwargs);
420 ParsedArgs<3> parsed_args;
421 auto r = parser.parse(args, kwargs, parsed_args);
423 auto deviceOptional = r.deviceOptional(0);
424 check_legacy_ctor_device(type, deviceOptional);
426 return at::empty({0}, type.
options());
427 }
else if (r.idx == 1) {
428 return new_with_storage(type, r.storage(0));
429 }
else if (r.idx == 2) {
430 auto cdata =
reinterpret_cast<void*
>(r.toInt64(0));
431 return type.unsafeTensorFromTH(cdata,
true);
432 }
else if (r.idx == 3) {
433 return new_with_tensor(type, r.tensor(0));
434 }
else if (r.idx == 4) {
435 PyObject* arg = r.pyobject(0);
436 auto deviceOptional = r.deviceOptional(1);
437 check_legacy_ctor_device(type, deviceOptional);
438 if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
441 return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
443 return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
444 }
else if (r.idx == 5) {
445 auto deviceOptional = r.deviceOptional(1);
446 check_legacy_ctor_device(type, deviceOptional);
447 return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
449 throw std::runtime_error(
"new(): invalid arguments");
452 Tensor indexing_tensor_from_data(
458 ScalarType scalar_type = infer_scalar_type(data);
459 if (scalar_type == ScalarType::Byte) {
460 auto& idx_type = type.toScalarType(scalar_type);
461 return internal_new_from_data(idx_type, std::move(device), data,
false,
false,
false);
463 return internal_new_from_data(type, std::move(device), data,
false,
false,
false);
467 Tensor sparse_coo_tensor_ctor(
const Type& default_type, PyObject* args, PyObject* kwargs) {
468 static PythonArgParser parser({
469 "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
470 "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
471 "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
474 ParsedArgs<6> parsed_args;
475 auto r = parser.parse(args, kwargs, parsed_args);
477 bool type_inference = r.isNone(2);
478 const auto& type = typeWithDefault(r, 2, 3, default_type);
479 const auto& values_type = type.toDense();
482 Tensor values = internal_new_from_data(values_type, r.deviceOptional(3), r.pyobject(1),
false,
true, type_inference);
483 const auto& indices_type = values.type().toScalarType(kLong);
484 Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(3), r.pyobject(0),
false,
true,
false);
485 return at::sparse_coo_tensor(indices, values, values.
options().
layout(at::kSparse)).set_requires_grad(r.toBool(4));
486 }
else if (r.idx == 1) {
487 bool type_inference = r.isNone(3);
488 const auto& type = typeWithDefault(r, 3, 4, default_type);
489 const auto& values_type = type.toDense();
491 Tensor values = internal_new_from_data(values_type, r.deviceOptional(4), r.pyobject(1),
false,
true, type_inference);
492 const auto& indices_type = values.type().toScalarType(kLong);
493 Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(4), r.pyobject(0),
false,
true,
false);
494 return at::sparse_coo_tensor(indices, values, r.intlist(2), values.
options().
layout(at::kSparse)).set_requires_grad(r.toBool(5));
495 }
else if (r.idx == 2) {
496 const auto& type = typeWithDefault(r, 1, 2, default_type);
498 return at::sparse_coo_tensor(r.intlist(0), type.
options().
layout(at::kSparse)).set_requires_grad(r.toBool(3));
500 throw std::runtime_error(
"sparse_coo_tensor(): invalid arguments");
503 Tensor tensor_ctor(
const Type& type, PyObject* args, PyObject* kwargs) {
504 static PythonArgParser parser({
505 "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
508 ParsedArgs<4> parsed_args;
509 auto r = parser.parse(args, kwargs, parsed_args);
511 PyObject* data = r.pyobject(0);
512 if (THPVariable_Check(data)) {
513 PyErr_WarnEx(PyExc_UserWarning,
514 "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " 515 "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", 1);
518 bool type_inference = r.isNone(1);
519 bool args_requires_grad = r.toBool(3);
520 auto new_tensor = internal_new_from_data(
521 typeWithDefault(r, 1, 2, type),
527 new_tensor.detach_();
528 new_tensor.set_requires_grad(args_requires_grad);
531 throw std::runtime_error(
"tensor(): invalid arguments");
534 Tensor as_tensor(
const Type& type, PyObject* args, PyObject* kwargs) {
536 static PythonArgParser parser({
537 "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
540 ParsedArgs<3> parsed_args;
541 auto r = parser.parse(args, kwargs, parsed_args);
543 bool type_inference = r.isNone(1);
544 return internal_new_from_data(
545 typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0),
false,
false, type_inference);
547 throw std::runtime_error(
"tensor(): invalid arguments");
550 Tensor new_tensor(
const Type& type, PyObject* args, PyObject* kwargs) {
551 static PythonArgParser parser({
552 "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
555 ParsedArgs<4> parsed_args;
556 auto r = parser.parse(args, kwargs, parsed_args);
558 PyObject* data = r.pyobject(0);
559 if (THPVariable_Check(data)) {
560 PyErr_WarnEx(PyExc_UserWarning,
561 "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() " 562 "or sourceTensor.clone().detach().requires_grad_(True), rather than tensor.new_tensor(sourceTensor).", 1);
565 bool args_requires_grad = r.toBool(3);
566 auto new_tensor = new_from_data_copy(
567 typeWithDefault(r, 1, 2, type),
570 new_tensor.detach_();
571 new_tensor.set_requires_grad(args_requires_grad);
574 throw std::runtime_error(
"new_tensor(): invalid arguments");
577 Tensor new_empty(
const Type& type, PyObject* args, PyObject* kwargs) {
578 static PythonArgParser parser({
579 "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
582 ParsedArgs<4> parsed_args;
583 auto r = parser.parse(args, kwargs, parsed_args);
585 const auto& actual_type = typeWithDefault(r, 1, 2, type);
586 return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
588 throw std::runtime_error(
"new_empty(): invalid arguments");
591 Tensor new_full(
const Type& type, PyObject* args, PyObject* kwargs) {
592 static PythonArgParser parser({
593 "new_full(IntArrayRef size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
596 ParsedArgs<5> parsed_args;
597 auto r = parser.parse(args, kwargs, parsed_args);
599 const auto& actual_type = typeWithDefault(r, 2, 3, type);
600 return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
602 throw std::runtime_error(
"new_full(): invalid arguments");
605 Tensor new_ones(
const Type& type, PyObject* args, PyObject* kwargs) {
606 static PythonArgParser parser({
607 "new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
610 ParsedArgs<4> parsed_args;
611 auto r = parser.parse(args, kwargs, parsed_args);
613 const auto& actual_type = typeWithDefault(r, 1, 2, type);
614 return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
616 throw std::runtime_error(
"new_ones(): invalid arguments");
619 Tensor new_zeros(
const Type& type, PyObject* args, PyObject* kwargs) {
620 static PythonArgParser parser({
621 "new_zeros(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
624 ParsedArgs<4> parsed_args;
625 auto r = parser.parse(args, kwargs, parsed_args);
627 const auto& actual_type = typeWithDefault(r, 1, 2, type);
628 return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
630 throw std::runtime_error(
"new_zeros(): invalid arguments");
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Scalar represents a 0-dimensional tensor which contains a single element.
TensorOptions options(int16_t device_index=-1) const
Constructs the TensorOptions from a type and a device_index.
Represents a a compute device on which a tensor is located.
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
C10_NODISCARD TensorOptions layout(c10::optional< Layout > layout) const noexcept
Sets the layout of the TensorOptions.