Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_reshape_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
2 #define CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor_int8.h"
7 #include "caffe2/operators/quantized/int8_utils.h"
8 #include "caffe2/operators/reshape_op.h"
9 
10 namespace caffe2 {
11 
12 namespace int8 {
13 
14 class Int8ReshapeOp final : public ReshapeOp<uint8_t, CPUContext> {
15  public:
16  template <class... Args>
17  explicit Int8ReshapeOp(Args&&... args)
18  : ReshapeOp(std::forward<Args>(args)...) {}
19 
20  bool RunOnDevice() override {
21  if (InputSize() == 2) {
22  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
23  }
24  CAFFE_ENFORCE(
25  OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
26  return this->template DoRunWithType<int64_t>();
27  }
28 
29  template <typename T>
30  bool DoRunWithType() {
31  auto& X = Inputs()[0]->Get<Int8TensorCPU>();
32  auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
33  int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
34  auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
35  CHECK_EQ(Y_offset, X.zero_point);
36  CHECK_EQ(Y_scale, X.scale);
37  Y->scale = Y_scale;
38  Y->zero_point = Y_offset;
39  DoRunWithTypeImpl<T>(X.t, &Y->t);
40  return true;
41  }
42 };
43 
44 } // namespace int8
45 
46 } // namespace caffe2
47 
48 #endif // CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70