Caffe2 - C++ API
A deep learning, cross platform ML framework
fc_cpu.cc
1 #include "caffe2/core/context.h"
2 #include <ATen/core/dispatch/KernelRegistration.h>
3 #include "caffe2/core/operator.h"
4 #include "caffe2/operators/experimental/c10/schemas/fc.h"
5 #include "caffe2/utils/conversions.h"
6 #include "caffe2/utils/math.h"
7 #include "caffe2/core/tensor.h"
8 
10 using caffe2::Tensor;
11 
12 namespace caffe2 {
13 namespace {
14 
15 struct Cache final : public c10::KernelCache {
16  vector<int64_t> Y_shape_cache_;
17  at::Tensor bias_multiplier_ = at::Tensor(C10Tensor(Tensor()));
18 };
19 
20 template <class DataType, class Context>
21 void fc_op_cpu_impl(
22  const at::Tensor& X_,
23  const at::Tensor& W_,
24  const at::Tensor& b_,
25  const at::Tensor& Y_,
26  int64_t axis,
27  int64_t axis_w,
28  Cache* cache) {
29  Tensor X{C10Tensor(X_)};
30  Tensor W{C10Tensor(W_)};
31  Tensor b{C10Tensor(b_)};
32  Tensor Y{C10Tensor(Y_)};
33  CPUContext context;
34 
35  constexpr bool TransposeWeight = true;
36 
37  CAFFE_ENFORCE(b.dim() == 1, b.dim());
38  // batch size
39  const auto canonical_axis = X.canonical_axis_index(axis);
40  const auto M = X.size_to_dim(canonical_axis);
41  const auto K = X.size_from_dim(canonical_axis);
42  const auto canonical_axis_w = W.canonical_axis_index(axis_w);
43  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
44  : W.size_from_dim(canonical_axis_w);
45 
46  auto dimErrorString = [&]() {
47  return c10::str(
48  "Dimension mismatch: ",
49  "X: ",
50  X.sizes(),
51  ", W: ",
52  W.sizes(),
53  ", b: ",
54  b.sizes(),
55  ", axis: ",
56  axis,
57  ", M: ",
58  M,
59  ", N: ",
60  N,
61  ", K: ",
62  K);
63  };
64 
65  // Error checking
66  CAFFE_ENFORCE(M == X.numel() / K, dimErrorString());
67  CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
68  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
69  CAFFE_ENFORCE(N == b.numel(), dimErrorString());
70 
71  cache->Y_shape_cache_ = X.sizes().vec();
72  // This is an invariant of canonical_axis, so we can DCHECK.
73  DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size());
74  cache->Y_shape_cache_.resize(canonical_axis + 1);
75  cache->Y_shape_cache_[canonical_axis] = N;
76  Y.Resize(cache->Y_shape_cache_);
77  CAFFE_ENFORCE(M * N == Y.numel(), dimErrorString());
78 
79  if (X.numel() == 0) {
80  // skip the rest of the computation if X is empty
81  Y.template mutable_data<DataType>();
82  return;
83  }
84 
85  // default to FLOAT as math.h does.
86  caffe2::TensorProto::DataType math_type = caffe2::TensorProto_DataType_FLOAT;
87  if (caffe2::fp16_type<DataType>()) {
88  math_type = caffe2::TensorProto_DataType_FLOAT16;
89  }
90 
91  // W * x
92  caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
93  CblasNoTrans,
94  TransposeWeight ? CblasTrans : CblasNoTrans,
95  M,
96  N,
97  K,
98  1,
99  X.template data<DataType>(),
100  W.template data<DataType>(),
101  0,
102  Y.template mutable_data<DataType>(),
103  static_cast<Context*>(&context),
104  math_type);
105  // Add bias term
106  Tensor bias_multiplier(cache->bias_multiplier_);
107  ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
108  caffe2::math::Set<DataType, Context>(
109  M,
110  caffe2::convert::To<float, DataType>(1),
111  bias_multiplier.template mutable_data<DataType>(),
112  static_cast<Context*>(&context));
113  caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
114  CblasNoTrans,
115  CblasNoTrans,
116  M,
117  N,
118  1,
119  1,
120  bias_multiplier.template data<DataType>(),
121  b.template data<DataType>(),
122  1,
123  Y.template mutable_data<DataType>(),
124  static_cast<Context*>(&context),
125  math_type);
126 }
127 } // namespace
128 } // namespace caffe2
129 
130 namespace c10 {
131 C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
132  .withCache<caffe2::Cache>()
133  .kernel<decltype(caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
134  .dispatchKey(CPUTensorId());
135 } // namespace c10
Definition: any.cpp:108
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
A kernel can keep around a cache to have better performance when it&#39;s called multiple times...
Definition: KernelCache.h:15