Caffe2 - C++ API
A deep learning, cross platform ML framework
given_tensor_fill_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/operators/filler_op.h"
23 #include "caffe2/utils/cast.h"
24 #include "caffe2/utils/math.h"
25 
26 namespace caffe2 {
27 
28 template <typename T, class Context>
29 class GivenTensorFillOp final : public FillerOp<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
33  : FillerOp<Context>(operator_def, ws) {
34  const ArgumentHelper helper(operator_def);
35  // GivenTensorFillOp can be provided with a "dtype" arg if float is
36  // is specified as T. Otherwise, "dtype" is ignored.
37  // In the ideal world, we would get rid of templating of T at all, but we
38  // need to provide backwards compatibility.
39  if (!std::is_same<T, float>::value || !helper.HasArgument("dtype")) {
40  ExtractValues<T>();
41  } else {
42  auto dtype = cast::GetCastDataType(helper, "dtype");
43  switch (dtype) {
44  case TensorProto_DataType_FLOAT:
45  ExtractValues<float>();
46  break;
47  case TensorProto_DataType_DOUBLE:
48  ExtractValues<double>();
49  break;
50  case TensorProto_DataType_BOOL:
51  ExtractValues<bool>();
52  break;
53  case TensorProto_DataType_INT32:
54  ExtractValues<int>();
55  break;
56  case TensorProto_DataType_INT64:
57  ExtractValues<int64_t>();
58  break;
59  case TensorProto_DataType_STRING:
60  ExtractValues<std::string>();
61  break;
62  case TensorProto_DataType_UNDEFINED:
63  CAFFE_THROW("Cannot have undefined 'dtype' argument");
64  default:
65  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
66  }
67  }
68  }
69 
70  bool Fill(Tensor<Context>* output) override {
71  return (this->*body_)(output);
72  }
73 
74  private:
75  template <typename Type>
76  void ExtractValues() {
77  auto source_values =
78  OperatorBase::template GetRepeatedArgument<Type>("values");
79  values_.Resize(source_values.size());
80  Type* values_data = values_.template mutable_data<Type>();
81  for (int i = 0; i < source_values.size(); i++) {
82  values_data[i] = static_cast<Type>(source_values[i]);
83  }
84  body_ = &GivenTensorFillOp::FillWithType<Type>;
85  }
86 
87  template <typename Type>
88  bool FillWithType(Tensor<Context>* output) {
89  DCHECK_EQ(output->size(), values_.size())
90  << "output size: " << output->size()
91  << " given size: " << values_.size();
92  auto* data = output->template mutable_data<Type>();
93  const Type* values_data = values_.template data<Type>();
94  if (output->size()) {
95  context_.template Copy<Type, CPUContext, Context>(
96  output->size(), values_data, data);
97  }
98  return true;
99  }
100 
101  bool (GivenTensorFillOp::*body_)(Tensor<Context>* output);
102  TensorCPU values_;
103 };
104 } // namespace caffe2
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
A helper class to index into arguments.
Definition: proto_utils.h:198
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
Copyright (c) 2016-present, Facebook, Inc.