Caffe2 - C++ API
A deep learning, cross platform ML framework
fp32_momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/timer.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 void fp32_momentum_sgd_update(
10  int N,
11  const float* g,
12  const float* m,
13  float* ng,
14  float* nm,
15  const float* lr,
16  float momentum,
17  bool nesterov,
18  float weight_decay,
19  float* param,
20  Context* /*context*/) {}
21 
22 template <typename T, class Context>
23 class FP32MomentumSGDUpdateOp final : public Operator<Context> {
24  public:
25  USE_OPERATOR_CONTEXT_FUNCTIONS;
26  FP32MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
27  : Operator<Context>(operator_def, ws),
28  momentum_(this->template GetSingleArgument<float>("momentum", 0.0)),
29  weight_decay_(
30  this->template GetSingleArgument<float>("weight_decay", 0.0)),
31  nesterov_(this->template GetSingleArgument<int>("nesterov", 0)) {}
32 
33  bool RunOnDevice() override {
34  auto device_type = Context::GetDeviceType();
35  // Iter live on the CPU
36  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
37  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
38  CAFFE_ENFORCE(Input(LR).size() == 1);
39  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
40  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
41  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
42 
43  fp32_momentum_sgd_update<Context>(
44  Input(GRAD).size(),
45  Input(GRAD).template data<T>(),
46  Input(MOMENTUM).template data<T>(),
47  Output(OUTPUT_GRAD)->template mutable_data<T>(),
48  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
49  Input(LR).template data<float>(),
50  momentum_,
51  nesterov_,
52  weight_decay_,
53  Output(OUTPUT_PARAM)->template mutable_data<T>(),
54  &context_);
55 
56  return true;
57  }
58 
59  protected:
60  float momentum_{0.9};
61  float weight_decay_{0.0};
62  bool nesterov_;
63  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
64  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
65 };
66 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13