1 #ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_ 2 #define CAFFE2_OPERATORS_DISTANCE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
13 template <
class... Args>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 bool RunOnDevice()
override;
24 template <
typename T,
class Context>
27 template <
class... Args>
30 USE_OPERATOR_CONTEXT_FUNCTIONS;
32 bool RunOnDevice()
override {
35 auto& dDistance =
Input(2);
37 int N = X.dim() > 0 ? X.dim32(0) : 1;
38 int D = N > 0 ? X.numel() / N : 0;
39 CAFFE_ENFORCE(X.dim() == Y.dim());
40 for (
int i = 0; i < X.dim(); ++i) {
41 CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
43 CAFFE_ENFORCE(dDistance.dim() == 1);
44 CAFFE_ENFORCE(dDistance.dim32(0) == N);
45 auto* dX = Output(0, X.sizes(), at::dtype<T>());
46 auto* dY = Output(1, Y.sizes(), at::dtype<T>());
47 math::Sub<T, Context>(
51 dX->template mutable_data<T>(),
53 for (
int i = 0; i < N; ++i) {
54 math::Scale<T, T, Context>(
56 dDistance.template data<T>() + i,
57 dX->template data<T>() + i * D,
58 dX->template mutable_data<T>() + i * D,
62 math::Scale<T, T, Context>(
65 dX->template data<T>(),
66 dY->template mutable_data<T>(),
75 template <
typename T,
class Context>
78 template <
class... Args>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
83 bool RunOnDevice()
override;
89 template <
typename T,
class Context>
92 template <
class... Args>
95 USE_OPERATOR_CONTEXT_FUNCTIONS;
97 bool RunOnDevice()
override;
103 template <
typename T,
class Context>
106 template <
class... Args>
109 USE_OPERATOR_CONTEXT_FUNCTIONS;
111 bool RunOnDevice()
override;
114 INPUT_TAGS(X_IN, Y_IN);
115 OUTPUT_TAGS(DOT_OUT);
118 template <
typename T,
class Context>
121 template <
class... Args>
124 USE_OPERATOR_CONTEXT_FUNCTIONS;
126 bool RunOnDevice()
override;
129 INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
130 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
133 template <
typename T,
class Context>
136 template <
class... Args>
139 pad_value_(this->
template GetSingleArgument<float>(
"pad_value", 0.0)),
140 replicate_(this->
template GetSingleArgument<bool>(
"replicate",
false)) {
142 USE_OPERATOR_CONTEXT_FUNCTIONS;
144 bool RunOnDevice()
override;
149 INPUT_TAGS(X_IN, Y_IN);
150 OUTPUT_TAGS(DOT_OUT);
153 template <
typename T,
class Context>
156 template <
class... Args>
159 USE_OPERATOR_CONTEXT_FUNCTIONS;
161 bool RunOnDevice()
override;
164 INPUT_TAGS(X_IN, Y_IN);
165 OUTPUT_TAGS(COS_OUT);
171 template <
typename T,
class Context>
174 template <
class... Args>
177 USE_OPERATOR_CONTEXT_FUNCTIONS;
179 bool RunOnDevice()
override;
182 INPUT_TAGS(X_IN, Y_IN, DER_COS_IN);
183 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
189 template <
typename T,
class Context>
192 template <
class... Args>
195 pad_value_(this->
template GetSingleArgument<float>(
"pad_value", 0.0)),
196 replicate_(this->
template GetSingleArgument<bool>(
"replicate",
false)) {
198 USE_OPERATOR_CONTEXT_FUNCTIONS;
200 bool RunOnDevice()
override {
201 auto& X =
Input(X_IN);
202 auto& Y =
Input(Y_IN);
203 auto& dDot =
Input(DER_DOT_IN);
205 int N,
D, DX, DY, restD;
207 N = X.dim() > 0 ? X.dim32(0) : 1;
215 CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0);
216 D = std::min(DX, DY);
217 restD = std::max(DX, DY) - D;
218 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
219 CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
220 CAFFE_ENFORCE_EQ(dDot.dim(), 1);
221 CAFFE_ENFORCE_EQ(dDot.dim32(0), N);
222 auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<T>());
223 auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<T>());
225 const auto* X_data = X.template data<T>();
226 const auto* Y_data = Y.template data<T>();
227 const auto* dDot_data = dDot.template data<T>();
228 auto* dX_data = dX->template mutable_data<T>();
229 auto* dY_data = dY->template mutable_data<T>();
230 for (
int i = 0; i < N; ++i) {
231 auto offsetX = i * DX;
232 auto offsetY = i * DY;
235 const T *L_data, *S_data;
236 T *dL_data, *dS_data;
239 L_data = X_data + offsetX;
240 S_data = Y_data + offsetY;
241 dL_data = dX_data + offsetX;
242 dS_data = dY_data + offsetY;
246 L_data = Y_data + offsetY;
247 S_data = X_data + offsetX;
248 dL_data = dY_data + offsetY;
249 dS_data = dX_data + offsetX;
255 std::vector<T> tmp_data(DS);
256 math::Set<T, Context>(DS, 0.0, dS_data, &context_);
257 for (
int j = 0; j < DL / DS; j++) {
258 math::Scale<T, T, Context>(
259 DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
260 math::Scale<T, T, Context>(
261 DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
262 math::Axpy<T, Context>(DS, 1.0, tmp_data.data(), dS_data, &context_);
265 math::Scale<T, T, Context>(
266 D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
267 math::Scale<T, T, Context>(
268 D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
271 if (!replicate_ && DX != DY) {
274 rest_data = dX_data + offsetX + D;
276 rest_data = dY_data + offsetY + D;
278 auto pad_gradient = dDot_data[i] * pad_value_;
279 math::Set<T, Context>(restD, pad_gradient, rest_data, &context_);
289 INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
290 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
295 #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...