Caffe2 - C++ API
A deep learning, cross platform ML framework
transpose_op.h
1 
17 #ifndef CAFFE2_OPERATORS_TRANSPOSE_H_
18 #define CAFFE2_OPERATORS_TRANSPOSE_H_
19 #define MAX_BLOB_NUM 1024
20 
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class TransposeOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  USE_DISPATCH_HELPER;
32  TransposeOp(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws),
34  axes_(OperatorBase::GetRepeatedArgument<int>("axes")) {
35  // We will check the legality of axes_: it should be from 0 to axes_.size().
36  std::vector<int> axes_sorted(axes_);
37  std::sort(axes_sorted.begin(), axes_sorted.end());
38  for (int i = 0; i < axes_sorted.size(); ++i) {
39  if (axes_sorted[i] != i) {
40  CAFFE_THROW("Axes should be a permutation of 0 to ndim.");
41  }
42  }
43  }
44  ~TransposeOp() {}
45 
46  bool RunOnDevice() override {
47  const auto& X = Input(0);
48  auto* Y = Output(0);
49  const int num_axes = X.ndim();
50  const std::vector<int> x_dims(X.dims().cbegin(), X.dims().cend());
51  std::vector<int> y_dims(num_axes);
52  if (axes_.empty()) {
53  axes_.resize(num_axes);
54  for (int i = 0; i < num_axes; ++i) {
55  axes_[i] = num_axes - 1 - i;
56  }
57  y_dims.assign(X.dims().rbegin(), X.dims().rend());
58  } else {
59  CAFFE_ENFORCE_EQ(X.ndim(), axes_.size());
60  for (int i = 0; i < num_axes; ++i) {
61  y_dims[i] = X.dim32(axes_[i]);
62  }
63  }
64  Y->Resize(y_dims);
65  SetDeviceTensor(x_dims, &x_dims_device_);
66  SetDeviceTensor(y_dims, &y_dims_device_);
67  SetDeviceTensor(axes_, &axes_device_);
68 
69  // Do the actual transpose, which is implemented in DoRunWithType().
71  this, Input(0));
72  }
73 
74  protected:
75  void SetDeviceTensor(const std::vector<int>& data, Tensor<Context>* tensor) {
76  tensor->Resize(data.size());
77  context_.template Copy<int, CPUContext, Context>(
78  data.size(), data.data(), tensor->template mutable_data<int>());
79  }
80 
81  template <typename T>
82  bool DoRunWithType() {
83  const auto& X = Input(0);
84  auto* Y = Output(0);
85  math::Transpose<T, Context>(
86  axes_.size(),
87  x_dims_device_.template data<int>(),
88  y_dims_device_.template data<int>(),
89  axes_device_.template data<int>(),
90  X.size(),
91  X.template data<T>(),
92  Y->template mutable_data<T>(),
93  &context_);
94  return true;
95  }
96 
97  std::vector<int> axes_;
98 
99  Tensor<Context> x_dims_device_;
100  Tensor<Context> y_dims_device_;
101  Tensor<Context> axes_device_;
102 };
103 
104 } // namespace caffe2
105 
106 #endif // CAFFE2_OPERATORS_TRANSPOSE_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:304
Copyright (c) 2016-present, Facebook, Inc.