Caffe2 - C++ API
A deep learning, cross platform ML framework
distance_op.h
1 #ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_
2 #define CAFFE2_OPERATORS_DISTANCE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class SquaredL2DistanceOp : public Operator<Context> {
12  public:
13  template <class... Args>
14  explicit SquaredL2DistanceOp(Args&&... args)
15  : Operator<Context>(std::forward<Args>(args)...) {}
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17 
18  bool RunOnDevice() override;
19 
20  protected:
21  // Input: X, Y; Output: Distance
22 };
23 
24 template <typename T, class Context>
25 class SquaredL2DistanceGradientOp final : public Operator<Context> {
26  public:
27  template <class... Args>
28  explicit SquaredL2DistanceGradientOp(Args&&... args)
29  : Operator<Context>(std::forward<Args>(args)...) {}
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31 
32  bool RunOnDevice() override {
33  auto& X = Input(0);
34  auto& Y = Input(1);
35  auto& dDistance = Input(2);
36 
37  int N = X.dim() > 0 ? X.dim32(0) : 1;
38  int D = N > 0 ? X.numel() / N : 0;
39  CAFFE_ENFORCE(X.dim() == Y.dim());
40  for (int i = 0; i < X.dim(); ++i) {
41  CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
42  }
43  CAFFE_ENFORCE(dDistance.dim() == 1);
44  CAFFE_ENFORCE(dDistance.dim32(0) == N);
45  auto* dX = Output(0, X.sizes(), at::dtype<T>());
46  auto* dY = Output(1, Y.sizes(), at::dtype<T>());
47  math::Sub<T, Context>(
48  X.numel(),
49  X.template data<T>(),
50  Y.template data<T>(),
51  dX->template mutable_data<T>(),
52  &context_);
53  for (int i = 0; i < N; ++i) {
54  math::Scale<T, T, Context>(
55  D,
56  dDistance.template data<T>() + i,
57  dX->template data<T>() + i * D,
58  dX->template mutable_data<T>() + i * D,
59  &context_);
60  }
61  // The gradient of the other side is basically the negative.
62  math::Scale<T, T, Context>(
63  X.numel(),
64  -1,
65  dX->template data<T>(),
66  dY->template mutable_data<T>(),
67  &context_);
68  return true;
69  }
70 
71  protected:
72  // Input: X, Y, dDistance; Output: dX, dY
73 };
74 
75 template <typename T, class Context>
76 class L1DistanceOp : public Operator<Context> {
77  public:
78  template <class... Args>
79  explicit L1DistanceOp(Args&&... args)
80  : Operator<Context>(std::forward<Args>(args)...) {}
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82 
83  bool RunOnDevice() override;
84 
85  protected:
86  // Input: X, Y; Output: Distance
87 };
88 
89 template <typename T, class Context>
90 class L1DistanceGradientOp : public Operator<Context> {
91  public:
92  template <class... Args>
93  explicit L1DistanceGradientOp(Args&&... args)
94  : Operator<Context>(std::forward<Args>(args)...) {}
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96 
97  bool RunOnDevice() override;
98 
99  protected:
100  // Input: X, Y, dDistance; Output: dX, dY
101 };
102 
103 template <typename T, class Context>
104 class DotProductOp : public Operator<Context> {
105  public:
106  template <class... Args>
107  explicit DotProductOp(Args&&... args)
108  : Operator<Context>(std::forward<Args>(args)...) {}
109  USE_OPERATOR_CONTEXT_FUNCTIONS;
110 
111  bool RunOnDevice() override;
112 
113  protected:
114  INPUT_TAGS(X_IN, Y_IN);
115  OUTPUT_TAGS(DOT_OUT);
116 };
117 
118 template <typename T, class Context>
119 class DotProductGradientOp final : public Operator<Context> {
120  public:
121  template <class... Args>
122  explicit DotProductGradientOp(Args&&... args)
123  : Operator<Context>(std::forward<Args>(args)...) {}
124  USE_OPERATOR_CONTEXT_FUNCTIONS;
125 
126  bool RunOnDevice() override;
127 
128  protected:
129  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
130  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
131 };
132 
133 template <typename T, class Context>
134 class DotProductWithPaddingOp : public Operator<Context> {
135  public:
136  template <class... Args>
137  explicit DotProductWithPaddingOp(Args&&... args)
138  : Operator<Context>(std::forward<Args>(args)...),
139  pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)),
140  replicate_(this->template GetSingleArgument<bool>("replicate", false)) {
141  }
142  USE_OPERATOR_CONTEXT_FUNCTIONS;
143 
144  bool RunOnDevice() override;
145 
146  protected:
147  float pad_value_;
148  bool replicate_;
149  INPUT_TAGS(X_IN, Y_IN);
150  OUTPUT_TAGS(DOT_OUT);
151 };
152 
153 template <typename T, class Context>
154 class CosineSimilarityOp : public Operator<Context> {
155  public:
156  template <class... Args>
157  explicit CosineSimilarityOp(Args&&... args)
158  : Operator<Context>(std::forward<Args>(args)...) {}
159  USE_OPERATOR_CONTEXT_FUNCTIONS;
160 
161  bool RunOnDevice() override;
162 
163  protected:
164  INPUT_TAGS(X_IN, Y_IN);
165  OUTPUT_TAGS(COS_OUT);
166 
167  private:
168  Tensor aux_;
169 };
170 
171 template <typename T, class Context>
172 class CosineSimilarityGradientOp final : public Operator<Context> {
173  public:
174  template <class... Args>
175  explicit CosineSimilarityGradientOp(Args&&... args)
176  : Operator<Context>(std::forward<Args>(args)...) {}
177  USE_OPERATOR_CONTEXT_FUNCTIONS;
178 
179  bool RunOnDevice() override;
180 
181  protected:
182  INPUT_TAGS(X_IN, Y_IN, DER_COS_IN);
183  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
184 
185  private:
186  Tensor aux_;
187 };
188 
189 template <typename T, class Context>
190 class DotProductWithPaddingGradientOp final : public Operator<Context> {
191  public:
192  template <class... Args>
193  explicit DotProductWithPaddingGradientOp(Args&&... args)
194  : Operator<Context>(std::forward<Args>(args)...),
195  pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)),
196  replicate_(this->template GetSingleArgument<bool>("replicate", false)) {
197  }
198  USE_OPERATOR_CONTEXT_FUNCTIONS;
199 
200  bool RunOnDevice() override {
201  auto& X = Input(X_IN);
202  auto& Y = Input(Y_IN);
203  auto& dDot = Input(DER_DOT_IN);
204 
205  int N, D, DX, DY, restD;
206  if (X.numel() > 0) {
207  N = X.dim() > 0 ? X.dim32(0) : 1;
208  DX = X.numel() / N;
209  DY = Y.numel() / N;
210  } else {
211  N = 0;
212  DX = 0;
213  DY = 0;
214  }
215  CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0);
216  D = std::min(DX, DY);
217  restD = std::max(DX, DY) - D;
218  CAFFE_ENFORCE_EQ(X.dim(), Y.dim());
219  CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
220  CAFFE_ENFORCE_EQ(dDot.dim(), 1);
221  CAFFE_ENFORCE_EQ(dDot.dim32(0), N);
222  auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<T>());
223  auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<T>());
224 
225  const auto* X_data = X.template data<T>();
226  const auto* Y_data = Y.template data<T>();
227  const auto* dDot_data = dDot.template data<T>();
228  auto* dX_data = dX->template mutable_data<T>();
229  auto* dY_data = dY->template mutable_data<T>();
230  for (int i = 0; i < N; ++i) { // TODO: multithreading
231  auto offsetX = i * DX;
232  auto offsetY = i * DY;
233  if (replicate_) {
234  // L_ for longer vector and S_ for shorter vector
235  const T *L_data, *S_data;
236  T *dL_data, *dS_data;
237  int DL, DS;
238  if (DX > DY) {
239  L_data = X_data + offsetX;
240  S_data = Y_data + offsetY;
241  dL_data = dX_data + offsetX;
242  dS_data = dY_data + offsetY;
243  DL = DX;
244  DS = DY;
245  } else {
246  L_data = Y_data + offsetY;
247  S_data = X_data + offsetX;
248  dL_data = dY_data + offsetY;
249  dS_data = dX_data + offsetX;
250  DL = DY;
251  DS = DX;
252  }
253 
254  // TODO: get rid of temp memory use
255  std::vector<T> tmp_data(DS);
256  math::Set<T, Context>(DS, 0.0, dS_data, &context_);
257  for (int j = 0; j < DL / DS; j++) {
258  math::Scale<T, T, Context>(
259  DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
260  math::Scale<T, T, Context>(
261  DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
262  math::Axpy<T, Context>(DS, 1.0, tmp_data.data(), dS_data, &context_);
263  }
264  } else {
265  math::Scale<T, T, Context>(
266  D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
267  math::Scale<T, T, Context>(
268  D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
269  }
270 
271  if (!replicate_ && DX != DY) {
272  T* rest_data;
273  if (DX > DY) {
274  rest_data = dX_data + offsetX + D;
275  } else {
276  rest_data = dY_data + offsetY + D;
277  }
278  auto pad_gradient = dDot_data[i] * pad_value_;
279  math::Set<T, Context>(restD, pad_gradient, rest_data, &context_);
280  }
281  }
282 
283  return true;
284  }
285 
286  protected:
287  float pad_value_;
288  bool replicate_;
289  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
290  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
291 };
292 
293 } // namespace caffe2
294 
295 #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70