Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_permutation_op.cc
1 
17 #include "batch_permutation_op.h"
18 #ifdef CAFFE2_USE_MKLDNN
19 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
20 #include <caffe2/ideep/utils/ideep_operator.h>
21 #endif
22 
23 namespace caffe2 {
24 
25 #ifdef CAFFE2_USE_MKLDNN
26 REGISTER_IDEEP_OPERATOR(
27  BatchPermutation,
28  IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
29 #endif
30 
31 REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
32 REGISTER_CPU_OPERATOR(
33  BatchPermutationGradient,
34  BatchPermutationGradientOp<float, CPUContext>);
35 
36 OPERATOR_SCHEMA(BatchPermutation)
37  .NumInputs(2)
38  .NumOutputs(1)
39  .SetDoc(R"DOC(
40 Permute the batch elements of the input tensor X according to the permutation
41 specified in the input indices.
42 
43 Warning: this op does not verify that indices is a valid permutation; gradient
44 comptuation is only correct if indices is a permutation.
45 )DOC")
46  .Input(
47  0,
48  "X",
49  "Tensor of at least 1D shape (N, D0, D1, ...).")
50  .Input(
51  1,
52  "indices",
53  "1D tensor of type int with shape (N, ) specifying a valid permutation "
54  "of the indices in [0, N - 1] (inclusive).")
55  .Output(
56  0,
57  "Y",
58  "Tensor with the same shape as X where the (D0, D1, ...) dimensional "
59  "batch elements of X are permuted according to the input indices.");
60 
61 OPERATOR_SCHEMA(BatchPermutationGradient)
62  .NumInputs(2)
63  .NumOutputs(1)
64  .Input(
65  0,
66  "indices",
67  "See BatchPermutation.")
68  .Input(
69  1,
70  "dY",
71  "Gradient of forward output 0 (Y).")
72  .Output(
73  0,
74  "dX",
75  "Gradient of forward input 0 (X).");
76 
77 template <>
78 bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
79  const auto& X = Input(0);
80  const auto& indices = Input(1);
81 
82  CAFFE_ENFORCE_EQ(indices.dim(), 1, "indices must be 1-d");
83  CAFFE_ENFORCE_EQ(
84  X.dim32(0), indices.dim32(0),
85  "X.dim32(0) must be equal to indices.dim32(0)",
86  "(",
87  X.dim32(0),
88  " vs. ",
89  indices.dim32(0),
90  ")");
91 
92  auto* Y = Output(0, X.sizes(), at::dtype<float>());
93 
94  const int N = X.dim32(0);
95  const int C = X.dim32(1);
96  const int H = X.dim32(2);
97  const int W = X.dim32(3);
98 
99  const float *src = X.template data<float>();
100  float *dst = Y->template mutable_data<float>();
101 
102 #ifdef _OPENMP
103 #if (_OPENMP >= 201307)
104 #pragma omp parallel for simd
105 #else
106 #pragma omp parallel for
107 #endif
108 #endif
109  for (int i = 0; i < N; i++) {
110  int idx = indices.template data<int>()[i];
111 
112  std::memcpy(dst + i * C * H * W, src + idx * C * H * W, sizeof(float) * C * H * W);
113  }
114 
115  return true;
116 }
117 
119  using GradientMakerBase::GradientMakerBase;
120  vector<OperatorDef> GetGradientDefs() override {
121  return SingleGradientDef(
122  "BatchPermutationGradient",
123  "",
124  vector<string>{I(1), GO(0)},
125  vector<string>{GI(0)});
126  }
127 };
128 
129 REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
130 
131 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
Definition: static.cpp:64