Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_permutation_op.cc
1 
17 #include "batch_permutation_op.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
22 REGISTER_CPU_OPERATOR(
23  BatchPermutationGradient,
24  BatchPermutationGradientOp<float, CPUContext>);
25 
26 OPERATOR_SCHEMA(BatchPermutation)
27  .NumInputs(2)
28  .NumOutputs(1)
29  .SetDoc(R"DOC(
30 Permute the batch elements of the input tensor X according to the permutation
31 specified in the input indices.
32 
33 Warning: this op does not verify that indices is a valid permutation; gradient
34 comptuation is only correct if indices is a permutation.
35 )DOC")
36  .Input(
37  0,
38  "X",
39  "Tensor of at least 1D shape (N, D0, D1, ...).")
40  .Input(
41  1,
42  "indices",
43  "1D tensor of type int with shape (N, ) specifying a valid permutation "
44  "of the indices in [0, N - 1] (inclusive).")
45  .Output(
46  0,
47  "Y",
48  "Tensor with the same shape as X where the (D0, D1, ...) dimensional "
49  "batch elements of X are permuted according to the input indices.");
50 
51 OPERATOR_SCHEMA(BatchPermutationGradient)
52  .NumInputs(2)
53  .NumOutputs(1)
54  .Input(
55  0,
56  "indices",
57  "See BatchPermutation.")
58  .Input(
59  1,
60  "dY",
61  "Gradient of forward output 0 (Y).")
62  .Output(
63  0,
64  "dX",
65  "Gradient of forward input 0 (X).");
66 
68  using GradientMakerBase::GradientMakerBase;
69  vector<OperatorDef> GetGradientDefs() override {
70  return SingleGradientDef(
71  "BatchPermutationGradient",
72  "",
73  vector<string>{I(1), GO(0)},
74  vector<string>{GI(0)});
75  }
76 };
77 
78 REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
79 
80 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...