3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/timer.h" 8 template <
class Context>
9 void fp16_momentum_sgd_update(
23 template <
typename T,
class Context>
26 USE_OPERATOR_CONTEXT_FUNCTIONS;
29 momentum_(this->
template GetSingleArgument<float>(
"momentum", 0.0)),
31 this->
template GetSingleArgument<float>(
"weight_decay", 0.0)),
32 nesterov_(this->
template GetSingleArgument<int>(
"nesterov", 0)),
35 fp32_update_(this->
template GetSingleArgument<int>(
"fp32_update", 0)) {}
37 bool RunOnDevice()
override {
38 auto device_type = Context::GetDeviceType();
40 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
41 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
42 CAFFE_ENFORCE(
Input(LR).size() == 1);
43 CAFFE_ENFORCE(
Input(GRAD).size() ==
Input(MOMENTUM).size());
44 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
45 Output(OUTPUT_MOMENTUM)->ResizeLike(
Input(MOMENTUM));
47 fp16_momentum_sgd_update<Context>(
49 Input(GRAD).template data<T>(),
50 Input(MOMENTUM).template data<T>(),
51 Output(OUTPUT_GRAD)->template mutable_data<T>(),
52 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
53 Input(LR).template data<float>(),
58 Output(OUTPUT_PARAM)->template mutable_data<T>(),
66 float weight_decay_{0.0};
69 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
70 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...