1 #ifndef CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_ 2 #define CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 9 #include "caffe2/operators/reducer_functors.h" 10 #include "caffe2/perfkernels/math.h" 11 #include "caffe2/utils/math.h" 15 #define FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL 20 template <
class Context>
23 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 template <
class... Args>
27 bitwidth_(OperatorBase::GetSingleArgument<int32_t>(
"bitwidth", 8)),
28 random_(OperatorBase::GetSingleArgument<bool>(
"random",
true)) {
30 bitwidth_ == 1 || bitwidth_ == 2 || bitwidth_ == 4 || bitwidth_ == 8,
31 "Unsupported bitwidth");
33 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL 34 int status = vslNewStream(
37 std::chrono::system_clock::now().time_since_epoch().count());
38 if (status != VSL_STATUS_OK) {
39 LOG(WARNING) <<
"vslNewStream returns " << status;
42 gen_.seed(std::chrono::system_clock::now().time_since_epoch().count());
43 dis_.reset(
new std::uniform_real_distribution<float>(0.0f, 1.0f));
48 ~FloatToFusedRandRowwiseQuantizedOp() {
50 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL 51 int status = vslDeleteStream(&vslStream_);
52 if (status != VSL_STATUS_OK) {
53 LOG(WARNING) <<
"vslDeleteStream returns " << status;
59 bool RunOnDevice()
override;
62 INPUT_TAGS(DATA_FLOAT);
63 OUTPUT_TAGS(DATA_FUSED_QUANTIZED);
68 std::vector<float> random_buffer_;
70 #ifdef FUSED_ROWWISE_RANDOM_QUANTIZATION_USE_MKL 71 VSLStreamStatePtr vslStream_;
73 std::unique_ptr<std::uniform_real_distribution<float>> dis_;
74 std::minstd_rand gen_;
78 template <
class Context>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
84 bool RunOnDevice()
override;
87 INPUT_TAGS(DATA_FUSED_QUANTIZED);
88 OUTPUT_TAGS(DATA_FLOAT);
93 #endif // CAFFE2_OPERATORS_FUSED_ROWWISE_RAND_CONVERSION_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...