Caffe2 - C++ API
A deep learning, cross platform ML framework
fp16_momentum_sgd_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/timer.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 void fp16_momentum_sgd_update(
10  int N,
11  const at::Half* g,
12  const at::Half* m,
13  at::Half* ng,
14  at::Half* nm,
15  const float* lr,
16  float momentum,
17  bool nesterov,
18  float weight_decay,
19  bool fp32_update,
20  at::Half* param,
21  Context* /*context*/) {}
22 
23 template <typename T, class Context>
24 class FP16MomentumSGDUpdateOp final : public Operator<Context> {
25  public:
26  USE_OPERATOR_CONTEXT_FUNCTIONS;
27  FP16MomentumSGDUpdateOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator<Context>(operator_def, ws),
29  momentum_(this->template GetSingleArgument<float>("momentum", 0.0)),
30  weight_decay_(
31  this->template GetSingleArgument<float>("weight_decay", 0.0)),
32  nesterov_(this->template GetSingleArgument<int>("nesterov", 0)),
33  // when set, fp32_update will read in the fp16 data but
34  // perform all the compute in fp32 precision.
35  fp32_update_(this->template GetSingleArgument<int>("fp32_update", 0)) {}
36 
37  bool RunOnDevice() override {
38  auto device_type = Context::GetDeviceType();
39  // Iter live on the CPU
40  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
41  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
42  CAFFE_ENFORCE(Input(LR).size() == 1);
43  CAFFE_ENFORCE(Input(GRAD).size() == Input(MOMENTUM).size());
44  Output(OUTPUT_GRAD)->ResizeLike(Input(GRAD));
45  Output(OUTPUT_MOMENTUM)->ResizeLike(Input(MOMENTUM));
46 
47  fp16_momentum_sgd_update<Context>(
48  Input(GRAD).size(),
49  Input(GRAD).template data<T>(),
50  Input(MOMENTUM).template data<T>(),
51  Output(OUTPUT_GRAD)->template mutable_data<T>(),
52  Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
53  Input(LR).template data<float>(),
54  momentum_,
55  nesterov_,
56  weight_decay_,
57  fp32_update_,
58  Output(OUTPUT_PARAM)->template mutable_data<T>(),
59  &context_);
60 
61  return true;
62  }
63 
64  protected:
65  float momentum_{0.9};
66  float weight_decay_{0.0};
67  bool nesterov_;
68  bool fp32_update_;
69  INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
70  OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
71 };
72 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13