1 #include <caffe2/ideep/ideep_utils.h> 5 void momentum_sgd_update(
15 const float LR = lr[0];
17 #pragma omp parallel for schedule(static) 19 for (
auto i = 0; i < N; ++i) {
21 const float adjusted_gradient = LR * g[i] + momentum * m[i];
22 nm[i] = adjusted_gradient;
23 ng[i] = adjusted_gradient;
25 const float mi = m[i];
26 const float mi_new = momentum * mi + LR * g[i];
28 ng[i] = (1 + momentum) * mi_new - momentum * mi;
39 USE_IDEEP_DEF_ALIASES();
40 USE_IDEEP_OPERATOR_FUNCTIONS();
44 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.0)),
45 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
47 bool RunOnDevice()
override {
48 CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
49 if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
50 Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
52 if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
53 Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
57 const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
58 CAFFE_ENFORCE(lr.numel() == 1);
61 Input(GRAD).get_nelems(),
62 static_cast<float*>(Input(GRAD).get_data_handle()),
63 static_cast<float*>(Input(MOMENTUM).get_data_handle()),
64 static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
65 static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
66 lr.template data<float>(),
76 INPUT_TAGS(GRAD, MOMENTUM, LR);
77 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM);
82 USE_IDEEP_DEF_ALIASES();
83 USE_IDEEP_OPERATOR_FUNCTIONS();
86 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.0)),
87 nesterov_(OperatorBase::GetSingleArgument<int>(
"nesterov", 0)) {}
89 bool RunOnDevice()
override {
90 CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
91 if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
92 Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
94 if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
95 Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
99 const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
100 CAFFE_ENFORCE(lr.numel() == 1);
103 Input(GRAD).get_nelems(),
104 static_cast<float*>(Input(GRAD).get_data_handle()),
105 static_cast<float*>(Input(MOMENTUM).get_data_handle()),
106 static_cast<float*>(Output(OUTPUT_GRAD)->get_data_handle()),
107 static_cast<float*>(Output(OUTPUT_MOMENTUM)->get_data_handle()),
108 lr.template data<float>(),
111 static_cast<float*>(Output(OUTPUT_PARAM)->get_data_handle()));
116 float momentum_{0.9};
118 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
119 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...