Caffe2 - C++ API
A deep learning, cross platform ML framework
filler_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/filler.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 #include <c10/core/Tensor.h>
6 
8 using caffe2::Tensor;
10 using std::vector;
11 using c10::ivalue::TensorList;
12 
13 namespace caffe2 {
14 namespace {
15 void filler_init(
16  ArrayRef<at::Tensor> inputs,
17  const at::Tensor& output_,
18  ArrayRef<int64_t> shape,
19  ArrayRef<int64_t> extra_shape,
20  bool input_as_shape) {
21  Tensor output{C10Tensor(output_)};
22  if (inputs.size()) {
23  auto real_shape = vector<int64_t>{};
24  if (input_as_shape) {
25  // Shape input must be in CPU context
26  Tensor input(inputs[0]);
27  CAFFE_ENFORCE_EQ(
28  input.dim(),
29  1,
30  "When input_as_shape is true, the input must be a 1D tensor of "
31  "data type int64_t");
32  auto* shape_data = input.template data<int64_t>();
33  real_shape.insert(
34  real_shape.end(), shape_data, shape_data + input.dim32(0));
35  } else {
36  Tensor input(inputs[0]);
37  real_shape.insert(
38  real_shape.end(), input.sizes().begin(), input.sizes().end());
39  }
40  real_shape.insert(real_shape.end(), extra_shape.begin(), extra_shape.end());
41  output.Resize(real_shape);
42  } else {
43  output.Resize(shape);
44  }
45 }
46 
47 template <class Type, class Context>
48 void given_tensor_fill_op_cpu_impl(
49  ArrayRef<at::Tensor> inputs,
50  const at::Tensor& output_,
51  ArrayRef<int64_t> shape,
52  ArrayRef<int64_t> extra_shape,
53  bool input_as_shape,
54  const at::Tensor& values_) {
55  Tensor output{C10Tensor(output_)};
56  Tensor values{C10Tensor(values_)};
57  CPUContext context;
58 
59  filler_init(inputs, output_, shape, extra_shape, input_as_shape);
60 
61  // TODO T might not be the correct type to call, since float allows others.
62 
63  DCHECK_EQ(output.numel(), values.numel())
64  << "output size: " << output.numel()
65  << " given size: " << values.numel();
66  auto* data = output.template mutable_data<Type>();
67  const Type* values_data = values.template data<Type>();
68  if (output.numel()) {
69  context.CopySameDevice(output.numel(), values_data, data);
70  }
71 }
72 
73 void constant_fill_op_cpu_impl(
74  ArrayRef<at::Tensor> inputs,
75  const at::Tensor& output_,
76  ArrayRef<int64_t> shape,
77  ArrayRef<int64_t> extra_shape,
78  bool input_as_shape,
79  int64_t dtype,
80  c10::Scalar value) {
81  Tensor output{C10Tensor(output_)};
82  CPUContext context;
83 
84  filler_init(inputs, output_, shape, extra_shape, input_as_shape);
85 
86  if (output.numel()) {
87  if (dtype == caffe2::TensorProto_DataType_FLOAT) {
88  caffe2::math::Set<float, CPUContext>(
89  output.numel(),
90  value.toDouble(),
91  output.template mutable_data<float>(),
92  static_cast<CPUContext*>(&context));
93  } else if (dtype == caffe2::TensorProto_DataType_INT32) {
94  caffe2::math::Set<int32_t, CPUContext>(
95  output.numel(),
96  value.toInt(),
97  output.template mutable_data<int32_t>(),
98  static_cast<CPUContext*>(&context));
99  } else if (dtype == caffe2::TensorProto_DataType_INT64) {
100  caffe2::math::Set<int64_t, CPUContext>(
101  output.numel(),
102  value.toInt(),
103  output.template mutable_data<int64_t>(),
104  static_cast<CPUContext*>(&context));
105  } else {
106  throw std::logic_error(
107  "Unimplemented data type for ConstantFill: " +
108  c10::guts::to_string(dtype));
109  }
110  }
111 }
112 
113 void uniform_fill_op_cpu_impl(
114  ArrayRef<at::Tensor> inputs,
115  const at::Tensor& output_,
116  ArrayRef<int64_t> shape,
117  ArrayRef<int64_t> extra_shape,
118  bool input_as_shape,
119  double min,
120  double max) {
121  Tensor output{C10Tensor(output_)};
122  CPUContext context;
123 
124  filler_init(inputs, output_, shape, extra_shape, input_as_shape);
125 
126  if (inputs.size() == 3) {
127  CAFFE_ENFORCE_EQ(1, Tensor(inputs[1]).numel(), "min blob must be scalar");
128  CAFFE_ENFORCE_EQ(1, Tensor(inputs[2]).numel(), "max blob must be scalar");
129  min = *Tensor(inputs[1]).template data<float>();
130  max = *Tensor(inputs[2]).template data<float>();
131  if (min > max) {
132  auto shape = output.sizes().vec();
133  shape[0] = 0;
134  output.Resize(shape);
135  output.template mutable_data<float>();
136  return;
137  }
138  }
139  caffe2::math::RandUniform<float, CPUContext>(
140  output.numel(),
141  min,
142  max,
143  output.template mutable_data<float>(),
144  static_cast<CPUContext*>(&context));
145 }
146 } // namespace
147 } // namespace caffe2
148 
149 namespace c10 {
150 C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
151  .kernel<decltype(caffe2::constant_fill_op_cpu_impl), &caffe2::constant_fill_op_cpu_impl>()
152  .dispatchKey(CPUTensorId());
153 
154 C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
155  .kernel<decltype(caffe2::uniform_fill_op_cpu_impl), &caffe2::uniform_fill_op_cpu_impl>()
156  .dispatchKey(CPUTensorId());
157 
158 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill)
159  .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
160  .dispatchKey(CPUTensorId());
161 
162 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorIntFill)
163  .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
164  .dispatchKey(CPUTensorId());
165 
166 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorInt64Fill)
167  .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
168  .dispatchKey(CPUTensorId());
169 } // namespace c10
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7