Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_new.cpp
1 #include <torch/csrc/python_headers.h>
2 #include <torch/csrc/utils/tensor_new.h>
3 
4 #include <torch/csrc/DynamicTypes.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/Size.h>
7 #include <torch/csrc/autograd/variable.h>
8 #include <torch/csrc/utils/auto_gil.h>
9 #include <torch/csrc/utils/cuda_lazy_init.h>
10 #include <torch/csrc/utils/numpy_stub.h>
11 #include <torch/csrc/utils/python_arg_parser.h>
12 #include <torch/csrc/utils/python_numbers.h>
13 #include <torch/csrc/utils/python_scalars.h>
14 #include <torch/csrc/utils/python_strings.h>
15 #include <torch/csrc/utils/tensor_numpy.h>
16 #include <torch/csrc/autograd/generated/variable_factories.h>
17 
18 #include <ATen/ATen.h>
19 #include <ATen/InitialTensorOptions.h>
20 #include <c10/util/Exception.h>
21 #include <c10/util/Optional.h>
22 
23 #include <stdexcept>
24 #include <vector>
25 
26 using at::Backend;
27 using at::Device;
28 using at::IntArrayRef;
29 using at::kCPU;
30 using at::kCUDA;
31 using at::kLong;
32 using at::Scalar;
33 using at::ScalarType;
34 using at::Storage;
35 using at::Tensor;
36 using at::TensorOptions;
37 using at::Type;
38 using c10::optional;
39 
40 namespace torch { namespace utils {
41 namespace {
42 const int MAX_DIMS = 128;
43 
44 void maybe_initialize_cuda(const Type &type) {
45  if (type.is_cuda()) {
46  torch::utils::cuda_lazy_init();
47  }
48 }
49 
50 void maybe_initialize_cuda(const Device device) {
51  if (device.is_cuda()) {
52  torch::utils::cuda_lazy_init();
53  }
54 }
55 
56 Tensor dispatch_zeros(const Type& type, optional<Device> device, IntArrayRef sizes) {
57  maybe_initialize_cuda(type);
58  AutoNoGIL no_gil;
59  return torch::zeros(sizes, type.options(std::move(device)));
60 }
61 
62 Tensor dispatch_ones(const Type& type, optional<Device> device, IntArrayRef sizes) {
63  maybe_initialize_cuda(type);
64  AutoNoGIL no_gil;
65  return torch::ones(sizes, type.options(std::move(device)));
66 }
67 
68 Tensor dispatch_full(const Type& type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
69  maybe_initialize_cuda(type);
70  AutoNoGIL no_gil;
71  return torch::full(sizes, fill_value, type.options(std::move(device)));
72 }
73 
74 Tensor new_with_sizes(const Type& type, optional<Device> device, IntArrayRef sizes) {
75  maybe_initialize_cuda(type);
76  AutoNoGIL no_gil;
77  return torch::empty(sizes, type.options(std::move(device)));
78 }
79 
80 Tensor new_with_storage(const Type& type, Storage storage) {
81  auto tensor = at::empty({}, type.options());
82  tensor.set_(std::move(storage));
83  return tensor;
84 }
85 
86 Tensor new_with_tensor(const Type& type, const Tensor& other) {
87  if (other.type() != type) {
88  throw TypeError("expected %s (got %s)", type.toString(), other.type().toString());
89  }
90  return other.slice();
91 }
92 
93 std::vector<int64_t> compute_sizes(PyObject* seq) {
94  std::vector<int64_t> sizes;
95  THPObjectPtr handle;
96  while (PySequence_Check(seq)) {
97  auto length = PySequence_Length(seq);
98  if (length < 0) throw python_error();
99  sizes.push_back(length);
100  if (sizes.size() > MAX_DIMS) {
101  throw ValueError("too many dimensions '%s'", Py_TYPE(seq)->tp_name);
102  }
103  if (length == 0) break;
104  handle = THPObjectPtr(PySequence_GetItem(seq, 0));
105  if (!handle) {
106  throw ValueError("could not determine the shape of object type '%s'", Py_TYPE(seq)->tp_name);
107  }
108  seq = handle.get();
109  }
110 
111  return sizes;
112 }
113 
114 ScalarType infer_scalar_type(PyObject *obj) {
115  if (PyFloat_Check(obj)) {
116  // this is always guaranteed to be a floating-point type, and makes it more
117  // convenient to write e.g. torch.tensor(0.) than torch.tensor(0., dtype=torch.Tensor.dtype).
118  return torch::tensors::get_default_tensor_type().scalarType();
119  }
120  if (THPUtils_checkLong(obj)) {
121  return ScalarType::Long;
122  }
123  if (PyBool_Check(obj)) {
124  // TODO: infer Bool when we have Bool ScalarType
125  return ScalarType::Byte;
126  }
127  if (THPVariable_Check(obj)) {
128  auto var = reinterpret_cast<THPVariable*>(obj)->cdata;
129  return var.scalar_type();
130  }
131 #ifdef USE_NUMPY
132  if (PyArray_Check(obj)) {
133  return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
134  }
135  if (PyArray_CheckScalar(obj)) {
136  return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)(PyArray_FromScalar(obj, nullptr))));
137  }
138 #endif
139  if (THPUtils_checkString(obj)) {
140  throw TypeError("new(): invalid data type '%s'", Py_TYPE(obj)->tp_name);
141  }
142  if (PySequence_Check(obj)) {
143  c10::optional<ScalarType> scalarType;
144  auto length = PySequence_Length(obj);
145  if (length < 0) throw python_error();
146  // match NumPy semantics, except use default tensor type instead of double.
147  if (length == 0) return torch::tensors::get_default_tensor_type().scalarType();
148  for (int i = 0; i < length; ++i) {
149  THPObjectPtr handle(PySequence_GetItem(obj, i));
150  if (!handle) throw python_error();
151  auto cur_item = handle.get();
152  if (cur_item == obj) throw TypeError("new(): self-referential lists are incompatible");
153  ScalarType item_scalarType = infer_scalar_type(cur_item);
154  scalarType = (scalarType) ?
155  at::promoteTypes(*scalarType, item_scalarType) : item_scalarType;
156  if (scalarType == ScalarType::Double) {
157  // this won't change (unless we hit undefined, but that will fail later).
158  return *scalarType;
159  }
160  }
161  return *scalarType;
162  }
163  AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name);
164 }
165 
166 void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t dim,
167  ScalarType scalarType, int elementSize, PyObject* obj) {
168  int64_t ndim = sizes.size();
169  if (dim == ndim) {
170  torch::utils::store_scalar(data, scalarType, obj);
171  return;
172  }
173 
174  auto n = sizes[dim];
175  auto seq = THPObjectPtr(PySequence_Fast(obj, "not a sequence"));
176  if (!seq) throw python_error();
177  auto seq_size = PySequence_Fast_GET_SIZE(seq.get());
178  if (seq_size != n) {
179  throw ValueError("expected sequence of length %lld at dim %lld (got %lld)",
180  (long long)n, (long long)dim, (long long)seq_size);
181  }
182 
183  PyObject** items = PySequence_Fast_ITEMS(seq.get());
184  for (int64_t i = 0; i < n; i++) {
185  recursive_store(data, sizes, strides, dim + 1, scalarType, elementSize, items[i]);
186  data += strides[dim] * elementSize;
187  }
188 }
189 
190 Tensor internal_new_from_data(
191  const Type& type,
192  c10::optional<Device> device_opt,
193  PyObject* data,
194  bool copy_variables,
195  bool copy_numpy,
196  bool type_inference) {
197  if (THPUtils_checkString(data)) {
198  throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
199  }
200 
201  if (THPVariable_Check(data)) {
202  auto var = reinterpret_cast<THPVariable*>(data)->cdata;
203  if (copy_variables) {
204  var = var.detach();
205  }
206  // infer the scalar type and device type; it's not expected to infer the layout since these constructors
207  // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
208  const auto& scalar_type = type_inference ? var.scalar_type() : type.scalarType();
209  auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(torch::getDeviceType(type)));
210  AutoNoGIL no_gil;
211  maybe_initialize_cuda(device);
212  return var.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
213  }
214 
215 #ifdef USE_NUMPY
216  if (PyArray_Check(data)) {
217  auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
218  const auto& scalar_type = type_inference ? tensor.scalar_type() : type.scalarType();
219  auto device = device_opt.has_value() ? *device_opt : at::Device(type.device_type());
220  AutoNoGIL no_gil;
221  maybe_initialize_cuda(device);
222  return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy);
223  }
224 #endif
225 
226  auto sizes = compute_sizes(data);
227  ScalarType scalar_type = type_inference ? infer_scalar_type(data) : type.scalarType();
228  auto tensor = autograd::make_variable(at::empty(sizes, at::initialTensorOptions().dtype(scalar_type)), /*requires_grad=*/false);
229  recursive_store(
230  (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0,
231  scalar_type, tensor.element_size(), data);
232  auto device = device_opt.has_value() ? *device_opt : at::Device(torch::getDeviceType(type));
233  AutoNoGIL no_gil;
234  maybe_initialize_cuda(device);
235  return tensor.to(device, scalar_type, /*non_blocking=*/false, /*copy=*/false);
236 }
237 
238 Tensor new_from_data_copy(
239  const Type& type,
240  c10::optional<Device> device,
241  PyObject* data) {
242  return internal_new_from_data(type, std::move(device), data, true, true, false);
243 }
244 
245 Tensor legacy_new_from_sequence(
246  const Type& type,
247  c10::optional<Device> device,
248  PyObject* data) {
249  if (!PySequence_Check(data)) {
250  throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
251  }
252  return internal_new_from_data(type, std::move(device), data, false, false, false);
253 }
254 
255 void check_legacy_ctor_device(const Type& type, c10::optional<Device> device) {
256  if (device.has_value()) {
257  AT_CHECK(type.device_type() == device.value().type(),
258  "legacy constructor for device type: ", type.device_type(),
259  " was passed device type: ", device.value().type(),
260  ", but device type must be: ", type.device_type());
261  }
262 }
263 
264 Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
265  static PythonArgParser parser({
266  "new(*, Device? device=None)",
267  "new(*, int64_t cdata)|hidden",
268  "new(Tensor indices, Tensor values, *, Device? device=None)",
269  "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
270  "new(IntArrayRef size, *, Device? device=None)",
271  });
272  ParsedArgs<4> parsed_args;
273  auto r = parser.parse(args, kwargs, parsed_args);
274  if (r.idx == 0) {
275  auto deviceOptional = r.deviceOptional(0);
276  check_legacy_ctor_device(type, deviceOptional);
277  return at::empty({0}, type.options(r.deviceOptional(0)));
278  } else if (r.idx == 1) {
279  auto cdata = reinterpret_cast<void*>(r.toInt64(0));
280  return type.unsafeTensorFromTH(cdata, true);
281  } else if (r.idx == 2) {
282  auto deviceOptional = r.deviceOptional(2);
283  check_legacy_ctor_device(type, deviceOptional);
284  at::OptionalDeviceGuard device_guard(deviceOptional);
285  return at::sparse_coo_tensor(r.tensor(0), r.tensor(1));
286  } else if (r.idx == 3) {
287  auto deviceOptional = r.deviceOptional(3);
288  check_legacy_ctor_device(type, deviceOptional);
289  at::OptionalDeviceGuard device_guard(deviceOptional);
290  return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
291  } else if (r.idx == 4) {
292  PyObject* arg = r.pyobject(0);
293  auto deviceOptional = r.deviceOptional(1);
294  check_legacy_ctor_device(type, deviceOptional);
295  if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
296  // new(sequence) binds to this signature but should be treated differently
297  // unless the sequences is a torch.Size
298  return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
299  }
300  return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
301  }
302  throw std::runtime_error("new(): invalid arguments");
303 }
304 
305 Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
306  static PythonArgParser parser({
307  "new(*, Device? device=None)",
308  "new(*, int64_t cdata)|hidden",
309  "new(Tensor indices, Tensor values, *, Device? device=None)",
310  "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
311  "new(IntArrayRef size, *, Device? device=None)",
312  });
313  ParsedArgs<5> parsed_args;
314  auto r = parser.parse(args, kwargs, parsed_args);
315  if (r.idx == 0) {
316  auto deviceOptional = r.deviceOptional(0);
317  check_legacy_ctor_device(type, deviceOptional);
318  at::OptionalDeviceGuard device_guard(deviceOptional);
319  return at::empty({0}, type.options());
320  } else if (r.idx == 1) {
321  auto cdata = reinterpret_cast<void*>(r.toInt64(0));
322  return type.unsafeTensorFromTH(cdata, true);
323  } else if (r.idx == 2) {
324  // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
325  // have a device (we should infer it).
326  auto deviceOptional = r.deviceOptional(2);
327  check_legacy_ctor_device(type, deviceOptional);
328  at::OptionalDeviceGuard device_guard(deviceOptional);
329  return at::sparse_coo_tensor(r.tensor(0), r.tensor(1));
330  } else if (r.idx == 3) {
331  // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't
332  // have a device (we should infer it).
333  auto deviceOptional = r.deviceOptional(3);
334  check_legacy_ctor_device(type, deviceOptional);
335  at::OptionalDeviceGuard device_guard(deviceOptional);
336  return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
337  } else if (r.idx == 4) {
338  PyObject* arg = r.pyobject(0);
339  auto deviceOptional = r.deviceOptional(1);
340  check_legacy_ctor_device(type, deviceOptional);
341  if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
342  // new(sequence) binds to this signature but should be treated differently
343  // unless the sequences is a torch.Size
344  return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
345  }
346  return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
347  }
348  throw std::runtime_error("new(): invalid arguments");
349 }
350 
351 // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
352 const Type& typeWithDefault(PythonArgs& r, int64_t dtype_idx, int64_t device_idx, const Type& type) {
353  const auto scalartype = r.scalartypeWithDefault(dtype_idx, type.scalarType());
354  const Device types_device_type(type.device_type());
355  const auto device_type = r.isNone(device_idx) ? types_device_type : r.device(device_idx).type();
356  return torch::getVariableType(scalartype, *torch::getLayout(type.backend()), device_type);
357 }
358 } // namespace
359 
360 Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
361  static PythonArgParser parser({
362  "new(*, Device? device=None)",
363  "new(Storage storage)",
364  "new(*, int64_t cdata)|hidden",
365  "new(Tensor other)",
366  "new(IntArrayRef size, *, Device? device=None)",
367  "new(PyObject* data, *, Device? device=None)",
368  });
369 
370  if (type.is_sparse()) {
371  return legacy_sparse_tensor_ctor(type, args, kwargs);
372  }
373 
374  ParsedArgs<2> parsed_args;
375  auto r = parser.parse(args, kwargs, parsed_args);
376  if (r.idx == 0) {
377  auto deviceOptional = r.deviceOptional(0);
378  check_legacy_ctor_device(type, deviceOptional);
379  at::OptionalDeviceGuard device_guard(deviceOptional);
380  return at::empty({0}, type.options());
381  } else if (r.idx == 1) {
382  return new_with_storage(type, r.storage(0));
383  } else if (r.idx == 2) {
384  auto cdata = reinterpret_cast<void*>(r.toInt64(0));
385  return type.unsafeTensorFromTH(cdata, true);
386  } else if (r.idx == 3) {
387  return new_with_tensor(type, r.tensor(0));
388  } else if (r.idx == 4) {
389  PyObject* arg = r.pyobject(0);
390  auto deviceOptional = r.deviceOptional(1);
391  check_legacy_ctor_device(type, deviceOptional);
392  if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
393  // new(sequence) binds to this signature but should be treated differently
394  // unless the sequences is a torch.Size
395  return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
396  }
397  return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
398  } else if (r.idx == 5) {
399  auto deviceOptional = r.deviceOptional(1);
400  check_legacy_ctor_device(type, deviceOptional);
401  return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
402  }
403  throw std::runtime_error("new(): invalid arguments");
404 }
405 
406 Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
407  static PythonArgParser parser({
408  "new(*, Device? device=None)",
409  "new(Storage storage)",
410  "new(*, int64_t cdata)|hidden",
411  "new(Tensor other)", // this doesn't have a dtype/device because it creates an alias.
412  "new(IntArrayRef size, *, Device? device=None)",
413  "new(PyObject* data, *, Device? device=None)",
414  });
415 
416  if (type.is_sparse()) {
417  return legacy_sparse_tensor_new(type, args, kwargs);
418  }
419 
420  ParsedArgs<3> parsed_args;
421  auto r = parser.parse(args, kwargs, parsed_args);
422  if (r.idx == 0) {
423  auto deviceOptional = r.deviceOptional(0);
424  check_legacy_ctor_device(type, deviceOptional);
425  at::OptionalDeviceGuard device_guard(deviceOptional);
426  return at::empty({0}, type.options());
427  } else if (r.idx == 1) {
428  return new_with_storage(type, r.storage(0));
429  } else if (r.idx == 2) {
430  auto cdata = reinterpret_cast<void*>(r.toInt64(0));
431  return type.unsafeTensorFromTH(cdata, true);
432  } else if (r.idx == 3) {
433  return new_with_tensor(type, r.tensor(0));
434  } else if (r.idx == 4) {
435  PyObject* arg = r.pyobject(0);
436  auto deviceOptional = r.deviceOptional(1);
437  check_legacy_ctor_device(type, deviceOptional);
438  if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
439  // new(sequence) binds to this signature but should be treated differently
440  // unless the sequences is a torch.Size
441  return legacy_new_from_sequence(type, deviceOptional, r.pyobject(0));
442  }
443  return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
444  } else if (r.idx == 5) {
445  auto deviceOptional = r.deviceOptional(1);
446  check_legacy_ctor_device(type, deviceOptional);
447  return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
448  }
449  throw std::runtime_error("new(): invalid arguments");
450 }
451 
452 Tensor indexing_tensor_from_data(
453  const Type& type,
454  c10::optional<Device> device,
455  PyObject* data) {
456  // Specific to tensor indexing, converts an indexing list to an
457  // indexing tensor (type Byte or Long)
458  ScalarType scalar_type = infer_scalar_type(data);
459  if (scalar_type == ScalarType::Byte) {
460  auto& idx_type = type.toScalarType(scalar_type);
461  return internal_new_from_data(idx_type, std::move(device), data, false, false, false);
462  } else {
463  return internal_new_from_data(type, std::move(device), data, false, false, false);
464  }
465 }
466 
467 Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject* kwargs) {
468  static PythonArgParser parser({
469  "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
470  "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
471  "sparse_coo_tensor(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
472  });
473 
474  ParsedArgs<6> parsed_args;
475  auto r = parser.parse(args, kwargs, parsed_args);
476  if (r.idx == 0) {
477  bool type_inference = r.isNone(2);
478  const auto& type = typeWithDefault(r, 2, 3, default_type);
479  const auto& values_type = type.toDense();
480  at::OptionalDeviceGuard device_guard(r.deviceOptional(3));
481  // if no dtype provided, infer type based on value type.
482  Tensor values = internal_new_from_data(values_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference);
483  const auto& indices_type = values.type().toScalarType(kLong);
484  Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(3), r.pyobject(0), false, true, false);
485  return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4));
486  } else if (r.idx == 1) {
487  bool type_inference = r.isNone(3);
488  const auto& type = typeWithDefault(r, 3, 4, default_type);
489  const auto& values_type = type.toDense();
490  at::OptionalDeviceGuard device_guard(r.deviceOptional(4));
491  Tensor values = internal_new_from_data(values_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference);
492  const auto& indices_type = values.type().toScalarType(kLong);
493  Tensor indices = internal_new_from_data(indices_type, r.deviceOptional(4), r.pyobject(0), false, true, false);
494  return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
495  } else if (r.idx == 2) {
496  const auto& type = typeWithDefault(r, 1, 2, default_type);
497  at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
498  return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3));
499  }
500  throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
501 }
502 
503 Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
504  static PythonArgParser parser({
505  "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
506  });
507 
508  ParsedArgs<4> parsed_args;
509  auto r = parser.parse(args, kwargs, parsed_args);
510  if (r.idx == 0) {
511  PyObject* data = r.pyobject(0);
512  if (THPVariable_Check(data)) {
513  PyErr_WarnEx(PyExc_UserWarning,
514  "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
515  "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).", 1);
516  }
517 
518  bool type_inference = r.isNone(1);
519  bool args_requires_grad = r.toBool(3);
520  auto new_tensor = internal_new_from_data(
521  typeWithDefault(r, 1, 2, type),
522  r.deviceOptional(2),
523  data,
524  true,
525  true,
526  type_inference);
527  new_tensor.detach_(); // ensure new_tensor a leaf node
528  new_tensor.set_requires_grad(args_requires_grad);
529  return new_tensor;
530  }
531  throw std::runtime_error("tensor(): invalid arguments");
532 }
533 
534 Tensor as_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
535  // TODO: add requires_grad once we decide on semantics for sharing data.
536  static PythonArgParser parser({
537  "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)",
538  });
539 
540  ParsedArgs<3> parsed_args;
541  auto r = parser.parse(args, kwargs, parsed_args);
542  if (r.idx == 0) {
543  bool type_inference = r.isNone(1);
544  return internal_new_from_data(
545  typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0), false, false, type_inference);
546  }
547  throw std::runtime_error("tensor(): invalid arguments");
548 }
549 
550 Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
551  static PythonArgParser parser({
552  "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
553  });
554 
555  ParsedArgs<4> parsed_args;
556  auto r = parser.parse(args, kwargs, parsed_args);
557  if (r.idx == 0) {
558  PyObject* data = r.pyobject(0);
559  if (THPVariable_Check(data)) {
560  PyErr_WarnEx(PyExc_UserWarning,
561  "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
562  "or sourceTensor.clone().detach().requires_grad_(True), rather than tensor.new_tensor(sourceTensor).", 1);
563  }
564 
565  bool args_requires_grad = r.toBool(3);
566  auto new_tensor = new_from_data_copy(
567  typeWithDefault(r, 1, 2, type),
568  r.deviceOptional(2),
569  data);
570  new_tensor.detach_(); // ensure new_tensor a leaf node
571  new_tensor.set_requires_grad(args_requires_grad);
572  return new_tensor;
573  }
574  throw std::runtime_error("new_tensor(): invalid arguments");
575 }
576 
577 Tensor new_empty(const Type& type, PyObject* args, PyObject* kwargs) {
578  static PythonArgParser parser({
579  "new_empty(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
580  }, /*traceable=*/true);
581 
582  ParsedArgs<4> parsed_args;
583  auto r = parser.parse(args, kwargs, parsed_args);
584  if (r.idx == 0) {
585  const auto& actual_type = typeWithDefault(r, 1, 2, type);
586  return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
587  }
588  throw std::runtime_error("new_empty(): invalid arguments");
589 }
590 
591 Tensor new_full(const Type& type, PyObject* args, PyObject* kwargs) {
592  static PythonArgParser parser({
593  "new_full(IntArrayRef size, Scalar fill_value, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
594  }, /*traceable=*/true);
595 
596  ParsedArgs<5> parsed_args;
597  auto r = parser.parse(args, kwargs, parsed_args);
598  if (r.idx == 0) {
599  const auto& actual_type = typeWithDefault(r, 2, 3, type);
600  return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
601  }
602  throw std::runtime_error("new_full(): invalid arguments");
603 }
604 
605 Tensor new_ones(const Type& type, PyObject* args, PyObject* kwargs) {
606  static PythonArgParser parser({
607  "new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
608  }, /*traceable=*/true);
609 
610  ParsedArgs<4> parsed_args;
611  auto r = parser.parse(args, kwargs, parsed_args);
612  if (r.idx == 0) {
613  const auto& actual_type = typeWithDefault(r, 1, 2, type);
614  return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
615  }
616  throw std::runtime_error("new_ones(): invalid arguments");
617 }
618 
619 Tensor new_zeros(const Type& type, PyObject* args, PyObject* kwargs) {
620  static PythonArgParser parser({
621  "new_zeros(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
622  }, /*traceable=*/true);
623 
624  ParsedArgs<4> parsed_args;
625  auto r = parser.parse(args, kwargs, parsed_args);
626  if (r.idx == 0) {
627  const auto& actual_type = typeWithDefault(r, 1, 2, type);
628  return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
629  }
630  throw std::runtime_error("new_zeros(): invalid arguments");
631 }
632 
633 }} // namespace torch::utils
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorMethods.h:42
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: Type.h:107
TensorOptions options(int16_t device_index=-1) const
Constructs the TensorOptions from a type and a device_index.
Definition: Type.h:185
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Definition: Backend.h:23
A OptionalDeviceGuard is an RAII class that sets a device to some value on initialization, and resets the device to its original value on destruction.
Definition: DeviceGuard.h:119
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking=false)
Recursively casts all parameters to the given dtype and device.
Definition: module.cpp:244
Definition: jit_type.h:17
C10_NODISCARD TensorOptions layout(c10::optional< Layout > layout) const noexcept
Sets the layout of the TensorOptions.