3 #include "caffe2/core/operator.h" 7 template <
typename Context>
8 void momentum_sgd_update(
19 const float LR = lr[0];
20 for (
auto i = 0; i < N; ++i) {
22 const float adjusted_gradient = LR * g[i] + momentum * m[i];
23 nm[i] = adjusted_gradient;
24 ng[i] = adjusted_gradient;
26 const float mi = m[i];
27 const float mi_new = momentum * mi + LR * g[i];
29 ng[i] = (1 + momentum) * mi_new - momentum * mi;
38 template <
typename T,
class Context>
41 USE_OPERATOR_CONTEXT_FUNCTIONS;
44 momentum_(this->
template GetSingleArgument<T>(
"momentum", 0.0)),
45 nesterov_(this->
template GetSingleArgument<int>(
"nesterov", 0)) {}
47 bool RunOnDevice()
override {
48 auto device_type = Context::GetDeviceType();
50 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
51 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
52 CAFFE_ENFORCE(
Input(LR).numel() == 1);
53 CAFFE_ENFORCE(
Input(GRAD).numel() ==
Input(MOMENTUM).numel());
54 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
55 Output(OUTPUT_MOMENTUM)->ResizeLike(
Input(MOMENTUM));
57 momentum_sgd_update<Context>(
59 Input(GRAD).template data<T>(),
60 Input(MOMENTUM).template data<T>(),
61 Output(OUTPUT_GRAD)->template mutable_data<T>(),
62 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
63 Input(LR).template data<T>(),
74 INPUT_TAGS(GRAD, MOMENTUM, LR);
75 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
78 template <
typename T,
class Context>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
84 momentum_(this->
template GetSingleArgument<T>(
"momentum", 0.0)),
85 nesterov_(this->
template GetSingleArgument<int>(
"nesterov", 0)) {}
87 bool RunOnDevice()
override {
88 auto device_type = Context::GetDeviceType();
90 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
91 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
92 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
93 CAFFE_ENFORCE_EQ(
Input(GRAD).numel(),
Input(MOMENTUM).numel());
94 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
95 Output(OUTPUT_MOMENTUM)->ResizeLike(
Input(MOMENTUM));
97 momentum_sgd_update<Context>(
99 Input(GRAD).template data<T>(),
100 Input(MOMENTUM).template data<T>(),
101 Output(OUTPUT_GRAD)->template mutable_data<T>(),
102 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
103 Input(LR).template data<T>(),
106 Output(OUTPUT_PARAM)->template mutable_data<T>(),
114 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
115 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
118 template <
typename T,
class Context>
121 USE_OPERATOR_CONTEXT_FUNCTIONS;
124 momentum_(this->
template GetSingleArgument<T>(
"momentum", 0.0)),
125 nesterov_(this->
template GetSingleArgument<int>(
"nesterov", 0)) {}
127 bool RunOnDevice()
override {
129 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
132 CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
133 CAFFE_ENFORCE_EQ(
Input(PARAM).numel(),
Input(MOMENTUM).numel());
135 Input(PARAM).size_from_dim(1),
136 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
139 this,
Input(INDICES));
142 template <
typename SIndex>
143 bool DoRunWithType() {
144 auto block_size =
Input(PARAM).numel() /
Input(PARAM).size(0);
145 auto n =
Input(GRAD).numel() / block_size;
147 const auto* gradIn =
Input(GRAD).template data<T>();
148 const auto* momentumIn =
Input(MOMENTUM).template data<T>();
149 const auto* lr =
Input(LR).template data<T>();
150 const auto* paramIn =
Input(PARAM).template data<T>();
151 const auto* indices =
Input(INDICES).template data<SIndex>();
153 auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
154 auto* momentumOut = Output(OUTPUT_MOMENTUM)->template mutable_data<T>();
155 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
157 for (
auto i = 0; i < n; ++i) {
158 auto idx = indices[i];
159 auto offsetI = i * block_size;
160 auto offsetIdx = idx * block_size;
162 CAFFE_ENFORCE(offsetIdx + block_size <=
Input(PARAM).numel());
163 CAFFE_ENFORCE(offsetI + block_size <=
Input(GRAD).numel());
165 momentum_sgd_update<Context>(
168 momentumIn + offsetIdx,
170 momentumOut + offsetIdx,
174 paramOut + offsetIdx,
183 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM, INDICES);
184 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...