Caffe2 - C++ API
A deep learning, cross platform ML framework
qtensor.h
1 #ifndef CAFFE2_CORE_QTENSOR_H_
2 #define CAFFE2_CORE_QTENSOR_H_
3 
4 #include <algorithm>
5 #include <climits>
6 #include <cstddef>
7 #include <vector>
8 
9 #include "caffe2/core/common.h"
10 #include "caffe2/core/context.h"
11 #include "caffe2/core/tensor.h"
12 #include <c10/util/typeid.h>
13 
14 namespace caffe2 {
15 
16 template <class Context>
17 class C10_EXPORT QTensor {
18  public:
19  QTensor() {}
20  virtual ~QTensor() {}
49  // TODO: changing at::ArrayRef<int> to at::ArrayRef<int64_t>?
50  explicit QTensor(
51  at::ArrayRef<int> dims,
52  const unsigned char precision,
53  const bool signbit = false)
54  : precision_(precision), signed_(signbit) {
55  Resize(dims);
56  }
57 
58  void Resize(at::ArrayRef<int> dim_source) {
59  if (dims_ != dim_source) {
60  size_t source_size = std::accumulate(
61  dim_source.begin(), dim_source.end(), 1, std::multiplies<int>());
62  if ((source_size * (precision_ + signed_)) > capacity_) {
63  data_ptr_.clear();
64  capacity_ = 0;
65  }
66  dims_ = dim_source.vec();
67  size_ = source_size;
68  }
69  }
70 
71  void
72  SetBitAtIndex(const unsigned char bit, const size_t index, const bool value) {
73  // Get the mutable data at bit depth `bit`.
74  unsigned char* d = mutable_data();
75 
76  CAFFE_ENFORCE(
77  bit < precision_ + signed_,
78  "Attempted to a set a bit that is not allocated.");
79  CAFFE_ENFORCE(bit * aligned_size() < capacity_);
80 
81  auto idx = (aligned_size() * bit) / CHAR_BIT;
82  d = &d[idx];
83 
84  idx = index / CHAR_BIT;
85  auto shift = CHAR_BIT - (index % CHAR_BIT) - 1;
86 
87  if (value) {
88  d[idx] |= 1 << shift;
89  } else {
90  d[idx] &= ~(1 << shift);
91  }
92  }
93 
94  bool GetBitAtIndex(const unsigned char bit, const size_t index) const {
95  // Get the data at bit depth `bit`
96  const unsigned char* d = data();
97  auto idx = (aligned_size() * bit) / CHAR_BIT;
98  d = &d[idx];
99 
100  idx = index / CHAR_BIT;
101  auto shift = CHAR_BIT - (index % CHAR_BIT) - 1;
102 
103  return d[idx] & (1 << shift);
104  }
105 
106  void SetPrecision(const unsigned char precision) {
107  precision_ = precision;
108  data_ptr_.clear();
109  }
110 
111  void SetSigned(const bool make_signed = true) {
112  signed_ = make_signed;
113  data_ptr_.clear();
114  }
115 
116  void SetScale(const double scale) {
117  scale_ = scale;
118  }
119 
120  void SetBias(const double bias) {
121  bias_ = bias;
122  }
123 
124  unsigned char* mutable_data() {
125  if (!data_ptr_) {
126  data_ptr_ = Context::New(nbytes());
127  capacity_ = nbytes() * CHAR_BIT;
128  }
129  CAFFE_ENFORCE(capacity_ == nbytes() * CHAR_BIT);
130  return static_cast<unsigned char*>(data_ptr_.get());
131  }
132 
133  inline const unsigned char* data() const {
134  return static_cast<unsigned char*>(data_ptr_.get());
135  }
136 
137  inline size_t size() const {
138  return size_;
139  }
140 
141  inline unsigned char alignment() const {
142  return alignment_;
143  }
144 
145  inline unsigned char precision() const {
146  return precision_;
147  }
148 
149  inline at::ArrayRef<int> sizes() const {
150  return dims_;
151  }
152 
153  // TODO: deprecate?
154  inline at::ArrayRef<int> dims() const {
155  return dims_;
156  }
157 
158  inline bool is_signed() const {
159  return signed_;
160  }
161 
165  inline int ndim() const {
166  return dims_.size();
167  }
168 
169  inline size_t aligned_size() const {
170  return alignment_ * ((size_ + alignment_ - 1) / alignment_);
171  }
172 
173  inline size_t nbytes() const {
174  return (aligned_size() * (precision_ + signed_)) / CHAR_BIT;
175  }
176 
177  inline double scale() const {
178  return scale_;
179  }
180 
181  inline double bias() const {
182  return bias_;
183  }
184 
188  inline int dim32(const int i) const {
189  DCHECK_LT(i, dims_.size()) << "Exceeding ndim limit " << dims_.size();
190  DCHECK_GE(i, 0) << "Cannot have negative index";
191  CAFFE_ENFORCE_LT(dims_[i], std::numeric_limits<int>::max());
192  return static_cast<int>(dims_[i]);
193  }
194 
206  inline int canonical_axis_index(int axis_index) const {
207  CAFFE_ENFORCE_GE(axis_index, -ndim());
208  CAFFE_ENFORCE_LT(axis_index, ndim());
209  if (axis_index < 0) {
210  return axis_index + ndim();
211  }
212  return axis_index;
213  }
214 
218  inline int64_t size_from_dim(int k) const {
219  int64_t r = 1;
220  for (int i = k; i < dims_.size(); ++i) {
221  r *= dims_[i];
222  }
223  return r;
224  }
225 
229  inline int64_t size_to_dim(int k) const {
230  CAFFE_ENFORCE(k < dims_.size());
231  int64_t r = 1;
232  for (int i = 0; i < k; ++i) {
233  r *= dims_[i];
234  }
235  return r;
236  }
237 
238  protected:
239  std::vector<int> dims_;
240  size_t size_ = 0;
241 
242  // Precision in bits.
243  unsigned char precision_ = CHAR_BIT;
244  // Bit alignment.
245  unsigned char alignment_ = CHAR_BIT;
246 
247  // Allocated data.
248  at::DataPtr data_ptr_;
249 
250  // value = scale_ * (x + bias_)
251  double scale_;
252  double bias_;
253  bool signed_ = false;
254 
255  // Capacity in bits.
256  size_t capacity_ = 0;
257 };
258 
259 } // namespace caffe2
260 #endif // CAFFE2_CORE_QTENSOR_H_
int64_t size_to_dim(int k) const
Product of all dims up to.
Definition: qtensor.h:229
int canonical_axis_index(int axis_index) const
Returns the &#39;canonical&#39; version of a (usually) user-specified axis, allowing for negative indexing (e...
Definition: qtensor.h:206
int ndim() const
Returns the number of dimensions of the data.
Definition: qtensor.h:165
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
int dim32(const int i) const
Returns the i-th dimension of the qtensor in int.
Definition: qtensor.h:188
QTensor(at::ArrayRef< int > dims, const unsigned char precision, const bool signbit=false)
Creates a quantized tensor of the given dimension.
Definition: qtensor.h:50
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
int64_t size_from_dim(int k) const
Return product of all dimensions starting from K.
Definition: qtensor.h:218