1 #ifndef CAFFE2_OPERATORS_TRANSPOSE_H_ 2 #define CAFFE2_OPERATORS_TRANSPOSE_H_ 7 #include "caffe2/core/context.h" 8 #include "caffe2/core/operator.h" 9 #include "caffe2/utils/math.h" 13 template <
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 template <
class... Args>
22 axes_(this->
template GetRepeatedArgument<int>(
"axes")) {
24 std::vector<int> axes_sorted = axes_;
25 std::sort(axes_sorted.begin(), axes_sorted.end());
26 for (std::size_t i = 0; i < axes_sorted.size(); ++i) {
27 if (axes_sorted[i] != i) {
28 CAFFE_THROW(
"Axes should be a permutation of 0 to ndim.");
33 bool RunOnDevice()
override {
41 bool DoRunWithType() {
42 const auto& X =
Input(0);
44 const int ndim = X.dim();
47 std::iota(axes_.rbegin(), axes_.rend(), 0);
49 CAFFE_ENFORCE_EQ(ndim, axes_.size());
51 const std::vector<std::int64_t> X_dims = X.sizes().vec();
52 std::vector<std::int64_t> Y_dims(ndim);
53 for (
int i = 0; i < ndim; ++i) {
54 Y_dims[i] = X_dims[axes_[i]];
56 auto* Y = Output(0, Y_dims, at::dtype<T>());
57 math::Transpose<std::int64_t, T, Context>(
62 Y->template mutable_data<T>(),
67 std::vector<int> axes_;
72 #endif // CAFFE2_OPERATORS_TRANSPOSE_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...