Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_desc.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/WindowsTorchApiMacro.h>
5 #include <c10/util/Exception.h>
6 #include <ATen/core/jit_type.h>
7 #include <torch/csrc/utils/hash.h>
8 
9 #include <algorithm>
10 #include <iostream>
11 #include <vector>
12 
13 namespace torch {
14 namespace jit {
15 namespace fuser {
16 
17 // type information needed by the compiler for input/outputs
18 // contiguity[i] is true if the dim i is contiguous with dim i + 1.
19 // contiguity.back() == true means strides.back() == 1.
20 struct TORCH_API TensorDesc {
21  at::ScalarType scalar_type;
22  std::vector<bool> contiguity;
23 
24  TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
25  : scalar_type{type}, contiguity{contiguity} {
26  if (contiguity.size() == 0) {
27  nDim_ = 0;
28  } else {
29  nDim_ = std::count(contiguity.begin(), contiguity.end(), false) +
30  (lastIsContiguous() ? 1 : 0);
31  }
32  }
33 
34  // Delegating constructors
35  TensorDesc(
36  const at::ScalarType& type,
37  const at::IntArrayRef& sizes,
38  const at::IntArrayRef& strides)
39  : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
40 
41  TensorDesc(const at::Tensor& t)
42  : TensorDesc(t.scalar_type(), t.sizes(), t.strides()) {}
43 
44  TensorDesc(const c10::CompleteTensorTypePtr& type)
45  : TensorDesc(type->scalarType(), type->sizes(), type->strides()) {}
46 
47  // number of dimensions after contiguity compression
48  size_t nDim() const {
49  return nDim_;
50  }
51 
52  // True iff innermost stride is 1
53  bool lastIsContiguous() const {
54  return (contiguity.size() == 0 || contiguity.back());
55  }
56 
57  static std::vector<bool> findContiguous(
58  const at::IntArrayRef& sizes,
59  const at::IntArrayRef& strides) {
60  AT_ASSERT(sizes.size() == strides.size());
61  std::vector<bool> cont(sizes.size());
62  for (size_t i = 0; i < sizes.size(); ++i) {
63  const auto expected_stride =
64  (i + 1 < sizes.size()) ? sizes[i + 1] * strides[i + 1] : 1;
65  cont[i] = (strides[i] == expected_stride);
66  }
67  return cont;
68  }
69 
70  bool operator==(const TensorDesc& desc) const {
71  return scalar_type == desc.scalar_type && contiguity == desc.contiguity;
72  }
73 
74  bool operator!=(const TensorDesc& desc) const {
75  return !(*this == desc);
76  }
77 
78  static size_t hash(const TensorDesc& spec) {
79  return torch::get_hash(
80  spec.scalar_type,
81  spec.nDim_,
82  std::hash<std::vector<bool>>{}(spec.contiguity));
83  }
84 
85  private:
86  size_t nDim_;
87 };
88 
89 inline std::ostream& operator<<(std::ostream& out, const TensorDesc& d) {
90  out << d.scalar_type << "[";
91  for (const auto b : d.contiguity)
92  out << b << ";";
93  out << "]";
94  return out;
95 }
96 
97 } // namespace fuser
98 } // namespace jit
99 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17