1 #ifndef QUANT_DECODE_OP_H_ 2 #define QUANT_DECODE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/core/tensor.h" 7 #include <c10/util/typeid.h> 13 template <
class CodebookT,
class CodeT>
17 const Tensor*
const decoded_grad,
20 CAFFE_ENFORCE(codebook.IsType<CodebookT>());
22 auto* cb_ptr = codebook.data<CodebookT>();
23 int cb_size = codebook.numel();
25 CAFFE_ENFORCE(codes.IsType<CodeT>());
26 auto* code_ptr = codes.data<CodeT>();
28 if (decoded_grad ==
nullptr) {
30 output->ResizeLike(codes);
31 auto* out_ptr = output->template mutable_data<CodebookT>();
36 int sz = output->numel();
37 for (
int i = 0; i < sz; i++) {
38 DCHECK_LE(*code_ptr, cb_size);
39 *out_ptr++ = cb_ptr[*code_ptr++];
43 CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel());
44 auto* gradient_ptr = decoded_grad->data<CodebookT>();
45 auto*
const gradient_end = gradient_ptr + decoded_grad->numel();
47 CAFFE_ENFORCE_EQ(cb_size, output->numel());
48 auto* out_ptr = output->template mutable_data<CodebookT>();
49 while (gradient_ptr < gradient_end) {
50 DCHECK_LE(*code_ptr, cb_size);
51 out_ptr[*code_ptr++] += *gradient_ptr++;
56 #define REGISTER_DECODER(codebookType, codesType) \ 58 {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()}, \ 59 [](const Tensor& codebook_, \ 60 const Tensor& codes_, \ 61 const Tensor* gradient_, \ 62 Tensor* outDecoded_, \ 64 Decode<codebookType, codesType>( \ 65 codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \ 69 inline void DecodeGeneral(
75 const static std::map<
76 std::pair<TypeIdentifier, TypeIdentifier>,
83 gDecoderMapper = {REGISTER_DECODER(
float, uint8_t),
84 REGISTER_DECODER(
float, uint16_t),
85 REGISTER_DECODER(
float, int32_t)};
87 gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})(
88 codebook, codes, gradient, outDecoded, resizeOnly);
96 enum class QuantDecodeRunTy {
101 template <QuantDecodeRunTy QuantDecodeRun>
105 template <
class... Args>
111 bool RunOnDevice()
override {
112 CAFFE_ENFORCE_GT(InputSize(), 1);
114 CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
116 const auto& codebook = Input(0);
117 CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
119 for (
int i = 0; i < OutputSize(); i++) {
120 auto& ci = Input(i + 1);
121 auto* co = Output(i);
128 QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
142 template <
class... Args>
145 ~QuantDecodeGradientOp() {}
147 bool RunOnDevice()
override {
149 CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
150 const int num_code_tensors = (InputSize() - 1) / 2;
151 CAFFE_ENFORCE_EQ(OutputSize(), 1);
153 const auto& codebook = Input(0);
154 CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
156 auto* gradient = Output(0, codebook.sizes(), at::dtype<float>());
157 auto* gradient_ptr = gradient->template mutable_data<float>();
158 std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0);
160 for (
int i = 0; i < num_code_tensors; i++) {
161 auto& codes_i = Input(i + 1);
162 auto& output_gradient_i = Input(i + num_code_tensors + 1);
163 DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient,
false);
170 #endif // QUANT_DECODE_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...