1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/filler.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 5 #include <c10/core/Tensor.h> 11 using c10::ivalue::TensorList;
16 ArrayRef<at::Tensor> inputs,
18 ArrayRef<int64_t> shape,
19 ArrayRef<int64_t> extra_shape,
20 bool input_as_shape) {
21 Tensor output{C10Tensor(output_)};
23 auto real_shape = vector<int64_t>{};
30 "When input_as_shape is true, the input must be a 1D tensor of " 32 auto* shape_data = input.template data<int64_t>();
34 real_shape.end(), shape_data, shape_data + input.dim32(0));
38 real_shape.end(), input.sizes().begin(), input.sizes().end());
40 real_shape.insert(real_shape.end(), extra_shape.begin(), extra_shape.end());
41 output.Resize(real_shape);
47 template <
class Type,
class Context>
48 void given_tensor_fill_op_cpu_impl(
49 ArrayRef<at::Tensor> inputs,
51 ArrayRef<int64_t> shape,
52 ArrayRef<int64_t> extra_shape,
55 Tensor output{C10Tensor(output_)};
56 Tensor values{C10Tensor(values_)};
59 filler_init(inputs, output_, shape, extra_shape, input_as_shape);
63 DCHECK_EQ(output.numel(), values.numel())
64 <<
"output size: " << output.numel()
65 <<
" given size: " << values.numel();
66 auto* data = output.template mutable_data<Type>();
67 const Type* values_data = values.template data<Type>();
69 context.CopySameDevice(output.numel(), values_data, data);
73 void constant_fill_op_cpu_impl(
74 ArrayRef<at::Tensor> inputs,
76 ArrayRef<int64_t> shape,
77 ArrayRef<int64_t> extra_shape,
81 Tensor output{C10Tensor(output_)};
84 filler_init(inputs, output_, shape, extra_shape, input_as_shape);
87 if (dtype == caffe2::TensorProto_DataType_FLOAT) {
88 caffe2::math::Set<float, CPUContext>(
91 output.template mutable_data<float>(),
92 static_cast<CPUContext*>(&context));
93 }
else if (dtype == caffe2::TensorProto_DataType_INT32) {
94 caffe2::math::Set<int32_t, CPUContext>(
97 output.template mutable_data<int32_t>(),
98 static_cast<CPUContext*>(&context));
99 }
else if (dtype == caffe2::TensorProto_DataType_INT64) {
100 caffe2::math::Set<int64_t, CPUContext>(
103 output.template mutable_data<int64_t>(),
104 static_cast<CPUContext*>(&context));
106 throw std::logic_error(
107 "Unimplemented data type for ConstantFill: " +
108 c10::guts::to_string(dtype));
113 void uniform_fill_op_cpu_impl(
114 ArrayRef<at::Tensor> inputs,
116 ArrayRef<int64_t> shape,
117 ArrayRef<int64_t> extra_shape,
121 Tensor output{C10Tensor(output_)};
124 filler_init(inputs, output_, shape, extra_shape, input_as_shape);
126 if (inputs.size() == 3) {
127 CAFFE_ENFORCE_EQ(1,
Tensor(inputs[1]).numel(),
"min blob must be scalar");
128 CAFFE_ENFORCE_EQ(1,
Tensor(inputs[2]).numel(),
"max blob must be scalar");
129 min = *
Tensor(inputs[1]).template data<float>();
130 max = *
Tensor(inputs[2]).template data<float>();
132 auto shape = output.sizes().vec();
134 output.Resize(shape);
135 output.template mutable_data<float>();
139 caffe2::math::RandUniform<float, CPUContext>(
143 output.template mutable_data<float>(),
144 static_cast<CPUContext*>(&context));
150 C10_REGISTER_KERNEL(caffe2::ops::ConstantFill)
151 .kernel<decltype(caffe2::constant_fill_op_cpu_impl), &caffe2::constant_fill_op_cpu_impl>()
152 .dispatchKey(CPUTensorId());
154 C10_REGISTER_KERNEL(caffe2::ops::UniformFill)
155 .kernel<decltype(caffe2::uniform_fill_op_cpu_impl), &caffe2::uniform_fill_op_cpu_impl>()
156 .dispatchKey(CPUTensorId());
158 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorFill)
159 .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<float, caffe2::CPUContext>>()
160 .dispatchKey(CPUTensorId());
162 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorIntFill)
163 .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int, caffe2::CPUContext>>()
164 .dispatchKey(CPUTensorId());
166 C10_REGISTER_KERNEL(caffe2::ops::GivenTensorInt64Fill)
167 .kernel<decltype(caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>), &caffe2::given_tensor_fill_op_cpu_impl<int64_t, caffe2::CPUContext>>()
168 .dispatchKey(CPUTensorId());
Scalar represents a 0-dimensional tensor which contains a single element.
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...