1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/elementwise_ops_utils.h" 3 #include "caffe2/operators/experimental/c10/schemas/mul.h" 4 #include "caffe2/utils/math.h" 5 #include "caffe2/core/tensor.h" 13 template <
class DataType>
18 bool legacy_broadcast,
24 const DataType* A_data =
A.template data<DataType>();
25 const DataType* B_data =
B.template data<DataType>();
26 std::vector<int> A_dims;
27 std::vector<int> B_dims;
29 if (legacy_broadcast) {
32 "In-place is allowed only with the first tensor when " 33 "legacy-broadcasting");
36 A_dims = {
static_cast<int>(
A.numel())};
40 std::tie(pre, n, post) =
41 caffe2::elementwise_ops_utils::ComputeLegacyBroadcastSizes(
44 static_cast<int>(pre), static_cast<int>(n),
static_cast<int>(post)};
45 B_dims = {
static_cast<int>(n), 1};
48 std::copy(
A.sizes().cbegin(),
A.sizes().cend(), std::back_inserter(A_dims));
49 std::copy(
B.sizes().cbegin(),
B.sizes().cend(), std::back_inserter(B_dims));
50 const std::vector<int> C_dims =
51 caffe2::elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
54 CAFFE_ENFORCE_EQ(C_dims, A_dims);
55 }
else if (
B.is_same(
C)) {
56 CAFFE_ENFORCE_EQ(C_dims, B_dims);
61 auto* C_data =
C.template mutable_data<DataType>();
70 C.mutable_data<DataType>(),
71 static_cast<CPUContext*>(&context));
77 C10_REGISTER_KERNEL(caffe2::ops::Mul)
78 .kernel<decltype(caffe2::mul_op_cpu_impl<float>), &caffe2::mul_op_cpu_impl<float>>()
79 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...