1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/elementwise_ops_utils.h" 3 #include "caffe2/operators/experimental/c10/schemas/add.h" 4 #include "caffe2/utils/math.h" 12 template <
class DataType>
17 bool legacy_broadcast,
23 const DataType* A_data =
A.template data<DataType>();
24 const DataType* B_data =
B.template data<DataType>();
25 std::vector<int> A_dims;
26 std::vector<int> B_dims;
28 if (legacy_broadcast) {
31 "In-place is allowed only with the first tensor when " 32 "legacy-broadcasting");
35 A_dims = {
static_cast<int>(
A.numel())};
39 std::tie(pre, n, post) =
40 caffe2::elementwise_ops_utils::ComputeLegacyBroadcastSizes(
43 static_cast<int>(pre), static_cast<int>(n),
static_cast<int>(post)};
44 B_dims = {
static_cast<int>(n), 1};
47 std::copy(
A.sizes().cbegin(),
A.sizes().cend(), std::back_inserter(A_dims));
48 std::copy(
B.sizes().cbegin(),
B.sizes().cend(), std::back_inserter(B_dims));
49 const std::vector<int> C_dims =
50 caffe2::elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
53 CAFFE_ENFORCE_EQ(C_dims, A_dims);
54 }
else if (
B.is_same(
C)) {
55 CAFFE_ENFORCE_EQ(C_dims, B_dims);
60 auto* C_data =
C.template mutable_data<DataType>();
69 C.mutable_data<DataType>(),
70 static_cast<CPUContext*>(&context));
76 C10_REGISTER_KERNEL(caffe2::ops::Add)
77 .kernel<decltype(caffe2::add_op_cpu_impl<float>), &caffe2::add_op_cpu_impl<float>>()
78 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...