1 #include "caffe2/core/context.h" 2 #include <ATen/core/dispatch/KernelRegistration.h> 3 #include "caffe2/core/operator.h" 4 #include "caffe2/operators/experimental/c10/schemas/fc.h" 5 #include "caffe2/utils/conversions.h" 6 #include "caffe2/utils/math.h" 7 #include "caffe2/core/tensor.h" 16 vector<int64_t> Y_shape_cache_;
20 template <
class DataType,
class Context>
35 constexpr
bool TransposeWeight =
true;
37 CAFFE_ENFORCE(b.dim() == 1, b.dim());
39 const auto canonical_axis = X.canonical_axis_index(axis);
40 const auto M = X.size_to_dim(canonical_axis);
41 const auto K = X.size_from_dim(canonical_axis);
42 const auto canonical_axis_w = W.canonical_axis_index(axis_w);
43 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
44 : W.size_from_dim(canonical_axis_w);
46 auto dimErrorString = [&]() {
48 "Dimension mismatch: ",
66 CAFFE_ENFORCE(
M == X.numel() / K, dimErrorString());
67 CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
68 CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
69 CAFFE_ENFORCE(N == b.numel(), dimErrorString());
71 cache->Y_shape_cache_ = X.sizes().vec();
73 DCHECK_LE(canonical_axis + 1, cache->Y_shape_cache_.size());
74 cache->Y_shape_cache_.resize(canonical_axis + 1);
75 cache->Y_shape_cache_[canonical_axis] = N;
76 Y.Resize(cache->Y_shape_cache_);
77 CAFFE_ENFORCE(
M * N == Y.numel(), dimErrorString());
81 Y.template mutable_data<DataType>();
86 caffe2::TensorProto::DataType math_type = caffe2::TensorProto_DataType_FLOAT;
87 if (caffe2::fp16_type<DataType>()) {
88 math_type = caffe2::TensorProto_DataType_FLOAT16;
92 caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
94 TransposeWeight ? CblasTrans : CblasNoTrans,
99 X.template data<DataType>(),
100 W.template data<DataType>(),
102 Y.template mutable_data<DataType>(),
103 static_cast<Context*>(&context),
106 Tensor bias_multiplier(cache->bias_multiplier_);
108 caffe2::math::Set<DataType, Context>(
110 caffe2::convert::To<float, DataType>(1),
111 bias_multiplier.template mutable_data<DataType>(),
112 static_cast<Context*
>(&context));
113 caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
120 bias_multiplier.template data<DataType>(),
121 b.template data<DataType>(),
123 Y.template mutable_data<DataType>(),
124 static_cast<Context*>(&context),
131 C10_REGISTER_KERNEL(caffe2::ops::FullyConnected)
132 .withCache<caffe2::Cache>()
133 .kernel<decltype(caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::fc_op_cpu_impl<float, caffe2::CPUContext>>()
134 .dispatchKey(CPUTensorId());
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
A kernel can keep around a cache to have better performance when it's called multiple times...