Caffe2 - C++ API
A deep learning, cross platform ML framework
transpose_op.cc
1 
17 #include "caffe2/operators/transpose_op.h"
18 
19 #ifdef CAFFE2_USE_MKL
20 #include "caffe2/mkl/operators/operator_fallback_mkl.h"
21 #endif // CAFFE2_USE_MKL
22 
23 namespace caffe2 {
24 
25 REGISTER_CPU_OPERATOR(Transpose, TransposeOp<CPUContext>);
26 
27 #ifdef CAFFE2_HAS_MKL_DNN
28 // Registering in operator_fallback_mkl.cc results in a linker error in
29 // in opt build related to DoRunWithType().
30 REGISTER_MKL_OPERATOR(Transpose, mkl::MKLFallbackOp<TransposeOp<CPUContext>>);
31 #endif // CAFFE2_HAS_MKL_DNN
32 
33 OPERATOR_SCHEMA(Transpose)
34  .NumInputs(1)
35  .NumOutputs(1)
36  .TensorInferenceFunction([](
37  const OperatorDef& def,
38  const vector<TensorShape>& in) {
39  ArgumentHelper helper(def);
40  vector<int> axes = helper.GetRepeatedArgument<int>("axes");
41  vector<TensorShape> out(1);
42  out[0].set_data_type(in[0].data_type());
43 
44  if (axes.empty()) {
45  for (auto axis = in [0].dims().rbegin(); axis != in[0].dims().rend();
46  ++axis) {
47  out[0].add_dims(*axis);
48  }
49  } else {
50  auto tensor_size = in[0].dims().size();
51  auto valid_axes =
52  std::all_of(axes.begin(), axes.end(), [&tensor_size](int& axis) {
53  return axis >= 0 && axis < tensor_size;
54  });
55 
56  CAFFE_ENFORCE(valid_axes, "Axes argument passed in had invalid values");
57  CAFFE_ENFORCE(
58  axes.size() == tensor_size,
59  "Axes argument passed in had the incorrect size");
60 
61  for (auto axis = axes.begin(); axis != axes.end(); ++axis) {
62  out[0].add_dims(in[0].dims().Get(*axis));
63  }
64  }
65 
66  return out;
67  })
68  .SetDoc(R"DOC(
69 Transpose the input tensor similar to numpy.transpose. For example, when
70 axes=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape
71 will be (2, 1, 3).
72 )DOC")
73  .Arg(
74  "axes",
75  "A list of integers. By default, reverse the dimensions, "
76  "otherwise permute the axes according to the values given.")
77  .Input(0, "data", "An input tensor.")
78  .Output(0, "transposed", "Transposed output.");
79 
81  using GradientMakerBase::GradientMakerBase;
82  // We will create our own arguments.
83  bool CopyArguments() const override {
84  return false;
85  }
86  vector<OperatorDef> GetGradientDefs() override {
87  auto ops = SingleGradientDef(
88  "Transpose", "", vector<string>{GO(0)}, vector<string>{GI(0)});
89  ops[0].mutable_arg()->CopyFrom(Def().arg());
90  if (ArgumentHelper::HasArgument(Def(), "axes")) {
91  // If axes is specified, we will need to figure out the inverse index.
92  const Argument& old_axes = GetArgument(Def(), "axes");
93  const int axes_size = old_axes.ints_size();
94  Argument* new_arg = GetMutableArgument("axes", false, &ops[0]);
95  for (int i = 0; i < axes_size; ++i) {
96  new_arg->set_ints(old_axes.ints(i), i);
97  }
98  }
99  return ops;
100  }
101 };
102 
103 REGISTER_GRADIENT(Transpose, GetTransposeGradient);
104 
105 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...