1 #ifndef CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ 2 #define CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ 4 #include "caffe2/core/operator.h" 5 #include "caffe2/utils/eigen_utils.h" 6 #include "caffe2/utils/math.h" 10 template <
typename Context>
15 min_(this->
template GetSingleArgument<float>(
"min", -3)),
16 max_(this->
template GetSingleArgument<float>(
"max", 3)),
17 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")) {}
19 USE_OPERATOR_FUNCTIONS(Context);
22 bool RunOnDevice()
override {
23 const auto& WI =
Input(0);
25 auto* Y = Output(0, shape_, at::dtype<float>());
26 float bin_interval = (max_ - min_) / 255.0;
28 for (
int i = 0; i < shape_.size(); i++) {
32 if (WI.template IsType<uint8_t>()) {
33 CAFFE_ENFORCE(total, WI.nbytes());
34 Xdata = WI.template data<uint8_t>();
36 CAFFE_ENFORCE(total, WI.template data<std::string>()[0].size());
37 Xdata =
reinterpret_cast<const uint8_t*
>(
38 WI.template data<std::string>()[0].c_str());
40 auto* Ydata = Y->template mutable_data<float>();
41 ConstEigenVectorMap<uint8_t> index(&Xdata[0], total);
42 EigenVectorMap<float> weights(&Ydata[0], total);
43 weights = (index.cast<
float>().array() * bin_interval) + min_;
50 std::vector<int64_t> shape_;
55 #endif // CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...