17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_ 18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 33 template <
typename T,
class Context,
class Engine=DefaultEngine>
36 USE_OPERATOR_CONTEXT_FUNCTIONS;
41 bool RunOnDevice()
override {
42 const auto& X =
Input(0);
43 const auto& U =
Input(1);
44 const auto& V =
Input(2);
45 const auto& b =
Input(3);
50 CAFFE_ENFORCE_GE(X.dim(), 1);
51 CAFFE_ENFORCE_GE(U.dim(), 2);
52 CAFFE_ENFORCE_GE(V.dim(), 2);
53 if (X.dim() > 2 || U.dim() > 2 || V.dim() > 2) {
54 VLOG(1) <<
"Using legacy support for arbitrary input and weight " 57 CAFFE_ENFORCE_EQ(b.dim(), 1);
59 int M = X.dim() > 1 ? X.dim32(0) : 1;
61 int K = X.numel() / M;
64 int middle = U.dim32(0);
65 CAFFE_ENFORCE_EQ(K, V.dim32(0));
66 CAFFE_ENFORCE_EQ(N, b.dim32(0));
67 std::vector<int64_t> dims;
70 multi_buffer_.Resize(M, middle);
73 multi_buffer_.Resize(middle);
75 auto* Y = Output(0, dims, at::dtype<T>());
79 T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
81 math::Gemm<T, Context, Engine>(
82 CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
83 V.template data<T>(), 0, multi_buffer_data,
85 math::Gemm<T, Context, Engine>(
86 CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
87 U.template data<T>(), 0, Y->template mutable_data<T>(),
90 if (bias_multiplier_.numel() != M) {
92 bias_multiplier_.Resize(M);
93 math::Set<T, Context>(
94 M,
static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
97 math::Gemm<T, Context, Engine>(
98 CblasNoTrans, CblasNoTrans, M, N, 1, 1,
99 bias_multiplier_.template data<T>(), b.template data<T>(), 1,
100 Y->template mutable_data<T>(), &context_);
105 Tensor bias_multiplier_{Context::GetDeviceType()};
106 Tensor multi_buffer_{Context::GetDeviceType()};
109 template <
typename T,
class Context,
class Engine=DefaultEngine>
112 USE_OPERATOR_CONTEXT_FUNCTIONS;
117 bool RunOnDevice()
override {
118 const auto& X =
Input(0);
119 const auto& U =
Input(1);
120 const auto& V =
Input(2);
121 const auto& dY =
Input(3);
122 DCHECK_GE(X.dim(), 1);
123 DCHECK_GE(U.dim(), 2);
124 DCHECK_GE(V.dim(), 2);
125 DCHECK_LE(dY.dim(), 2);
127 int M = X.dim() > 1 ? X.dim32(0) : 1;
129 int K = X.numel() / M;
132 int middle = U.dim32(1);
133 DCHECK_EQ(K, V.dim32(0));
135 DCHECK_EQ(M, dY.dim32(0));
136 DCHECK_EQ(N, dY.dim32(1));
138 DCHECK_EQ(X.dim(), 1);
139 DCHECK_EQ(N, dY.numel());
142 auto* dU = Output(0, U.sizes(), at::dtype<T>());
143 auto* dV = Output(1, V.sizes(), at::dtype<T>());
144 auto* db = Output(2, {N}, at::dtype<T>());
148 du_buffer_.Resize(N, middle);
149 T* du_buffer_data = du_buffer_.template mutable_data<T>();
150 math::Gemm<T, Context, Engine>(
151 CblasNoTrans, CblasNoTrans, M, middle, K, 1,
152 X.template data<T>(), V.template data<T>(),
155 math::Gemm<T, Context, Engine>(
156 CblasTrans, CblasNoTrans, N, middle, M, 1,
157 dY.template data<T>(), du_buffer_data,
158 0, dU->template mutable_data<T>(),
162 dv_buffer_.Resize(M, middle);
163 T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
164 math::Gemm<T, Context, Engine>(
165 CblasNoTrans, CblasNoTrans, M, middle, N, 1,
166 dY.template data<T>(), U.template data<T>(),
169 math::Gemm<T, Context, Engine>(
170 CblasTrans, CblasNoTrans, K, middle, M, 1,
171 dY.template data<T>(), du_buffer_data,
172 0, dV->template mutable_data<T>(),
174 if (bias_multiplier_.numel() != M) {
176 bias_multiplier_.Resize(M);
177 math::Set<T, Context>(
178 M,
static_cast<T>(1),
179 bias_multiplier_.template mutable_data<T>(),
183 math::Gemv<T, Context>(
184 CblasTrans, M, N, 1, dY.template data<T>(),
185 bias_multiplier_.template data<T>(), 0,
186 db->template mutable_data<T>(),
189 if (OutputSize() == 4) {
190 auto* dX = Output(3, X.sizes(), at::dtype<T>());
191 dx_buffer_.Resize(M, middle);
192 T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
193 math::Gemm<T, Context, Engine>(
194 CblasNoTrans, CblasNoTrans, M, middle, N, 1,
195 dY.template data<T>(), U.template data<T>(),
198 math::Gemm<T, Context, Engine>(
199 CblasNoTrans, CblasTrans, M, K, middle, 1,
200 dx_buffer_data, V.template data<T>(),
201 0, dX->template mutable_data<T>(),
209 Tensor bias_multiplier_{Context::GetDeviceType()};
210 Tensor du_buffer_{Context::GetDeviceType()};
211 Tensor dv_buffer_{Context::GetDeviceType()};
212 Tensor dx_buffer_{Context::GetDeviceType()};
217 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...