1 #ifndef CAFFE2_OPERATORS_LARS_OP_H_ 2 #define CAFFE2_OPERATORS_LARS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 offset_(this->
template GetSingleArgument<float>(
"offset", 0.5)),
18 lr_min_(this->
template GetSingleArgument<float>(
"lr_min", 0.02)) {}
20 bool RunOnDevice()
override {
24 dX.numel() == X.numel(),
"Gradient size doesn't match parameter size.");
25 CAFFE_ENFORCE_GE(offset_, 0);
26 CAFFE_ENFORCE_GE(lr_min_, 0);
29 auto& trust =
Input(3);
30 auto& lr_max =
Input(4);
32 auto* lr_rescaled = Output(0, vector<int64_t>{1}, at::dtype<T>());
34 ReinitializeTensor(&X_norm_tensor_, {1}, at::dtype<T>().device(Context::GetDeviceType()));
35 T* X_norm_ = X_norm_tensor_.template mutable_data<T>();
37 ReinitializeTensor(&dX_norm_tensor_, {1}, at::dtype<T>().device(Context::GetDeviceType()));
38 T* dX_norm_ = dX_norm_tensor_.template mutable_data<T>();
43 dX.template data<T>(),
48 wd.template data<T>(),
49 trust.template data<T>(),
50 lr_max.template data<T>(),
55 lr_rescaled->template mutable_data<T>());
68 math::SumSqr(N, X_data, X_norm, &context_);
69 math::Sqrt(1, X_norm, X_norm, &context_);
70 math::SumSqr(N, dX_data, dX_norm, &context_);
71 math::Sqrt(1, dX_norm, dX_norm, &context_);
74 void ComputeLearningRate(
93 #endif // CAFFE2_OPERATORS_LARS_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...