1 #include "caffe2/operators/distance_op.h" 2 #include "caffe2/utils/eigen_utils.h" 3 #ifdef CAFFE2_USE_MKLDNN 4 #include <caffe2/ideep/operators/operator_fallback_ideep.h> 5 #include <caffe2/ideep/utils/ideep_operator.h> 11 bool SquaredL2DistanceOp<float, CPUContext>::RunOnDevice() {
15 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
16 for (
int i = 0; i < X.dim(); ++i) {
17 CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i));
19 int N = X.dim() > 0 ? X.dim32(0) : 1;
20 auto* distance = Output(0, {N}, at::dtype<float>());
21 int D = N > 0 ? X.numel() / N : 0;
22 float* distance_data = distance->template mutable_data<float>();
23 const float* X_data = X.data<
float>();
24 const float* Y_data = Y.data<
float>();
25 for (
int i = 0; i < N; ++i) {
26 float Xscale, Yscale, cross;
27 math::Dot<float, CPUContext>(
28 D, X_data + i * D, X_data + i * D, &Xscale, &context_);
29 math::Dot<float, CPUContext>(
30 D, Y_data + i * D, Y_data + i * D, &Yscale, &context_);
31 math::Dot<float, CPUContext>(
32 D, X_data + i * D, Y_data + i * D, &cross, &context_);
33 distance_data[i] = (Xscale + Yscale) * 0.5 - cross;
39 bool L1DistanceOp<float, CPUContext>::RunOnDevice() {
43 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
44 for (
int i = 0; i < X.dim(); ++i) {
45 CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i));
47 int N = X.dim() > 0 ? X.dim32(0) : 1;
48 auto* distance = Output(0, {N}, at::dtype<float>());
49 int D = N > 0 ? X.numel() / N : 0;
51 const float* X_data = X.data<
float>();
52 const float* Y_data = Y.data<
float>();
54 for (
int i = 0; i < N; ++i) {
55 (distance->template mutable_data<float>())[i] =
56 (ConstEigenVectorMap<float>(X_data + i * D, D).array() -
57 ConstEigenVectorMap<float>(Y_data + i * D, D).array())
65 bool L1DistanceGradientOp<float, CPUContext>::RunOnDevice() {
68 auto& dDistance = Input(2);
70 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
71 for (
int i = 0; i < X.dim(); ++i) {
72 CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i));
74 int N = X.dim() > 0 ? X.dim32(0) : 1;
75 int D = N > 0 ? X.numel() / N : 0;
76 CAFFE_ENFORCE(X.dim() == Y.dim());
77 for (
int i = 0; i < X.dim(); ++i) {
78 CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
80 CAFFE_ENFORCE(dDistance.dim() == 1);
81 CAFFE_ENFORCE(dDistance.dim32(0) == N);
82 auto* dX = Output(0, X.sizes(), at::dtype<float>());
83 auto* dY = Output(1, Y.sizes(), at::dtype<float>());
85 for (
int i = 0; i < N; ++i) {
87 for (
int j = 0; j < D; ++j) {
89 (X.data<
float>())[offset + j] - (Y.data<
float>())[offset + j];
90 const float kEps = 1e-12f;
92 dX->template mutable_data<float>()[offset + j] =
93 -(dDistance.data<
float>())[i];
94 dY->template mutable_data<float>()[offset + j] =
95 (dDistance.data<
float>())[i];
96 }
else if (temp > kEps) {
97 dX->template mutable_data<float>()[offset + j] =
98 (dDistance.data<
float>())[i];
99 dY->template mutable_data<float>()[offset + j] =
100 -(dDistance.data<
float>())[i];
102 dX->template mutable_data<float>()[offset + j] = 0;
103 dY->template mutable_data<float>()[offset + j] = 0;
111 bool CosineSimilarityOp<float, CPUContext>::RunOnDevice() {
112 auto& X = Input(X_IN);
113 auto& Y = Input(Y_IN);
115 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
116 for (
int i = 0; i < X.dim(); ++i) {
117 CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i));
119 const int N = X.dim() > 0 ? X.dim32(0) : 1;
120 const int D = X.size_from_dim(1);
121 auto* result = Output(COS_OUT, {N}, at::dtype<float>());
122 float* result_data = result->template mutable_data<float>();
123 const float* X_data = X.data<
float>();
124 const float* Y_data = Y.data<
float>();
126 const float kEps = 1e-12f;
127 for (
int i = 0; i < N; ++i) {
129 math::Dot<float, CPUContext>(
130 D, X_data + offset, X_data + offset, &X2, &context_);
131 math::Dot<float, CPUContext>(
132 D, Y_data + offset, Y_data + offset, &Y2, &context_);
133 math::Dot<float, CPUContext>(
134 D, X_data + offset, Y_data + offset, result_data + i, &context_);
135 result_data[i] /= std::sqrt(std::max(X2, kEps) * std::max(Y2, kEps));
141 bool CosineSimilarityGradientOp<float, CPUContext>::RunOnDevice() {
142 auto& X = Input(X_IN);
143 auto& Y = Input(Y_IN);
144 auto& dCos = Input(DER_COS_IN);
146 const int N = X.dim() > 0 ? X.dim32(0) : 1;
147 const int D = X.size_from_dim(1);
148 CAFFE_ENFORCE(X.dim() == Y.dim());
149 for (
int i = 0; i < X.dim(); ++i) {
150 CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
152 CAFFE_ENFORCE(dCos.dim() == 1);
153 CAFFE_ENFORCE(dCos.dim32(0) == N);
154 auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<float>());
155 auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<float>());
157 const auto* X_data = X.template data<float>();
158 const auto* Y_data = Y.template data<float>();
159 const auto* dCos_data = dCos.template data<float>();
160 auto* dX_data = dX->template mutable_data<float>();
161 auto* dY_data = dY->template mutable_data<float>();
163 const float kEps = 1e-12f;
164 for (
int i = 0; i < N; ++i) {
169 math::Dot<float, CPUContext>(
170 D, X_data + offset, X_data + offset, &XN, &context_);
171 XN = std::sqrt(std::max(XN, kEps));
173 math::Dot<float, CPUContext>(
174 D, Y_data + offset, Y_data + offset, &YN, &context_);
175 YN = std::sqrt(std::max(YN, kEps));
179 math::Dot<float, CPUContext>(
180 D, X_data + offset, Y_data + offset, &XY, &context_);
182 math::Scale<float, float, CPUContext>(
183 D, dCos_data[i] / XYN, Y_data + offset, dX_data + offset, &context_);
186 -dCos_data[i] * XY / (XN * XN * XYN),
191 math::Scale<float, float, CPUContext>(
192 D, dCos_data[i] / XYN, X_data + offset, dY_data + offset, &context_);
195 -dCos_data[i] * XY / (YN * YN * XYN),
205 bool DotProductOp<float, CPUContext>::RunOnDevice() {
206 auto& X = Input(X_IN);
207 auto& Y = Input(Y_IN);
209 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
210 for (
int i = 0; i < X.dim(); ++i) {
211 CAFFE_ENFORCE_EQ(X.dim32(i), Y.dim32(i),
"dimension at ", i);
215 N = X.dim() > 0 ? X.dim32(0) : 1;
221 auto* result = Output(DOT_OUT, {N}, at::dtype<float>());
222 float* result_data = result->template mutable_data<float>();
223 const float* X_data = X.template data<float>();
224 const float* Y_data = Y.template data<float>();
225 for (
int i = 0; i < N; ++i) {
227 math::Dot<float, CPUContext>(
228 D, X_data + offset, Y_data + offset, result_data + i, &context_);
233 vector<TensorShape> TensorInferenceForDotProduct(
235 const vector<TensorShape>& in) {
236 CAFFE_ENFORCE_GT(in.size(), 0);
238 vector<int64_t> dims(1);
239 dims[0] = in[0].dims().size() > 0 ? in[0].dims(0) : 1;
240 return vector<TensorShape>{CreateTensorShape(dims, in[0].data_type())};
243 OpSchema::Cost CostInferenceForDotProduct(
244 const OperatorDef& def,
245 const vector<TensorShape>& in) {
246 std::vector<TensorShape> out = TensorInferenceForDotProduct(def, in);
247 CAFFE_ENFORCE_GT(out.size(), 0);
248 CAFFE_ENFORCE_EQ(out[0].dims().size(), 1);
250 struct OpSchema::Cost c = PointwiseCostInference<2>(def, in);
251 c.bytes_written = out[0].dims(0) *
sizeof(out[0].data_type());
257 bool DotProductGradientOp<float, CPUContext>::RunOnDevice() {
258 auto& X = Input(X_IN);
259 auto& Y = Input(Y_IN);
260 auto& dDot = Input(DER_DOT_IN);
264 N = X.dim() > 0 ? X.dim32(0) : 1;
270 CAFFE_ENFORCE(X.dim() == Y.dim());
271 for (
int i = 0; i < X.dim(); ++i) {
272 CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
274 CAFFE_ENFORCE(dDot.dim() == 1);
275 CAFFE_ENFORCE(dDot.dim32(0) == N);
276 auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<float>());
277 auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<float>());
279 const auto* X_data = X.template data<float>();
280 const auto* Y_data = Y.template data<float>();
281 const auto* dDot_data = dDot.template data<float>();
282 auto* dX_data = dX->template mutable_data<float>();
283 auto* dY_data = dY->template mutable_data<float>();
284 for (
int i = 0; i < N; ++i) {
286 math::Scale<float, float, CPUContext>(
287 D, dDot_data[i], X_data + offset, dY_data + offset, &context_);
288 math::Scale<float, float, CPUContext>(
289 D, dDot_data[i], Y_data + offset, dX_data + offset, &context_);
295 bool DotProductWithPaddingOp<float, CPUContext>::RunOnDevice() {
296 auto& X = Input(X_IN);
297 auto& Y = Input(Y_IN);
299 CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
300 CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
302 int N, D, DX, DY, restD;
304 N = X.dim() > 0 ? X.dim32(0) : 1;
313 D = std::min(DX, DY);
314 restD = std::max(DX, DY) - D;
315 auto* result = Output(DOT_OUT, {N}, at::dtype<float>());
316 float* result_data = result->template mutable_data<float>();
317 const float* X_data = X.data<
float>();
318 const float* Y_data = Y.data<
float>();
319 for (
int i = 0; i < N; ++i) {
320 auto offsetX = i * DX, offsetY = i * DY;
323 const float *L_data, *S_data;
326 L_data = X_data + offsetX;
327 S_data = Y_data + offsetY;
331 L_data = Y_data + offsetY;
332 S_data = X_data + offsetX;
338 for (
int j = 0; j < DL / DS; j++) {
339 math::Dot<float, CPUContext>(
340 DS, L_data + j * DS, S_data, &tmp, &context_);
343 *(result_data + i) = sum;
345 math::Dot<float, CPUContext>(
346 D, X_data + offsetX, Y_data + offsetY, result_data + i, &context_);
349 if (!replicate_ && DX != DY) {
350 const float* rest_data;
353 rest_data = X_data + offsetX + D;
355 rest_data = Y_data + offsetY + D;
357 math::Sum<float, CPUContext>(restD, rest_data, &rest_sum, &context_);
358 result_data[i] += rest_sum * pad_value_;
365 REGISTER_CPU_OPERATOR(SquaredL2Distance,
366 SquaredL2DistanceOp<float, CPUContext>);
367 REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient,
368 SquaredL2DistanceGradientOp<float, CPUContext>);
370 OPERATOR_SCHEMA(SquaredL2Distance)
373 .IdenticalTypeAndShapeOfInputDim(0, 0)
375 Given two input float tensors X, Y, and produces one output float tensor 376 of the L2 difference between X and Y that is computed as ||(X - Y)^2 / 2||. 378 .Input(0, "X",
"1D or 2D input tensor")
379 .Input(1,
"Y",
"1D or 2D input tensor (must have the same shape as X)")
380 .Output(0,
"Z",
"1D output tensor");
382 OPERATOR_SCHEMA(SquaredL2DistanceGradient).NumInputs(3).NumOutputs(2);
385 using GradientMakerBase::GradientMakerBase;
386 vector<OperatorDef> GetGradientDefs()
override {
388 "SquaredL2DistanceGradient",
"",
389 vector<string>{I(0), I(1), GO(0)},
390 vector<string>{GI(0), GI(1)});
397 REGISTER_CPU_OPERATOR(
400 #ifdef CAFFE2_USE_MKLDNN 401 REGISTER_IDEEP_OPERATOR(
406 OPERATOR_SCHEMA(L1Distance)
409 .IdenticalTypeAndShapeOfInputDim(0, 0)
411 Computes the row-wise L1 Distance between the two input tensors $X$ and $Y$, which is defined as 413 $$L1Distance(\mathbf{x},\mathbf{y}) = \sum_{i}\mid x_i - y_i\mid$$ 415 Note, both inputs must either be 1-dimensional or 2-dimensional and both must have the same shape. The output $Z$ will be 1-dimensional regardless and its length will equal the number of rows in the inputs. 418 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/distance_op.h 419 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/distance_op.cc 423 <summary> <b>Example</b> </summary> 429 workspace.ResetWorkspace() 431 op = core.CreateOperator( 438 X = 5*np.ones((1, 4)) 445 // Feed X & Y into workspace 446 workspace.FeedBlob("X", X.astype(np.float32)) 447 workspace.FeedBlob("Y", Y.astype(np.float32)) 450 workspace.RunOperatorOnce(op) 453 print("Z:\n", workspace.FetchBlob("Z")) 473 .Input(0, "X",
"First input tensor. (1D or 2D)")
474 .Input(1,
"Y",
"Second input tensor. (must have the same shape as $X$)")
475 .Output(0,
"Z",
"1D output tensor. One value for each row of the inputs.");
477 OPERATOR_SCHEMA(L1DistanceGradient).NumInputs(3).NumOutputs(2);
480 using GradientMakerBase::GradientMakerBase;
481 vector<OperatorDef> GetGradientDefs()
override {
483 "L1DistanceGradient",
485 vector<string>{I(0), I(1), GO(0)},
486 vector<string>{GI(0), GI(1)});
494 REGISTER_CPU_OPERATOR(
498 OPERATOR_SCHEMA(DotProduct)
501 .IdenticalTypeAndShapeOfInputDim(0, 0)
503 Computes and outputs the dot product of the two input float tensors `X` and `Y`. 504 Note that `X` and `Y` must be either 1D or 2D, and they must be the same shape. 505 The output tensor is 1D, which represents either the product of each element in 506 a respective dimension if the inputs are 1D, or the sum of the products in a 507 given dimension if the inputs are 2D matrices. Note that the actual dot product 508 is a scalar value, which is effectively the sum of the elements in the 1D 512 Given two vectors $X = [x_0, x_1, x_2]$ and $Y = [y_0, y_1, y_2]$; $Z = [x_0 * y_0, x_1 * y_1, x_2 * y_2]$ 516 $$X = [[x_0^0, x_1^0, x_2^0], \\ [x_0^1, x_1^1, x_2^1], \\ [x_0^2, x_1^2, x_2^2], \\ ..., \\ [x_0^n, x_1^n, x_2^n]]$$ 520 $$Y = [[y_0^0, y_1^0, y_2^0], \\ [y_0^1, y_1^1, y_2^1], \\ [y_0^2, y_1^2, y_2^2], \\ ..., \\ [y_0^n, y_1^n, y_2^n]]$$ 524 $$Z = \biggl[\Big((x_0^0 * y_0^0) + (x_1^0 * y_1^0) + (x_2^0 * y_2^0)\Big), \\ \Big((x_0^1 * y_0^1) + (x_1^1 * y_1^1) + (x_2^1 * y_2^1)\Big), \\ \Big((x_0^2 * y_0^2) + (x_1^2 * y_1^2) + (x_2^2 * y_2^2)\Big), \\ ..., \\ \Big((x_0^n * y_0^n) + (x_1^n * y_1^n) + (x_2^n * y_2^n)\Big)\biggr]$$ 527 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/distance_op.cc 531 <summary> <b>Example</b> </summary> 537 workspace.ResetWorkspace() 539 op = core.CreateOperator( 545 workspace.FeedBlob("X", np.random.randint(20, size=(5)).astype(np.float32)) 546 workspace.FeedBlob("Y", np.random.randint(20, size=(5)).astype(np.float32)) 547 print("X:\n", workspace.FetchBlob("X")) 548 print("Y:\n", workspace.FetchBlob("Y")) 549 workspace.RunOperatorOnce(op) 550 print("Z:\n", workspace.FetchBlob("X")) 553 workspace.ResetWorkspace() 554 workspace.FeedBlob("X", np.random.randint(10, size=(3,3)).astype(np.float32)) 555 workspace.FeedBlob("Y", np.random.randint(10, size=(3,3)).astype(np.float32)) 556 print("X:\n", workspace.FetchBlob("X")) 557 print("Y:\n", workspace.FetchBlob("Y")) 558 workspace.RunOperatorOnce(op) 559 print("Z:\n", workspace.FetchBlob("Z")) 589 .Input(0, "X",
"*(type: Tensor`<float>`)* 1D or 2D input tensor.")
593 "*(type: Tensor`<float>`)* 1D or 2D input tensor (must have the same shape as X).")
594 .Output(0,
"Z",
"*(type: Tensor`<float>`)* 1D output tensor.")
595 .TensorInferenceFunction(TensorInferenceForDotProduct)
596 .CostInferenceFunction(
598 .InheritOnnxSchema();
600 OPERATOR_SCHEMA(DotProductGradient).NumInputs(3).NumOutputs(2);
603 using GradientMakerBase::GradientMakerBase;
604 vector<OperatorDef> GetGradientDefs()
override {
606 "DotProductGradient",
608 vector<string>{I(0), I(1), GO(0)},
609 vector<string>{GI(0), GI(1)});
616 REGISTER_CPU_OPERATOR(
617 CosineSimilarityGradient,
620 OPERATOR_SCHEMA(CosineSimilarity)
623 .IdenticalTypeAndShapeOfInputDim(0, 0)
625 This op takes two input float tensors of the same size, $X$ and $Y$, and produces one output float tensor , $Z$, calculated as the cosine similarity between $X$ and $Y$. Recall, the cosine similarity between two tensors $X$ and $Y$ is defined as: 627 $$\mathbf{Z}=CosineSimilarity(\mathbf{X},\mathbf{Y}) = \frac{\mathbf{X}\cdot\mathbf{Y}}{\|\mathbf{X}\|\|\mathbf{Y}\|} = \frac{\sum_n^{i=1}X_iY_i}{\sqrt{\sum_n^{i=1}X_i^2}\sqrt{\sum_n^{i=1}Y_i^2}}$$ 630 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/distance_op.h 631 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/distance_op.cc 635 <summary> <b>Example</b> </summary> 641 workspace.ResetWorkspace() 643 op = core.CreateOperator( 650 X = np.random.randn(3, 3) 654 Y = np.random.randn(3, 3) 657 // Feed X & Y into workspace 658 workspace.FeedBlob("X", X.astype(np.float32)) 659 workspace.FeedBlob("Y", Y.astype(np.float32)) 662 workspace.RunOperatorOnce(op) 665 print("Z:\n", workspace.FetchBlob("Z")) 674 [[-0.42635564 -0.23831588 -0.25515547] 675 [ 1.43914719 -1.05613228 1.01717373] 676 [ 0.06883105 0.33386519 -1.46648334]] 678 [[-0.90648691 -0.14241514 -1.1070837 ] 679 [ 0.92152729 -0.28115511 -0.17756722] 680 [-0.88394254 1.34654037 -0.80080998]] 682 [-1.7849885e-23 1.7849885e-23 -1.0842022e-07] 689 .Input(0, "X",
"1D or 2D input tensor")
690 .Input(1,
"Y",
"1D or 2D input tensor (must have the same shape as X)")
691 .Output(0,
"Z",
"1D output tensor");
693 OPERATOR_SCHEMA(CosineSimilarityGradient).NumInputs(3).NumOutputs(2);
696 using GradientMakerBase::GradientMakerBase;
697 vector<OperatorDef> GetGradientDefs()
override {
699 "CosineSimilarityGradient",
701 vector<string>{I(0), I(1), GO(0)},
702 vector<string>{GI(0), GI(1)});
708 REGISTER_CPU_OPERATOR(
709 DotProductWithPadding,
711 REGISTER_CPU_OPERATOR(
712 DotProductWithPaddingGradient,
715 OPERATOR_SCHEMA(DotProductWithPadding)
719 Given two input float tensors X, Y with different shapes and produces one 720 output float tensor of the dot product between X and Y. We currently support 721 two kinds of strategies to achieve this. Before doing normal dot_product 1) 722 pad the smaller tensor (using pad_value) to the same shape as the other one. 723 2) replicate the smaller tensor to the same shape as the other one. Note the 724 first dimension of X, Y must be equal. Only the second dimension of X or Y 727 .Input(0, "X",
"1D or 2D input tensor")
728 .Input(1,
"Y",
"1D or 2D input tensor")
729 .Output(0,
"Z",
"1D output tensor")
730 .IdenticalTypeAndShapeOfInputDim(0, 0)
731 .Arg(
"pad_value",
"the padding value for tensors with smaller dimension")
732 .Arg(
"replicate",
"whether to replicate the smaller tensor or not");
734 OPERATOR_SCHEMA(DotProductWithPaddingGradient).NumInputs(3).NumOutputs(2);
737 using GradientMakerBase::GradientMakerBase;
738 vector<OperatorDef> GetGradientDefs()
override {
740 bool replicate =
false;
741 if (ArgumentHelper::HasArgument(Def(),
"pad_value")) {
742 pad_value = GetArgument(Def(),
"pad_value").f();
744 if (ArgumentHelper::HasArgument(Def(),
"replicate")) {
745 replicate = GetArgument(Def(),
"replicate").i();
749 vector<Argument>{MakeArgument<float>(
"pad_value", pad_value),
750 MakeArgument<bool>(
"replicate", replicate)};
753 "DotProductWithPaddingGradient",
755 vector<string>{I(0), I(1), GO(0)},
756 vector<string>{GI(0), GI(1)},
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
A templated class to allow one to wrap a CPU operator as an IDEEP operator.
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...