1 #ifndef CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor_int8.h" 8 #include "caffe2/operators/filler_op.h" 9 #include "caffe2/utils/cast.h" 10 #include "caffe2/utils/math.h" 17 template <
class... Args>
20 scale_(this->
template GetSingleArgument<float>(
"Y_scale", 1.0)),
22 this->
template GetSingleArgument<int32_t>(
"Y_zero_point", 0)),
23 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")) {
27 bool RunOnDevice()
override {
28 auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
30 output->scale = scale_;
31 output->zero_point = zero_point_;
36 void ExtractValues() {
37 auto source_values = this->
template GetSingleArgument<string>(
"values",
"");
39 &values_, {
static_cast<int64_t
>(source_values.size())}, at::dtype<uint8_t>().device(CPU));
40 uint8_t* values_data = values_.template mutable_data<uint8_t>();
41 for (
int i = 0; i < source_values.size(); i++) {
42 values_data[i] =
static_cast<uint8_t
>(source_values[i]);
47 DCHECK_EQ(output->t.numel(), values_.numel())
48 <<
"output size: " << output->t.numel()
49 <<
" given size: " << values_.numel();
50 auto* data = output->t.template mutable_data<uint8_t>();
51 const uint8_t* values_data = values_.template data<uint8_t>();
52 if (output->t.numel()) {
53 context_.template CopySameDevice<uint8_t>(
54 output->t.numel(), values_data, data);
61 vector<int64_t> shape_;
67 template <
class... Args>
70 scale_(this->
template GetSingleArgument<float>(
"Y_scale", 1.0)),
72 this->
template GetSingleArgument<int32_t>(
"Y_zero_point", 0)),
73 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")) {
77 bool RunOnDevice()
override {
78 auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
79 output->t.Resize(shape_);
80 output->scale = scale_;
81 output->zero_point = zero_point_;
86 void ExtractValues() {
87 auto source_values = this->
template GetRepeatedArgument<int32_t>(
"values");
89 &values_, {
static_cast<int64_t
>(source_values.size())}, at::dtype<int32_t>().device(CPU));
90 auto* values_data = values_.template mutable_data<int32_t>();
91 for (
int i = 0; i < source_values.size(); i++) {
92 values_data[i] =
static_cast<int32_t
>(source_values[i]);
97 DCHECK_EQ(output->t.numel(), values_.numel())
98 <<
"output size: " << output->t.numel()
99 <<
" given size: " << values_.numel();
100 auto* data = output->t.template mutable_data<int32_t>();
101 const auto* values_data = values_.template data<int32_t>();
102 if (output->t.numel()) {
103 context_.template CopySameDevice<int32_t>(
104 output->t.numel(), values_data, data);
111 vector<int64_t> shape_;
118 #endif // CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...