Caffe2 - C++ API
A deep learning, cross platform ML framework
affine_channel_op.h
1 #ifndef CAFFE2_OPERATORS_AFFINE_CHANNEL_OP_H_
2 #define CAFFE2_OPERATORS_AFFINE_CHANNEL_OP_H_
3 
4 #include <string>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 template <typename T, class Context>
14 class AffineChannelOp final : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17 
18  template <class... Args>
19  explicit AffineChannelOp(Args&&... args)
20  : Operator<Context>(std::forward<Args>(args)...),
21  order_(StringToStorageOrder(
22  this->template GetSingleArgument<std::string>("order", "NCHW"))),
23  OP_SINGLE_ARG(bool, "is_learnable", is_learnable_, false) {
24  CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
25  }
26 
27  bool RunOnDevice() override {
28  return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW()
29  : RunOnDeviceWithOrderNHWC();
30  }
31 
32  bool RunOnDeviceWithOrderNCHW() {
33  const auto& X = Input(0);
34  const auto& scale = Input(1);
35  const auto& bias = Input(2);
36 
37  if (is_learnable_) {
38  CAFFE_ENFORCE(
39  !IsInputOutputAlias(0, 0),
40  "In-place affine_channel_op is not supported when "
41  "is_learnable = true.");
42  }
43  const int N = X.dim32(0);
44  const int C = X.dim32(1);
45  const int HxW = X.numel() / (N * C);
46  auto* Y = Output(0, X.sizes(), at::dtype<T>());
47  math::AffineChannel<T, Context, StorageOrder::NCHW>(
48  N,
49  C,
50  HxW,
51  X.template data<T>(),
52  scale.template data<T>(),
53  bias.template data<T>(),
54  Y->template mutable_data<T>(),
55  &context_);
56  return true;
57  }
58 
59  bool RunOnDeviceWithOrderNHWC() {
60  const auto& X = Input(0);
61  const auto& scale = Input(1);
62  const auto& bias = Input(2);
63 
64  if (is_learnable_) {
65  CAFFE_ENFORCE(
66  !IsInputOutputAlias(0, 0),
67  "In-place affine_channel_op is not supported when "
68  "is_learnable = true.");
69  }
70  const int ndim = X.dim();
71  const int N = X.dim32(0);
72  const int C = X.dim32(ndim - 1);
73  const int HxW = X.numel() / (N * C);
74  auto* Y =
75  Output(0, X.sizes(), at::dtype<T>());
76  math::AffineChannel<T, Context, StorageOrder::NHWC>(
77  N,
78  C,
79  HxW,
80  X.template data<T>(),
81  scale.template data<T>(),
82  bias.template data<T>(),
83  Y->template mutable_data<T>(),
84  &context_);
85  return true;
86  }
87 
88  private:
89  const StorageOrder order_;
90  const bool is_learnable_;
91 };
92 
93 template <typename T, class Context>
94 class AffineChannelGradientOp final : public Operator<Context> {
95  public:
96  USE_OPERATOR_CONTEXT_FUNCTIONS;
97 
98  template <class... Args>
99  explicit AffineChannelGradientOp(Args&&... args)
100  : Operator<Context>(std::forward<Args>(args)...),
101  order_(StringToStorageOrder(
102  this->template GetSingleArgument<std::string>("order", "NCHW"))),
103  OP_SINGLE_ARG(bool, "is_learnable", is_learnable_, false) {
104  CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
105  }
106 
107  bool RunOnDevice() override {
108  return order_ == StorageOrder::NCHW ? RunOnDeviceWithOrderNCHW()
109  : RunOnDeviceWithOrderNHWC();
110  }
111 
112  bool RunOnDeviceWithOrderNCHW();
113 
114  bool RunOnDeviceWithOrderNHWC();
115 
116  private:
117  const StorageOrder order_;
118  const bool is_learnable_;
119 };
120 
121 } // namespace caffe2
122 
123 #endif // CAFFE2_OPERATORS_AFFINE_CHANNEL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64