Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
18 #define CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class LRNOpBase : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  LRNOpBase(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
34  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
35  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
36  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)),
37  order_(StringToStorageOrder(
38  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
39  pre_pad_((size_ - 1) / 2) {
40  DCHECK_GT(size_, 0);
41  DCHECK_EQ(size_ % 2, 1);
42  DCHECK_GT(alpha_, 0);
43  DCHECK_GT(beta_, 0);
44  }
45 
46  bool RunOnDevice() override {
47  switch (order_) {
48  case StorageOrder::NHWC:
49  return RunOnDeviceWithOrderNHWC();
50  case StorageOrder::NCHW:
51  return RunOnDeviceWithOrderNCHW();
52  default:
53  LOG(FATAL) << "Unknown storage order: " << order_;
54  }
55  // To suppress old compiler warnings
56  return true;
57  }
58 
59  virtual bool RunOnDeviceWithOrderNCHW() = 0;
60  virtual bool RunOnDeviceWithOrderNHWC() = 0;
61 
62  protected:
63  const int size_;
64  const float alpha_;
65  const float beta_;
66  const float bias_;
67  const StorageOrder order_;
68  const int pre_pad_;
69  // Input: X; Output: Y, scale.
70 };
71 
72 template <typename T, class Context>
73 class LRNOp final : public LRNOpBase<T, Context> {
74  public:
75  USE_OPERATOR_CONTEXT_FUNCTIONS;
76  LRNOp(const OperatorDef& operator_def, Workspace* ws)
77  : LRNOpBase<T, Context>(operator_def, ws) {}
78 
79  bool RunOnDeviceWithOrderNCHW() override;
80  bool RunOnDeviceWithOrderNHWC() override;
81 
82  protected:
83  // Input: X; Output: Y, scale.
84  OUTPUT_TAGS(OUTPUT, SCALE);
85  Tensor<Context>* scale_ = nullptr;
86  Tensor<Context> local_scale_tensor_;
87 };
88 
89 template <typename T, class Context>
90 class LRNGradientOp final : public LRNOpBase<T, Context> {
91  public:
92  USE_OPERATOR_CONTEXT_FUNCTIONS;
93  LRNGradientOp(const OperatorDef& operator_def, Workspace* ws)
94  : LRNOpBase<T, Context>(operator_def, ws) {}
95 
96  bool RunOnDeviceWithOrderNCHW() override;
97  bool RunOnDeviceWithOrderNHWC() override;
98 
99  protected:
100  // Input: X, Y, scale, dY; Output: dX
101  INPUT_TAGS(INPUT, OUTPUT, SCALE, OUTPUT_GRAD);
102  Tensor<Context>* scale_ = nullptr;
103  Tensor<Context> local_scale_tensor_;
104 };
105 
106 } // namespace caffe2
107 
108 #endif // CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.