Caffe2 - C++ API
A deep learning, cross platform ML framework
fp16_momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/timer.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 void fp16_momentum_sgd_update(
10  int N,
11  const float16* g,
12  const float16* m,
13  float16* ng,
14  float16* nm,
15  const float* lr,
16  float momentum,
17  bool nesterov,
18  float weight_decay,
19  bool fp32_update,
20  float16* param,
21  Context* /*context*/) {}
22 
23 template <typename T, class Context>
24 class FP16MomentumSGDUpdateOp final : public Operator<Context> {
25  public:
26  USE_OPERATOR_CONTEXT_FUNCTIONS;
27  FP16MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator<Context>(operator_def, ws),
29  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.0)),
30  weight_decay_(
31  OperatorBase::GetSingleArgument<float>("weight_decay", 0.0)),
32  nesterov_(OperatorBase::GetSingleArgument<int>("nesterov", 0)),
33  // when set, fp32_update will read in the fp16 data but
34  // perform all the compute in fp32 precision.
35  fp32_update_(OperatorBase::GetSingleArgument<int>("fp32_update", 0)) {}
36 
37  bool RunOnDevice() override {
38  // Iter live on the CPU
39  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(GRAD));
40  CAFFE_ENFORCE(OperatorBase::InputIsType<Tensor<Context>>(MOMENTUM));
41  CAFFE_ENFORCE(Input(LR).size() == 1);
42  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
43  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
44  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
45 
46  fp16_momentum_sgd_update<Context>(
47  Input(GRAD).size(),
48  Input(GRAD).template data<T>(),
49  Input(MOMENTUM).template data<T>(),
50  Output(OUTPUT_GRAD)->template mutable_data<T>(),
51  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
52  Input(LR).template data<float>(),
53  momentum_,
54  nesterov_,
55  weight_decay_,
56  fp32_update_,
57  Output(OUTPUT_PARAM)->template mutable_data<T>(),
58  &context_);
59 
60  return true;
61  }
62 
63  protected:
64  float momentum_{0.9};
65  float weight_decay_{0.0};
66  bool nesterov_;
67  bool fp32_update_;
68  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
69  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
70 };
71 }
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.