Caffe2 - C++ API
A deep learning, cross platform ML framework
quant_decode_op.h
1 #ifndef QUANT_DECODE_OP_H_
2 #define QUANT_DECODE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor.h"
7 #include <c10/util/typeid.h>
8 
9 namespace caffe2 {
10 
11 namespace {
12 
13 template <class CodebookT, class CodeT>
14 void Decode(
15  const Tensor& codebook,
16  const Tensor& codes,
17  /* optional */ const Tensor* const decoded_grad,
18  Tensor* const output,
19  bool resizeOnly) {
20  CAFFE_ENFORCE(codebook.IsType<CodebookT>());
21 
22  auto* cb_ptr = codebook.data<CodebookT>();
23  int cb_size = codebook.numel();
24 
25  CAFFE_ENFORCE(codes.IsType<CodeT>());
26  auto* code_ptr = codes.data<CodeT>();
27 
28  if (decoded_grad == nullptr) {
29  // Forward pass: decode and store codebook values in output.
30  output->ResizeLike(codes);
31  auto* out_ptr = output->template mutable_data<CodebookT>();
32  if (resizeOnly) {
33  return;
34  }
35 
36  int sz = output->numel();
37  for (int i = 0; i < sz; i++) {
38  DCHECK_LE(*code_ptr, cb_size);
39  *out_ptr++ = cb_ptr[*code_ptr++];
40  }
41  } else {
42  // Backward pass: decode and accumulate gradient w.r.t. codebook values.
43  CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel());
44  auto* gradient_ptr = decoded_grad->data<CodebookT>();
45  auto* const gradient_end = gradient_ptr + decoded_grad->numel();
46 
47  CAFFE_ENFORCE_EQ(cb_size, output->numel());
48  auto* out_ptr = output->template mutable_data<CodebookT>();
49  while (gradient_ptr < gradient_end) {
50  DCHECK_LE(*code_ptr, cb_size);
51  out_ptr[*code_ptr++] += *gradient_ptr++;
52  }
53  }
54 }
55 
56 #define REGISTER_DECODER(codebookType, codesType) \
57  { \
58  {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \
59  [](const Tensor& codebook_, \
60  const Tensor& codes_, \
61  const Tensor* gradient_, \
62  Tensor* outDecoded_, \
63  bool resizeOnly_) { \
64  Decode<codebookType, codesType>( \
65  codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \
66  } \
67  }
68 
69 inline void DecodeGeneral(
70  const Tensor& codebook,
71  const Tensor& codes,
72  const Tensor* gradient,
73  Tensor* outDecoded,
74  bool resizeOnly) {
75  const static std::map<
76  std::pair<TypeIdentifier, TypeIdentifier>,
77  std::function<void(
78  const Tensor& codebook,
79  const Tensor& codes,
80  const Tensor* gradient,
81  Tensor* outDecoded,
82  bool resizeOnly)>>
83  gDecoderMapper = {REGISTER_DECODER(float, uint8_t),
84  REGISTER_DECODER(float, uint16_t),
85  REGISTER_DECODER(float, int32_t)};
86 
87  gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})(
88  codebook, codes, gradient, outDecoded, resizeOnly);
89 }
90 
91 } // namespace
92 
93 // Decode tensors based on given codebook,
94 // The codebook is generated by model_quantize.py
95 
96 enum class QuantDecodeRunTy {
97  RUN_ALWAYS,
98  RUN_ONCE,
99 };
100 
101 template <QuantDecodeRunTy QuantDecodeRun>
102 class QuantDecodeOp final : public Operator<CPUContext> {
103  public:
104  USE_OPERATOR_FUNCTIONS(CPUContext);
105  template <class... Args>
106  explicit QuantDecodeOp(Args&&... args)
107  : Operator<CPUContext>(std::forward<Args>(args)...) {}
108 
109  ~QuantDecodeOp() {}
110 
111  bool RunOnDevice() override {
112  CAFFE_ENFORCE_GT(InputSize(), 1);
113  // first input is the codebook
114  CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
115 
116  const auto& codebook = Input(0);
117  CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
118 
119  for (int i = 0; i < OutputSize(); i++) {
120  auto& ci = Input(i + 1);
121  auto* co = Output(i);
122 
123  DecodeGeneral(
124  codebook,
125  ci,
126  nullptr,
127  co,
128  /*resizeOnly=*/QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
129  hasRun_);
130  }
131  hasRun_ = true;
132  return true;
133  }
134 
135  private:
136  bool hasRun_{false};
137 };
138 
139 class QuantDecodeGradientOp final : public Operator<CPUContext> {
140  public:
141  USE_OPERATOR_FUNCTIONS(CPUContext);
142  template <class... Args>
143  explicit QuantDecodeGradientOp(Args&&... args)
144  : Operator<CPUContext>(std::forward<Args>(args)...) {}
145  ~QuantDecodeGradientOp() {}
146 
147  bool RunOnDevice() override {
148  // Inputs: 1 codebook, n tensors of codes, and n corresponding gradients.
149  CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
150  const int num_code_tensors = (InputSize() - 1) / 2;
151  CAFFE_ENFORCE_EQ(OutputSize(), 1);
152 
153  const auto& codebook = Input(0);
154  CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
155 
156  auto* gradient = Output(0, codebook.sizes(), at::dtype<float>());
157  auto* gradient_ptr = gradient->template mutable_data<float>();
158  std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0);
159 
160  for (int i = 0; i < num_code_tensors; i++) {
161  auto& codes_i = Input(i + 1);
162  auto& output_gradient_i = Input(i + num_code_tensors + 1);
163  DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, false);
164  }
165  return true;
166  }
167 };
168 
169 } // namespace caffe2
170 #endif // QUANT_DECODE_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13