Caffe2 - C++ API
A deep learning, cross platform ML framework
transpose_op.h
1 #ifndef CAFFE2_OPERATORS_TRANSPOSE_H_
2 #define CAFFE2_OPERATORS_TRANSPOSE_H_
3 
4 #include <algorithm>
5 #include <vector>
6 
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 template <class Context>
14 class TransposeOp final : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  USE_DISPATCH_HELPER;
18 
19  template <class... Args>
20  explicit TransposeOp(Args&&... args)
21  : Operator<Context>(std::forward<Args>(args)...),
22  axes_(this->template GetRepeatedArgument<int>("axes")) {
23  // We will check the legality of axes_: it should be from 0 to axes_.size().
24  std::vector<int> axes_sorted = axes_;
25  std::sort(axes_sorted.begin(), axes_sorted.end());
26  for (std::size_t i = 0; i < axes_sorted.size(); ++i) {
27  if (axes_sorted[i] != i) {
28  CAFFE_THROW("Axes should be a permutation of 0 to ndim.");
29  }
30  }
31  }
32 
33  bool RunOnDevice() override {
34  // Do the actual transpose, which is implemented in DoRunWithType().
36  this, Input(0));
37  }
38 
39  private:
40  template <typename T>
41  bool DoRunWithType() {
42  const auto& X = Input(0);
43 
44  const int ndim = X.dim();
45  if (axes_.empty()) {
46  axes_.resize(ndim);
47  std::iota(axes_.rbegin(), axes_.rend(), 0);
48  } else {
49  CAFFE_ENFORCE_EQ(ndim, axes_.size());
50  }
51  const std::vector<std::int64_t> X_dims = X.sizes().vec();
52  std::vector<std::int64_t> Y_dims(ndim);
53  for (int i = 0; i < ndim; ++i) {
54  Y_dims[i] = X_dims[axes_[i]];
55  }
56  auto* Y = Output(0, Y_dims, at::dtype<T>());
57  math::Transpose<std::int64_t, T, Context>(
58  X_dims.size(),
59  X_dims.data(),
60  axes_.data(),
61  X.template data<T>(),
62  Y->template mutable_data<T>(),
63  &context_);
64  return true;
65  }
66 
67  std::vector<int> axes_;
68 };
69 
70 } // namespace caffe2
71 
72 #endif // CAFFE2_OPERATORS_TRANSPOSE_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13