Caffe2 - C++ API
A deep learning, cross platform ML framework
given_tensor_byte_string_to_uint8_fill_op.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/operators/filler_op.h"
7 #include "caffe2/utils/cast.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class GivenTensorByteStringToUInt8FillOp final : public FillerOp<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  explicit GivenTensorByteStringToUInt8FillOp(const OperatorDef& operator_def, Workspace* ws)
17  : FillerOp<Context>(operator_def, ws) {
18  const ArgumentHelper helper(operator_def);
19  if (!helper.HasArgument("dtype")) {
20  Extract();
21  } else {
22  auto dtype = cast::GetCastDataType(helper, "dtype");
23  switch (dtype) {
24  case TensorProto_DataType_STRING:
25  Extract();
26  break;
27  case TensorProto_DataType_UNDEFINED:
28  CAFFE_THROW("Cannot have undefined 'dtype' argument");
29  default:
30  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
31  }
32  }
33  }
34 
35  bool Fill(Tensor* output) override {
36  DCHECK_EQ(output->numel(), values_.numel())
37  << "output size: " << output->numel()
38  << " given size: " << values_.numel();
39  auto* data = output->template mutable_data<uint8_t>();
40  const uint8_t* values_data = values_.template data<uint8_t>();
41  if (output->numel()) {
42  context_.template CopySameDevice<uint8_t>(
43  output->numel(), values_data, data);
44  }
45  return true;
46  }
47 
48  private:
49  void Extract() {
50  auto source_values = this->template GetRepeatedArgument<string>("values");
51  DCHECK_EQ(source_values.size(), 1)
52  << "expected size: 1 "
53  << " given size: " << source_values.size();
54 
55  auto str = source_values[0];
56  ReinitializeTensor(&values_, {static_cast<int64_t>(str.size())}, at::dtype<uint8_t>().device(CPU));
57  uint8_t* values_data = values_.template mutable_data<uint8_t>();
58  for (int i = 0; i < str.size(); i++) {
59  values_data[i] = static_cast<uint8_t>(str[i]);
60  }
61 }
62 
63 Tensor values_;
64 };
65 } // namespace caffe2
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A helper class to index into arguments.
Definition: proto_utils.h:200
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13