Caffe2 - C++ API
A deep learning, cross platform ML framework
fp32_momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/timer.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 void fp32_momentum_sgd_update(
10  int N,
11  const float* g,
12  const float* m,
13  float* ng,
14  float* nm,
15  const float* lr,
16  float momentum,
17  bool nesterov,
18  float weight_decay,
19  float* param,
20  Context* /*context*/) {}
21 
22 template <typename T, class Context>
23 class FP32MomentumSGDUpdateOp final : public Operator<Context> {
24  public:
25  USE_OPERATOR_CONTEXT_FUNCTIONS;
26  FP32MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
27  : Operator<Context>(operator_def, ws),
28  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
29  weight_decay_(
30  OperatorBase::GetSingleArgument<float>("weight_decay", 0.0)),
31  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)) {}
32 
33  bool RunOnDevice() override {
34  // Iter live on the CPU
35  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
36  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
37  CAFFE_ENFORCE(Input(LR).size() == 1);
38  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
39  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
40  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
41 
42  fp32_momentum_sgd_update<Context>(
43  Input(GRAD).size(),
44  Input(GRAD).template data<T>(),
45  Input(MOMENTUM).template data<T>(),
46  Output(OUTPUT_GRAD)->template mutable_data<T>(),
47  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
48  Input(LR).template data<float>(),
49  momentum_,
50  nesterov_,
51  weight_decay_,
52  Output(OUTPUT_PARAM)->template mutable_data<T>(),
53  &context_);
54 
55  return true;
56  }
57 
58  protected:
59  float momentum_{0.9};
60  float weight_decay_{0.0};
61  bool nesterov_;
62  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
63  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
64 };
65 }
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.