Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_slice_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_SLICE_OP_H_
2 #define CAFFE2_OPERATORS_INT8_SLICE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor_int8.h"
7 #include "caffe2/operators/quantized/int8_utils.h"
8 #include "caffe2/operators/slice_op.h"
9 
10 namespace caffe2 {
11 
12 namespace int8 {
13 
14 class Int8SliceOp final : public SliceOp<CPUContext> {
15  public:
16  template <class... Args>
17  explicit Int8SliceOp(Args&&... args) : SliceOp(std::forward<Args>(args)...) {}
18 
19  bool RunOnDevice() override {
20  if (InputSize() > 1) {
21  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
22  } else {
23  return DoRunWithType<int64_t>();
24  }
25  }
26 
27  template <typename SIndex>
28  bool DoRunWithType() {
29  if (InputSize() > 1) {
30  ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
31  ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
32  } else {
33  if (!statically_inited_) {
34  CAFFE_ENFORCE(HasArgument("starts"));
35  CAFFE_ENFORCE(HasArgument("ends"));
36  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
37 
39  &starts_host_, {static_cast<int64_t>(starts_.size())}, at::dtype<SIndex>().device(CPU));
41  &ends_host_, {static_cast<int64_t>(ends_.size())}, at::dtype<SIndex>().device(CPU));
42 
43  memcpy(
44  starts_host_.template mutable_data<SIndex>(),
45  starts_.data(),
46  sizeof(SIndex) * starts_.size());
47  memcpy(
48  ends_host_.template mutable_data<SIndex>(),
49  ends_.data(),
50  sizeof(SIndex) * ends_.size());
51  statically_inited_ = true;
52  }
53  }
54 
55  auto& X = Inputs()[0]->Get<Int8TensorCPU>();
56  auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
57  int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
58  auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
59  CHECK_EQ(Y_offset, X.zero_point);
60  CHECK_EQ(Y_scale, X.scale);
61  Y->scale = Y_scale;
62  Y->zero_point = Y_offset;
63 
64  return SliceImpl<SIndex, CPUContext>(
65  &Y->t, X.t, starts_host_, ends_host_, &context_);
66  }
67 };
68 
69 } // namespace int8
70 
71 } // namespace caffe2
72 
73 #endif // CAFFE2_OPERATORS_INT8_SLICE_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70