Caffe2 - C++ API
A deep learning, cross platform ML framework
add_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/elementwise_ops_utils.h"
3 #include "caffe2/operators/experimental/c10/schemas/add.h"
4 #include "caffe2/utils/math.h"
5 
7 using caffe2::Tensor;
8 
9 namespace caffe2 {
10 namespace {
11 
12 template <class DataType>
13 void add_op_cpu_impl(
14  const at::Tensor& A_,
15  const at::Tensor& B_,
16  const at::Tensor& C_,
17  bool legacy_broadcast,
18  int64_t axis) {
19  Tensor A{C10Tensor(A_)};
20  Tensor B{C10Tensor(B_)};
21  Tensor C{C10Tensor(C_)};
22  CPUContext context;
23  const DataType* A_data = A.template data<DataType>();
24  const DataType* B_data = B.template data<DataType>();
25  std::vector<int> A_dims;
26  std::vector<int> B_dims;
27 
28  if (legacy_broadcast) {
29  CAFFE_ENFORCE(
30  !B.is_same(C),
31  "In-place is allowed only with the first tensor when "
32  "legacy-broadcasting");
33  C.ResizeLike(A);
34  if (B.numel() == 1) {
35  A_dims = {static_cast<int>(A.numel())};
36  B_dims = {1};
37  } else {
38  size_t pre, n, post;
39  std::tie(pre, n, post) =
40  caffe2::elementwise_ops_utils::ComputeLegacyBroadcastSizes(
41  A, B, axis);
42  A_dims = {
43  static_cast<int>(pre), static_cast<int>(n), static_cast<int>(post)};
44  B_dims = {static_cast<int>(n), 1};
45  }
46  } else {
47  std::copy(A.sizes().cbegin(), A.sizes().cend(), std::back_inserter(A_dims));
48  std::copy(B.sizes().cbegin(), B.sizes().cend(), std::back_inserter(B_dims));
49  const std::vector<int> C_dims =
50  caffe2::elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
51  A_dims, B_dims);
52  if (A.is_same(C)) {
53  CAFFE_ENFORCE_EQ(C_dims, A_dims);
54  } else if (B.is_same(C)) {
55  CAFFE_ENFORCE_EQ(C_dims, B_dims);
56  } else {
57  C.Resize(C_dims);
58  }
59  }
60  auto* C_data = C.template mutable_data<DataType>();
61 
62  caffe2::math::Add(
63  A_dims.size(),
64  A_dims.data(),
65  B_dims.size(),
66  B_dims.data(),
67  A.data<DataType>(),
68  B.data<DataType>(),
69  C.mutable_data<DataType>(),
70  static_cast<CPUContext*>(&context));
71 }
72 } // namespace
73 } // namespace caffe2
74 
75 namespace c10 {
76 C10_REGISTER_KERNEL(caffe2::ops::Add)
77  .kernel<decltype(caffe2::add_op_cpu_impl<float>), &caffe2::add_op_cpu_impl<float>>()
78  .dispatchKey(CPUTensorId());
79 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7