17 #include "batch_permutation_op.h" 18 #ifdef CAFFE2_USE_MKLDNN 19 #include <caffe2/ideep/operators/operator_fallback_ideep.h> 20 #include <caffe2/ideep/utils/ideep_operator.h> 25 #ifdef CAFFE2_USE_MKLDNN 26 REGISTER_IDEEP_OPERATOR(
28 IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
31 REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
32 REGISTER_CPU_OPERATOR(
33 BatchPermutationGradient,
34 BatchPermutationGradientOp<float, CPUContext>);
36 OPERATOR_SCHEMA(BatchPermutation)
40 Permute the batch elements of the input tensor X according to the permutation 41 specified in the input indices. 43 Warning: this op does not verify that indices is a valid permutation; gradient 44 comptuation is only correct if indices is a permutation. 49 "Tensor of at least 1D shape (N, D0, D1, ...).")
53 "1D tensor of type int with shape (N, ) specifying a valid permutation " 54 "of the indices in [0, N - 1] (inclusive).")
58 "Tensor with the same shape as X where the (D0, D1, ...) dimensional " 59 "batch elements of X are permuted according to the input indices.");
61 OPERATOR_SCHEMA(BatchPermutationGradient)
67 "See BatchPermutation.")
71 "Gradient of forward output 0 (Y).")
75 "Gradient of forward input 0 (X).");
78 bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
79 const auto& X = Input(0);
80 const auto& indices = Input(1);
82 CAFFE_ENFORCE_EQ(indices.dim(), 1,
"indices must be 1-d");
84 X.dim32(0), indices.dim32(0),
85 "X.dim32(0) must be equal to indices.dim32(0)",
92 auto* Y = Output(0, X.sizes(), at::dtype<float>());
94 const int N = X.dim32(0);
95 const int C = X.dim32(1);
96 const int H = X.dim32(2);
97 const int W = X.dim32(3);
99 const float *src = X.template data<float>();
100 float *dst = Y->template mutable_data<float>();
103 #if (_OPENMP >= 201307) 104 #pragma omp parallel for simd 106 #pragma omp parallel for 109 for (
int i = 0; i < N; i++) {
110 int idx = indices.template data<int>()[i];
112 std::memcpy(dst + i * C * H * W, src + idx * C * H * W,
sizeof(
float) * C * H * W);
119 using GradientMakerBase::GradientMakerBase;
120 vector<OperatorDef> GetGradientDefs()
override {
122 "BatchPermutationGradient",
124 vector<string>{I(1), GO(0)},
125 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...