Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_rowwise_random_quantization_ops.h
1 #ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
2 #define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
3 
4 #include <chrono>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/operators/reducer_functors.h"
10 #include "caffe2/perfkernels/math.h"
11 #include "caffe2/utils/math.h"
12 
13 #ifdef CAFFE2_USE_MKL
14 #include <mkl.h>
15 #define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
16 #endif
17 
18 namespace caffe2 {
19 
20 template <class Context>
22  public:
23  USE_OPERATOR_CONTEXT_FUNCTIONS;
24  template <class... Args>
25  explicit FloatToFusedRandRowwiseQuantizedOp(Args&&... args)
26  : Operator<Context>(std::forward<Args>(args)...),
27  bitwidth_(OperatorBase::GetSingleArgument<int32_t>("bitwidth", 8)),
28  random_(OperatorBase::GetSingleArgument<bool>("random", true)) {
29  CAFFE_ENFORCE(
30  bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8,
31  "Unsupported bitwidth");
32  if (random_) {
33 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
34  int status = vslNewStream(
35  &vslStream_,
36  VSL_BRNG_MT19937,
37  std::chrono::system_clock::now().time_since_epoch().count());
38  if (status != VSL_STATUS_OK) {
39  LOG(WARNING) << "vslNewStream returns " << status;
40  }
41 #else
42  gen_.seed(std::chrono::system_clock::now().time_since_epoch().count());
43  dis_.reset(new std::uniform_real_distribution<float>(0.0f, 1.0f));
44 #endif
45  }
46  }
47 
48  ~FloatToFusedRandRowwiseQuantizedOp() {
49  if (random_) {
50 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
51  int status = vslDeleteStream(&vslStream_);
52  if (status != VSL_STATUS_OK) {
53  LOG(WARNING) << "vslDeleteStream returns " << status;
54  }
55 #endif
56  }
57  }
58 
59  bool RunOnDevice() override;
60 
61  private:
62  INPUT_TAGS(DATA_FLOAT);
63  OUTPUT_TAGS(DATA_FUSED_QUANTIZED);
64 
65  protected:
66  size_t bitwidth_{8};
67  bool random_{true};
68  std::vector<float> random_buffer_;
69 
70 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL
71  VSLStreamStatePtr vslStream_;
72 #else
73  std::unique_ptr<std::uniform_real_distribution<float>> dis_;
74  std::minstd_rand gen_;
75 #endif
76 };
77 
78 template <class Context>
80  public:
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82  USE_SIMPLE_CTOR_DTOR(FusedRandRowwiseQuantizedToFloatOp)
83 
84  bool RunOnDevice() override;
85 
86  private:
87  INPUT_TAGS(DATA_FUSED_QUANTIZED);
88  OUTPUT_TAGS(DATA_FLOAT);
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13