Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op.h
1 #ifndef CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
2 #define CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class LRNOpBase : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  template <class... Args>
16  explicit LRNOpBase(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  size_(this->template GetSingleArgument<int>("size", 0)),
19  alpha_(this->template GetSingleArgument<float>("alpha", 0)),
20  beta_(this->template GetSingleArgument<float>("beta", 0)),
21  bias_(this->template GetSingleArgument<float>("bias", 1)),
22  order_(StringToStorageOrder(
23  this->template GetSingleArgument<string>("order", "NCHW"))),
24  pre_pad_((size_ - 1) / 2) {
25  DCHECK_GT(size_, 0);
26  DCHECK_EQ(size_ % 2, 1);
27  DCHECK_GT(alpha_, 0);
28  DCHECK_GT(beta_, 0);
29  }
30 
31  bool RunOnDevice() override {
32  switch (order_) {
33  case StorageOrder::NHWC:
34  return RunOnDeviceWithOrderNHWC();
35  case StorageOrder::NCHW:
36  return RunOnDeviceWithOrderNCHW();
37  default:
38  LOG(FATAL) << "Unknown storage order: " << order_;
39  }
40  // To suppress old compiler warnings
41  return true;
42  }
43 
44  virtual bool RunOnDeviceWithOrderNCHW() = 0;
45  virtual bool RunOnDeviceWithOrderNHWC() = 0;
46 
47  protected:
48  const int size_;
49  const float alpha_;
50  const float beta_;
51  const float bias_;
52  const StorageOrder order_;
53  const int pre_pad_;
54  // Input: X; Output: Y, scale.
55 };
56 
57 template <typename T, class Context>
58 class LRNOp final : public LRNOpBase<T, Context> {
59  public:
60  USE_OPERATOR_CONTEXT_FUNCTIONS;
61  template <class... Args>
62  explicit LRNOp(Args&&... args)
63  : LRNOpBase<T, Context>(std::forward<Args>(args)...) {}
64 
65  bool RunOnDeviceWithOrderNCHW() override;
66  bool RunOnDeviceWithOrderNHWC() override;
67 
68  protected:
69  // Input: X; Output: Y, scale.
70  OUTPUT_TAGS(OUTPUT, SCALE);
71  Tensor* scale_ = nullptr;
72  Tensor local_scale_tensor_{Context::GetDeviceType()};
73 };
74 
75 template <typename T, class Context>
76 class LRNGradientOp final : public LRNOpBase<T, Context> {
77  public:
78  USE_OPERATOR_CONTEXT_FUNCTIONS;
79  template <class... Args>
80  explicit LRNGradientOp(Args&&... args)
81  : LRNOpBase<T, Context>(std::forward<Args>(args)...) {}
82 
83  bool RunOnDeviceWithOrderNCHW() override;
84  bool RunOnDeviceWithOrderNHWC() override;
85 
86  protected:
87  // Input: X, Y, scale, dY; Output: dX
88  INPUT_TAGS(INPUT, OUTPUT, SCALE, OUTPUT_GRAD);
89  Tensor* scale_ = nullptr;
90  Tensor local_scale_tensor_{Context::GetDeviceType()};
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13