1 #ifndef QUANT_DECODE_OP_H_     2 #define QUANT_DECODE_OP_H_     4 #include "caffe2/core/context.h"     5 #include "caffe2/core/operator.h"     6 #include "caffe2/core/tensor.h"     7 #include <c10/util/typeid.h>    13 template <
class CodebookT, 
class CodeT>
    17      const Tensor* 
const decoded_grad,
    20   CAFFE_ENFORCE(codebook.IsType<CodebookT>());
    22   auto* cb_ptr = codebook.data<CodebookT>();
    23   int cb_size = codebook.numel();
    25   CAFFE_ENFORCE(codes.IsType<CodeT>());
    26   auto* code_ptr = codes.data<CodeT>();
    28   if (decoded_grad == 
nullptr) {
    30     output->ResizeLike(codes);
    31     auto* out_ptr = output->template mutable_data<CodebookT>();
    36     int sz = output->numel();
    37     for (
int i = 0; i < sz; i++) {
    38       DCHECK_LE(*code_ptr, cb_size);
    39       *out_ptr++ = cb_ptr[*code_ptr++];
    43     CAFFE_ENFORCE_EQ(codes.numel(), decoded_grad->numel());
    44     auto* gradient_ptr = decoded_grad->data<CodebookT>();
    45     auto* 
const gradient_end = gradient_ptr + decoded_grad->numel();
    47     CAFFE_ENFORCE_EQ(cb_size, output->numel());
    48     auto* out_ptr = output->template mutable_data<CodebookT>();
    49     while (gradient_ptr < gradient_end) {
    50       DCHECK_LE(*code_ptr, cb_size);
    51       out_ptr[*code_ptr++] += *gradient_ptr++;
    56 #define REGISTER_DECODER(codebookType, codesType)                      \    58     {TypeMeta::Id<codebookType>(), TypeMeta::Id<codesType>()},         \    59         [](const Tensor& codebook_,                                    \    60            const Tensor& codes_,                                       \    61            const Tensor* gradient_,                                    \    62            Tensor* outDecoded_,                                        \    64           Decode<codebookType, codesType>(                             \    65               codebook_, codes_, gradient_, outDecoded_, resizeOnly_); \    69 inline void DecodeGeneral(
    75   const static std::map<
    76       std::pair<TypeIdentifier, TypeIdentifier>,
    83       gDecoderMapper = {REGISTER_DECODER(
float, uint8_t),
    84                         REGISTER_DECODER(
float, uint16_t),
    85                         REGISTER_DECODER(
float, int32_t)};
    87   gDecoderMapper.at({codebook.dtype().id(), codes.dtype().id()})(
    88       codebook, codes, gradient, outDecoded, resizeOnly);
    96 enum class QuantDecodeRunTy {
   101 template <QuantDecodeRunTy QuantDecodeRun>
   105   template <
class... Args>
   111   bool RunOnDevice()
 override {
   112     CAFFE_ENFORCE_GT(InputSize(), 1);
   114     CAFFE_ENFORCE_EQ(InputSize(), OutputSize() + 1);
   116     const auto& codebook = Input(0);
   117     CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
   119     for (
int i = 0; i < OutputSize(); i++) {
   120       auto& ci = Input(i + 1);
   121       auto* co = Output(i);
   128           QuantDecodeRun == QuantDecodeRunTy::RUN_ONCE &&
   142   template <
class... Args>
   145   ~QuantDecodeGradientOp() {}
   147   bool RunOnDevice()
 override {
   149     CAFFE_ENFORCE(InputSize() >= 3 && InputSize() % 2 == 1);
   150     const int num_code_tensors = (InputSize() - 1) / 2;
   151     CAFFE_ENFORCE_EQ(OutputSize(), 1);
   153     const auto& codebook = Input(0);
   154     CAFFE_ENFORCE(codebook.template IsType<float>(), codebook.dtype().name());
   156     auto* gradient = Output(0, codebook.sizes(), at::dtype<float>());
   157     auto* gradient_ptr = gradient->template mutable_data<float>();
   158     std::fill(gradient_ptr, gradient_ptr + gradient->numel(), 0);
   160     for (
int i = 0; i < num_code_tensors; i++) {
   161       auto& codes_i = Input(i + 1);
   162       auto& output_gradient_i = Input(i + num_code_tensors + 1);
   163       DecodeGeneral(codebook, codes_i, &output_gradient_i, gradient, 
false);
   170 #endif // QUANT_DECODE_OP_H_ 
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...