1 #include "caffe2/operators/transpose_op.h" 5 REGISTER_CPU_OPERATOR(Transpose, TransposeOp<CPUContext>);
7 OPERATOR_SCHEMA(Transpose)
10 .TensorInferenceFunction([](
const OperatorDef& def,
11 const vector<TensorShape>& in) {
12 ArgumentHelper helper(def);
13 vector<int> axes = helper.GetRepeatedArgument<
int>(
"axes");
14 vector<TensorShape> out(1);
15 out[0].set_data_type(in[0].data_type());
18 for (
auto axis = in [0].dims().rbegin(); axis != in[0].dims().rend();
20 out[0].add_dims(*axis);
23 auto tensor_size = in[0].dims().size();
25 std::all_of(axes.begin(), axes.end(), [&tensor_size](
int& axis) {
26 return axis >= 0 && axis < tensor_size;
29 CAFFE_ENFORCE(valid_axes,
"Axes argument passed in had invalid values");
31 axes.size() == tensor_size,
32 "Axes argument passed in had the incorrect size");
34 for (
auto axis = axes.begin(); axis != axes.end(); ++axis) {
35 out[0].add_dims(in[0].dims().Get(*axis));
42 Transpose the input tensor by permuting the axes of the input according 43 to the `axes` argument. Similar to numpy's 44 [transpose](https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html) 47 For example, when axes=(1, 0, 2), given an input tensor of shape 48 (1, 2, 3), the output shape will be (2, 1, 3). 52 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/transpose_op.cc 56 <summary> <b>Example</b> </summary> 61 workspace.ResetWorkspace() 63 op = core.CreateOperator( 70 x = np.random.rand(1,32,32,3) 71 workspace.FeedBlob("X", x) 72 print("X.shape (NHWC order):", workspace.FetchBlob("X").shape) 73 workspace.RunOperatorOnce(op) 74 print("Y.shape (NCHW order):", workspace.FetchBlob("Y").shape) 80 X.shape (NHWC order): (1, 32, 32, 3) 81 Y.shape (NCHW order): (1, 3, 32, 32) 89 "*(type: Tuple(int))* Order to permute axes of input tensor. Reverses " 90 "the dimensions by default.")
91 .Input(0,
"X",
"*(type: Tensor)* Input tensor.")
92 .Output(0,
"Y",
"*(type: Tensor)* Transposed output.")
96 using GradientMakerBase::GradientMakerBase;
98 bool CopyArguments()
const override {
101 vector<OperatorDef> GetGradientDefs()
override {
103 "Transpose",
"", vector<string>{GO(0)}, vector<string>{GI(0)});
104 ops[0].mutable_arg()->CopyFrom(Def().arg());
105 if (ArgumentHelper::HasArgument(Def(),
"axes")) {
107 const Argument& old_axes = GetArgument(Def(),
"axes");
108 const int axes_size = old_axes.ints_size();
109 Argument* new_arg = GetMutableArgument(
"axes",
false, &ops[0]);
110 for (
int i = 0; i < axes_size; ++i) {
111 new_arg->set_ints(old_axes.ints(i), i);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...