3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/timer.h" 8 template <
class Context>
9 void fp32_momentum_sgd_update(
22 template <
typename T,
class Context>
25 USE_OPERATOR_CONTEXT_FUNCTIONS;
28 momentum_(this->
template GetSingleArgument<float>(
"momentum", 0.0)),
30 this->
template GetSingleArgument<float>(
"weight_decay", 0.0)),
31 nesterov_(this->
template GetSingleArgument<int>(
"nesterov", 0)) {}
33 bool RunOnDevice()
override {
34 auto device_type = Context::GetDeviceType();
36 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(GRAD, device_type));
37 CAFFE_ENFORCE(OperatorBase::InputIsTensorType(MOMENTUM, device_type));
38 CAFFE_ENFORCE(
Input(LR).size() == 1);
39 CAFFE_ENFORCE(
Input(GRAD).size() ==
Input(MOMENTUM).size());
40 Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
41 Output(OUTPUT_MOMENTUM)->ResizeLike(
Input(MOMENTUM));
43 fp32_momentum_sgd_update<Context>(
45 Input(GRAD).template data<T>(),
46 Input(MOMENTUM).template data<T>(),
47 Output(OUTPUT_GRAD)->template mutable_data<T>(),
48 Output(OUTPUT_MOMENTUM)->template mutable_data<T>(),
49 Input(LR).template data<float>(),
53 Output(OUTPUT_PARAM)->template mutable_data<T>(),
61 float weight_decay_{0.0};
63 INPUT_TAGS(GRAD, MOMENTUM, LR, PARAM);
64 OUTPUT_TAGS(OUTPUT_GRAD, OUTPUT_MOMENTUM, OUTPUT_PARAM);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...