Caffe2 - C++ API
A deep learning, cross platform ML framework
given_tensor_fill_op.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/operators/filler_op.h"
7 #include "caffe2/utils/cast.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class GivenTensorFillOp final : public FillerOp<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  explicit GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
17  : FillerOp<Context>(operator_def, ws) {
18  const ArgumentHelper helper(operator_def);
19  // GivenTensorFillOp can be provided with a "dtype" arg if float is
20  // is specified as T. Otherwise, "dtype" is ignored.
21  // In the ideal world, we would get rid of templating of T at all, but we
22  // need to provide backwards compatibility.
23  if (!std::is_same<T, float>::value || !helper.HasArgument("dtype")) {
24  ExtractValues<T>();
25  } else {
26  auto dtype = cast::GetCastDataType(helper, "dtype");
27  switch (dtype) {
28  case TensorProto_DataType_FLOAT:
29  ExtractValues<float>();
30  break;
31  case TensorProto_DataType_DOUBLE:
32  ExtractValues<double>();
33  break;
34  case TensorProto_DataType_BOOL:
35  ExtractValues<bool>();
36  break;
37  case TensorProto_DataType_INT32:
38  ExtractValues<int>();
39  break;
40  case TensorProto_DataType_INT64:
41  ExtractValues<int64_t>();
42  break;
43  case TensorProto_DataType_STRING:
44  ExtractValues<std::string>();
45  break;
46  case TensorProto_DataType_UNDEFINED:
47  CAFFE_THROW("Cannot have undefined 'dtype' argument");
48  default:
49  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
50  }
51  }
52  }
53 
54  bool Fill(Tensor* output) override {
55  return (this->*body_)(output);
56  }
57 
58  private:
59  template <typename Type>
60  void ExtractValues() {
61  auto source_values =
62  this->template GetRepeatedArgument<Type>("values");
63  ReinitializeTensor(&values_, {static_cast<int64_t>(source_values.size())}, at::dtype<Type>().device(CPU));
64  Type* values_data = values_.template mutable_data<Type>();
65  for (int i = 0; i < source_values.size(); i++) {
66  values_data[i] = static_cast<Type>(source_values[i]);
67  }
68  body_ = &GivenTensorFillOp::FillWithType<Type>;
69  }
70 
71  template <typename Type>
72  bool FillWithType(Tensor* output) {
73  DCHECK_EQ(output->numel(), values_.numel())
74  << "output size: " << output->numel()
75  << " given size: " << values_.numel();
76  auto* data = output->template mutable_data<Type>();
77  const Type* values_data = values_.template data<Type>();
78  if (output->numel()) {
79  context_.CopyItemsFromCPU(
80  TypeMeta::Make<Type>(), output->numel(), values_data, data);
81  }
82  return true;
83  }
84 
85  bool (GivenTensorFillOp::*body_)(Tensor* output);
86  Tensor values_;
87 };
88 } // namespace caffe2
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A helper class to index into arguments.
Definition: proto_utils.h:200
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13