3 #include "caffe2/core/context.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/operators/filler_op.h" 7 #include "caffe2/utils/cast.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 if (!helper.HasArgument(
"dtype")) {
22 auto dtype = cast::GetCastDataType(helper,
"dtype");
24 case TensorProto_DataType_STRING:
27 case TensorProto_DataType_UNDEFINED:
28 CAFFE_THROW(
"Cannot have undefined 'dtype' argument");
30 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
35 bool Fill(
Tensor* output)
override {
36 DCHECK_EQ(output->numel(), values_.numel())
37 <<
"output size: " << output->numel()
38 <<
" given size: " << values_.numel();
39 auto* data = output->template mutable_data<uint8_t>();
40 const uint8_t* values_data = values_.template data<uint8_t>();
41 if (output->numel()) {
42 context_.template CopySameDevice<uint8_t>(
43 output->numel(), values_data, data);
50 auto source_values = this->
template GetRepeatedArgument<string>(
"values");
51 DCHECK_EQ(source_values.size(), 1)
52 <<
"expected size: 1 " 53 <<
" given size: " << source_values.size();
55 auto str = source_values[0];
56 ReinitializeTensor(&values_, {
static_cast<int64_t
>(str.size())}, at::dtype<uint8_t>().device(CPU));
57 uint8_t* values_data = values_.template mutable_data<uint8_t>();
58 for (
int i = 0; i < str.size(); i++) {
59 values_data[i] =
static_cast<uint8_t
>(str[i]);
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...