4 #include <torch/csrc/WindowsTorchApiMacro.h> 5 #include <c10/util/Exception.h> 6 #include <ATen/core/jit_type.h> 7 #include <torch/csrc/utils/hash.h> 21 at::ScalarType scalar_type;
22 std::vector<bool> contiguity;
24 TensorDesc(
const at::ScalarType& type,
const std::vector<bool>& contiguity)
25 : scalar_type{type}, contiguity{contiguity} {
26 if (contiguity.size() == 0) {
29 nDim_ = std::count(contiguity.begin(), contiguity.end(),
false) +
30 (lastIsContiguous() ? 1 : 0);
36 const at::ScalarType& type,
39 : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
42 : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {}
44 TensorDesc(
const c10::CompleteTensorTypePtr& type)
45 : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
53 bool lastIsContiguous()
const {
54 return (contiguity.size() == 0 || contiguity.back());
57 static std::vector<bool> findContiguous(
60 AT_ASSERT(sizes.
size() == strides.
size());
61 std::vector<bool> cont(sizes.
size());
62 for (
size_t i = 0; i < sizes.
size(); ++i) {
63 const auto expected_stride =
64 (i + 1 < sizes.
size()) ? sizes[i + 1] * strides[i + 1] : 1;
65 cont[i] = (strides[i] == expected_stride);
70 bool operator==(
const TensorDesc& desc)
const {
71 return scalar_type == desc.scalar_type && contiguity == desc.contiguity;
74 bool operator!=(
const TensorDesc& desc)
const {
75 return !(*
this == desc);
78 static size_t hash(
const TensorDesc& spec) {
79 return torch::get_hash(
82 std::hash<std::vector<bool>>{}(spec.contiguity));
89 inline std::ostream& operator<<(std::ostream& out,
const TensorDesc& d) {
90 out << d.scalar_type <<
"[";
91 for (
const auto b : d.contiguity)
constexpr size_t size() const
size - Get the array size.