Caffe2 - C++ API
A deep learning, cross platform ML framework
order_switch_ops.h
1 #ifndef CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_
2 #define CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_
3 
4 #include <vector>
5 
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // Note(Yangqing): I think it is possible to do a more general swapaxes operator
12 // but I am a little afraid of going down that general path. Only implementing
13 // the two actually needed ones here.
14 
15 template <typename T, class Context>
16 class NHWC2NCHWOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  USE_SIMPLE_CTOR_DTOR(NHWC2NCHWOp);
21 
22  bool RunOnDevice() override {
23  const auto& X = Input(0);
24 
25  const int ndim = X.dim();
26  CAFFE_ENFORCE_GE(ndim, 3);
27  const int N = X.dim32(0);
28  const int C = X.dim32(ndim - 1);
29  std::vector<int64_t> Y_dims(ndim);
30  Y_dims[0] = N;
31  Y_dims[1] = C;
32  int HxW = 1;
33  for (int i = 2; i < ndim; ++i) {
34  Y_dims[i] = X.dim32(i - 1);
35  HxW *= Y_dims[i];
36  }
37  auto* Y = Output(0, Y_dims, at::dtype<T>());
38  if (X.numel() <= 0) {
39  return true;
40  }
41  math::NHWC2NCHW<T, Context>(
42  N,
43  C,
44  HxW,
45  X.template data<T>(),
46  Y->template mutable_data<T>(),
47  &context_);
48  return true;
49  }
50 };
51 
52 template <typename T, class Context>
53 class NCHW2NHWCOp final : public Operator<Context> {
54  public:
55  USE_OPERATOR_CONTEXT_FUNCTIONS;
56 
57  USE_SIMPLE_CTOR_DTOR(NCHW2NHWCOp);
58 
59  bool RunOnDevice() override {
60  const auto& X = Input(0);
61 
62  const int ndim = X.dim();
63  CAFFE_ENFORCE_GE(ndim, 3);
64  const int N = X.dim32(0);
65  const int C = X.dim32(1);
66  std::vector<int64_t> Y_dims(ndim);
67  Y_dims[0] = N;
68  Y_dims[ndim - 1] = C;
69  int HxW = 1;
70  for (int i = 1; i < ndim - 1; ++i) {
71  Y_dims[i] = X.dim32(i + 1);
72  HxW *= Y_dims[i];
73  }
74  auto* Y = Output(0, Y_dims, at::dtype<T>());
75  if (X.numel() <= 0) {
76  return true;
77  }
78  math::NCHW2NHWC<T, Context>(
79  N,
80  C,
81  HxW,
82  X.template data<T>(),
83  Y->template mutable_data<T>(),
84  &context_);
85  return true;
86  }
87 };
88 
89 } // namespace caffe2
90 
91 #endif // CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64