Caffe2 - C++ API
A deep learning, cross platform ML framework
mul_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/elementwise_ops_utils.h"
3 #include "caffe2/operators/experimental/c10/schemas/mul.h"
4 #include "caffe2/utils/math.h"
5 #include "caffe2/core/tensor.h"
6 
8 using caffe2::Tensor;
9 
10 namespace caffe2 {
11 namespace {
12 
13 template <class DataType>
14 void mul_op_cpu_impl(
15  const at::Tensor& A_,
16  const at::Tensor& B_,
17  const at::Tensor& C_,
18  bool legacy_broadcast,
19  int64_t axis) {
20  Tensor A{C10Tensor(A_)};
21  Tensor B{C10Tensor(B_)};
22  Tensor C{C10Tensor(C_)};
23  CPUContext context;
24  const DataType* A_data = A.template data<DataType>();
25  const DataType* B_data = B.template data<DataType>();
26  std::vector<int> A_dims;
27  std::vector<int> B_dims;
28 
29  if (legacy_broadcast) {
30  CAFFE_ENFORCE(
31  !B.is_same(C),
32  "In-place is allowed only with the first tensor when "
33  "legacy-broadcasting");
34  C.ResizeLike(A);
35  if (B.numel() == 1) {
36  A_dims = {static_cast<int>(A.numel())};
37  B_dims = {1};
38  } else {
39  size_t pre, n, post;
40  std::tie(pre, n, post) =
41  caffe2::elementwise_ops_utils::ComputeLegacyBroadcastSizes(
42  A, B, axis);
43  A_dims = {
44  static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
45  B_dims = {static_cast<int>(n), 1};
46  }
47  } else {
48  std::copy(A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
49  std::copy(B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
50  const std::vector<int> C_dims =
51  caffe2::elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
52  A_dims, B_dims);
53  if (A.is_same(C)) {
54  CAFFE_ENFORCE_EQ(C_dims, A_dims);
55  } else if (B.is_same(C)) {
56  CAFFE_ENFORCE_EQ(C_dims, B_dims);
57  } else {
58  C.Resize(C_dims);
59  }
60  }
61  auto* C_data = C.template mutable_data<DataType>();
62 
63  caffe2::math::Mul(
64  A_dims.size(),
65  A_dims.data(),
66  B_dims.size(),
67  B_dims.data(),
68  A.data<DataType>(),
69  B.data<DataType>(),
70  C.mutable_data<DataType>(),
71  static_cast<CPUContext*>(&context));
72 }
73 } // namespace
74 } // namespace caffe2
75 
76 namespace c10 {
77 C10_REGISTER_KERNEL(caffe2::ops::Mul)
78  .kernel<decltype(caffe2::mul_op_cpu_impl<float>), &caffe2::mul_op_cpu_impl<float>>()
79  .dispatchKey(CPUTensorId());
80 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7