Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_given_tensor_fill_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
2 #define CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor_int8.h"
8 #include "caffe2/operators/filler_op.h"
9 #include "caffe2/utils/cast.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 namespace int8 {
14 
15 class Int8GivenTensorFillOp final : public Operator<CPUContext> {
16  public:
17  template <class... Args>
18  explicit Int8GivenTensorFillOp(Args&&... args)
19  : Operator<CPUContext>(std::forward<Args>(args)...),
20  scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
21  zero_point_(
22  this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
23  shape_(this->template GetRepeatedArgument<int64_t>("shape")) {
24  ExtractValues();
25  }
26 
27  bool RunOnDevice() override {
28  auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
29  ReinitializeTensor(&output->t, shape_, at::dtype<uint8_t>().device(CPU));
30  output->scale = scale_;
31  output->zero_point = zero_point_;
32  return Fill(output);
33  }
34 
35  private:
36  void ExtractValues() {
37  auto source_values = this->template GetSingleArgument<string>("values", "");
39  &values_, {static_cast<int64_t>(source_values.size())}, at::dtype<uint8_t>().device(CPU));
40  uint8_t* values_data = values_.template mutable_data<uint8_t>();
41  for (int i = 0; i < source_values.size(); i++) {
42  values_data[i] = static_cast<uint8_t>(source_values[i]);
43  }
44  }
45 
46  bool Fill(Int8TensorCPU* output) {
47  DCHECK_EQ(output->t.numel(), values_.numel())
48  << "output size: " << output->t.numel()
49  << " given size: " << values_.numel();
50  auto* data = output->t.template mutable_data<uint8_t>();
51  const uint8_t* values_data = values_.template data<uint8_t>();
52  if (output->t.numel()) {
53  context_.template CopySameDevice<uint8_t>(
54  output->t.numel(), values_data, data);
55  }
56  return true;
57  }
58 
59  float scale_;
60  int32_t zero_point_;
61  vector<int64_t> shape_;
62  Tensor values_;
63 };
64 
66  public:
67  template <class... Args>
68  explicit Int8GivenIntTensorFillOp(Args&&... args)
69  : Operator<CPUContext>(std::forward<Args>(args)...),
70  scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
71  zero_point_(
72  this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
73  shape_(this->template GetRepeatedArgument<int64_t>("shape")) {
74  ExtractValues();
75  }
76 
77  bool RunOnDevice() override {
78  auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
79  output->t.Resize(shape_);
80  output->scale = scale_;
81  output->zero_point = zero_point_;
82  return Fill(output);
83  }
84 
85  private:
86  void ExtractValues() {
87  auto source_values = this->template GetRepeatedArgument<int32_t>("values");
89  &values_, {static_cast<int64_t>(source_values.size())}, at::dtype<int32_t>().device(CPU));
90  auto* values_data = values_.template mutable_data<int32_t>();
91  for (int i = 0; i < source_values.size(); i++) {
92  values_data[i] = static_cast<int32_t>(source_values[i]);
93  }
94  }
95 
96  bool Fill(Int8TensorCPU* output) {
97  DCHECK_EQ(output->t.numel(), values_.numel())
98  << "output size: " << output->t.numel()
99  << " given size: " << values_.numel();
100  auto* data = output->t.template mutable_data<int32_t>();
101  const auto* values_data = values_.template data<int32_t>();
102  if (output->t.numel()) {
103  context_.template CopySameDevice<int32_t>(
104  output->t.numel(), values_data, data);
105  }
106  return true;
107  }
108 
109  float scale_;
110  int32_t zero_point_;
111  vector<int64_t> shape_;
112  Tensor values_;
113 };
114 
115 } // namespace int8
116 } // namespace caffe2
117 
118 #endif // CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13