Caffe2 - C++ API
A deep learning, cross platform ML framework
adadelta_op.h
1 #include "caffe2/core/operator.h"
2 
3 namespace caffe2 {
4 
5 namespace {
6 
7 template <typename Context>
8 void AdadeltaUpdate(
9  int N,
10  const float* w,
11  const float* g,
12  const float* h,
13  const float* d,
14  const float epsilon,
15  const float decay,
16  const float* lr,
17  float* nw,
18  float* nh,
19  float* nd,
20  Context* /*context*/) {
21  for (int i = 0; i < N; ++i) {
22  float gi = g[i];
23  float di = d[i];
24  float hi = nh[i] = decay * h[i] + (1.0f - decay) * gi * gi;
25  float ng = (std::sqrt(di + epsilon) / std::sqrt(hi + epsilon)) * gi;
26  nw[i] = w[i] + lr[0] * ng;
27  nd[i] = decay * di + (1.0f - decay) * ng * ng;
28  }
29 }
30 
31 } // namespace
32 
33 template <class Context>
34 class AdadeltaOp final : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  AdadeltaOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws),
39  OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
40  OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {}
41 
42  bool RunOnDevice() override {
43  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_GRAD).numel());
44  CAFFE_ENFORCE(Input(GRAD).numel() == Input(MOMENT_DELTA).numel());
45  CAFFE_ENFORCE(Input(GRAD).numel() == Input(PARAM).numel());
46  CAFFE_ENFORCE_GE(epsilon_, 0.0f);
47  CAFFE_ENFORCE_GT(decay_, 0.0f);
48  CAFFE_ENFORCE_LT(decay_, 1.0f);
49 
50  Output(OUTPUT_PARAM)->ResizeLike(Input(PARAM));
51  Output(OUTPUT_MOMENT_GRAD)->ResizeLike(Input(MOMENT_GRAD));
52  Output(OUTPUT_MOMENT_DELTA)->ResizeLike(Input(MOMENT_DELTA));
53  AdadeltaUpdate<Context>(
54  Input(GRAD).numel(),
55  Input(PARAM).template data<float>(),
56  Input(GRAD).template data<float>(),
57  Input(MOMENT_GRAD).template data<float>(),
58  Input(MOMENT_DELTA).template data<float>(),
59  epsilon_,
60  decay_,
61  Input(LR).template data<float>(),
62  Output(OUTPUT_PARAM)->template mutable_data<float>(),
63  Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>(),
64  Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>(),
65  &context_);
66  return true;
67  }
68 
69  protected:
70  const float epsilon_;
71  const float decay_;
72  INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, GRAD, LR);
73  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
74 };
75 
76 template <class Context>
77 class SparseAdadeltaOp final : public Operator<Context> {
78  public:
79  USE_OPERATOR_CONTEXT_FUNCTIONS;
80  SparseAdadeltaOp(const OperatorDef& operator_def, Workspace* ws)
81  : Operator<Context>(operator_def, ws),
82  OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f),
83  OP_SINGLE_ARG(float, "decay", decay_, 0.95f) {}
84 
85  bool RunOnDevice() override {
86  // Enforce shapes
87  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_GRAD).numel());
88  CAFFE_ENFORCE_EQ(Input(PARAM).numel(), Input(MOMENT_DELTA).numel());
89  CAFFE_ENFORCE_EQ(Input(LR).numel(), 1);
90  CAFFE_ENFORCE_EQ(
91  Input(PARAM).size_from_dim(1),
92  Input(GRAD).size_from_dim(Input(INDICES).dim()));
93 
94  // Enforce domain constraints for attributes
95  CAFFE_ENFORCE_GE(epsilon_, 0.0f);
96  CAFFE_ENFORCE_GT(decay_, 0.0f);
97  CAFFE_ENFORCE_LT(decay_, 1.0f);
98 
100  this, Input(INDICES));
101  }
102 
103  template <typename SIndex>
104  bool DoRunWithType() {
105  const auto* lr = Input(LR).template data<float>();
106  const auto* indices = Input(INDICES).template data<SIndex>();
107  const auto* gradIn = Input(GRAD).template data<float>();
108  const auto* paramIn = Input(PARAM).template data<float>();
109  const auto* momentIn = Input(MOMENT_GRAD).template data<float>();
110  const auto* momentDeltaIn = Input(MOMENT_DELTA).template data<float>();
111  auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>();
112  auto* momentOut =
113  Output(OUTPUT_MOMENT_GRAD)->template mutable_data<float>();
114  auto* momentDeltaOut =
115  Output(OUTPUT_MOMENT_DELTA)->template mutable_data<float>();
116 
117  auto n = Input(INDICES).numel();
118  if (n == 0) {
119  return true;
120  }
121 
122  auto block_size = Input(GRAD).numel() / n;
123  for (int i = 0; i < n; ++i) {
124  auto idx = indices[i];
125  if (block_size == 1) {
126  float gi = gradIn[i];
127  float di = momentDeltaIn[idx];
128  float hi = momentOut[idx] =
129  decay_ * momentIn[idx] + (1.0f - decay_) * gi * gi;
130  float ng = (std::sqrt(di + epsilon_) / std::sqrt(hi + epsilon_)) * gi;
131  paramOut[idx] = paramIn[idx] + lr[0] * ng;
132  momentDeltaOut[idx] = decay_ * di + (1.0f - decay_) * ng * ng;
133  } else {
134  auto offsetI = i * block_size;
135  auto offsetIdx = idx * block_size;
136 
137 #ifndef NDEBUG
138  CAFFE_ENFORCE_GE(
139  Input(PARAM).numel(),
140  block_size + offsetIdx,
141  this->debug_def().input(PARAM),
142  ", out of bound, idx:",
143  idx,
144  " for input i:",
145  i,
146  " and block size:",
147  block_size);
148  CAFFE_ENFORCE_GE(
149  Input(GRAD).numel(),
150  block_size + offsetI,
151  this->debug_def().input(GRAD),
152  ", out of bound idx, idx:",
153  idx,
154  " for input i:",
155  i);
156 #endif
157  AdadeltaUpdate(
158  block_size,
159  paramIn + offsetIdx,
160  gradIn + offsetI,
161  momentIn + offsetIdx,
162  momentDeltaIn + offsetIdx,
163  epsilon_,
164  decay_,
165  lr,
166  paramOut + offsetIdx,
167  momentOut + offsetIdx,
168  momentDeltaOut + offsetIdx,
169  &context_);
170  }
171  }
172  return true;
173  }
174 
175  protected:
176  const float epsilon_;
177  const float decay_;
178  INPUT_TAGS(PARAM, MOMENT_GRAD, MOMENT_DELTA, INDICES, GRAD, LR);
179  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_GRAD, OUTPUT_MOMENT_DELTA);
180 };
181 
182 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13