3 #include "caffe2/core/operator.h"     7 template <
typename Context>
    22   for (
auto i = 0; i < N; ++i) {
    24     float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
    25     float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
    26     ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
    30 template <
typename Context>
    46   for (
auto i = 0; i < N; ++i) {
    48     float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
    49     float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
    50     nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
    54 template <
typename Context>
    55 void adam_compute_output_grad(
    71   for (
auto i = 0; i < N; ++i) {
    73     float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
    74     float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
    75     float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat);
    76     nw[i] = w[i] + lr[0] * ngi;
    80 template <
typename T, 
class Context>
    83   USE_OPERATOR_CONTEXT_FUNCTIONS;
    86         beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
    87         beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
    88         epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
    89   bool RunOnDevice()
 override {
    91     CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
    92     CAFFE_ENFORCE(
Input(LR).numel() == 1);
    93     CAFFE_ENFORCE(
Input(GRAD).numel() == 
Input(PARAM).numel());
    94     CAFFE_ENFORCE(
Input(GRAD).numel() == 
Input(MOMENT_1).numel());
    95     CAFFE_ENFORCE(
Input(GRAD).numel() == 
Input(MOMENT_2).numel());
    96     Output(OUTPUT_PARAM)->ResizeLike(
Input(PARAM));
    97     Output(OUTPUT_MOMENT_1)->ResizeLike(
Input(MOMENT_1));
    98     Output(OUTPUT_MOMENT_2)->ResizeLike(
Input(MOMENT_2));
   101         OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
   103     const auto t = iter + 1;
   104     const auto correction =
   105         std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
   106     if (OutputSize() == 3) {
   107       adam_compute<Context>(
   109           Input(PARAM).template data<T>(),
   110           Input(GRAD).template data<T>(),
   111           Input(MOMENT_1).template data<T>(),
   112           Input(MOMENT_2).template data<T>(),
   113           Output(OUTPUT_PARAM)->template mutable_data<T>(),
   114           Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
   115           Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
   120           Input(LR).template data<T>(),
   123       Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
   124       adam_compute_output_grad<Context>(
   126           Input(PARAM).template data<T>(),
   127           Input(GRAD).template data<T>(),
   128           Input(MOMENT_1).template data<T>(),
   129           Input(MOMENT_2).template data<T>(),
   130           Output(OUTPUT_PARAM)->template mutable_data<T>(),
   131           Output(OUTPUT_MOMENT_1)->template mutable_data<T>(),
   132           Output(OUTPUT_MOMENT_2)->template mutable_data<T>(),
   133           Output(OUTPUT_GRAD)->template mutable_data<T>(),
   138           Input(LR).template data<T>(),
   149   INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
   150   OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
   153 template <
typename T, 
class Context>
   156   USE_OPERATOR_CONTEXT_FUNCTIONS;
   159         beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
   160         beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
   161         epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
   163   bool RunOnDevice()
 override {
   165     CAFFE_ENFORCE_EQ(
Input(PARAM).numel(), 
Input(MOMENT_1).numel());
   166     CAFFE_ENFORCE_EQ(
Input(PARAM).numel(), 
Input(MOMENT_2).numel());
   168         Input(PARAM).size_from_dim(1),
   169         Input(GRAD).size_from_dim(
Input(INDICES).dim()));
   170     CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
   173         this, 
Input(INDICES));
   176   template <
typename SIndex>
   177   bool DoRunWithType() {
   178     const auto* lr = 
Input(LR).template data<T>();
   180         OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
   182     const auto t = iter + 1;
   183     const auto correction =
   184         std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
   186     auto block_size = 
Input(PARAM).numel() / 
Input(PARAM).size(0);
   187     auto n = 
Input(GRAD).numel() / block_size;
   189     const auto* paramIn = 
Input(PARAM).template data<T>();
   190     const auto* indices = 
Input(INDICES).template data<SIndex>();
   191     const auto* gradIn = 
Input(GRAD).template data<T>();
   192     const auto* moment1In = 
Input(MOMENT_1).template data<T>();
   193     const auto* moment2In = 
Input(MOMENT_2).template data<T>();
   194     auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
   195     auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
   196     auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
   198     if (OutputSize() == 3) {
   199       for (
auto i = 0; i < n; ++i) {
   200         auto idx = indices[i];
   202         if (block_size == 1) {
   203           float gi = gradIn[i];
   204           float mi = moment1Out[idx] =
   205               moment1In[idx] * beta1_ + gi * (1 - beta1_);
   206           float vi = moment2Out[idx] =
   207               moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
   208           paramOut[idx] = paramIn[idx] +
   209               lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
   212           auto offsetI = i * block_size;
   213           auto offsetIdx = idx * block_size;
   217               Input(PARAM).numel(),
   218               block_size + offsetIdx,
   219               this->debug_def().input(PARAM),
   220               ", out of bound,  idx:",
   228               block_size + offsetI,
   229               this->debug_def().input(GRAD),
   230               ", out of bound idx, idx:",
   240               moment1In + offsetIdx,
   241               moment2In + offsetIdx,
   242               paramOut + offsetIdx,
   243               moment1Out + offsetIdx,
   244               moment2Out + offsetIdx,
   254       Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
   255       auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
   256       for (
auto i = 0; i < n; ++i) {
   257         auto idx = indices[i];
   259         if (block_size == 1) {
   260           float gi = gradIn[i];
   261           float mi = moment1Out[idx] =
   262               moment1In[idx] * beta1_ + gi * (1 - beta1_);
   263           float vi = moment2Out[idx] =
   264               moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
   265           float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
   266           paramOut[idx] = paramIn[idx] + lr[0] * ngi;
   269           auto offsetI = i * block_size;
   270           auto offsetIdx = idx * block_size;
   274               Input(PARAM).numel(),
   275               block_size + offsetIdx,
   276               this->debug_def().input(PARAM),
   277               ", out of bound,  idx:",
   285               block_size + offsetI,
   286               this->debug_def().input(GRAD),
   287               ", out of bound idx, idx:",
   293           adam_compute_output_grad(
   297               moment1In + offsetIdx,
   298               moment2In + offsetIdx,
   299               paramOut + offsetIdx,
   300               moment1Out + offsetIdx,
   301               moment2Out + offsetIdx,
   319   INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
   320   OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
   323 template <
typename T, 
class Context>
   326   USE_OPERATOR_CONTEXT_FUNCTIONS;
   329         beta1_(this->
template GetSingleArgument<float>(
"beta1", 0.9f)),
   330         beta2_(this->
template GetSingleArgument<float>(
"beta2", 0.999f)),
   331         epsilon_(this->
template GetSingleArgument<float>(
"epsilon", 1e-5f)) {}
   333   bool RunOnDevice()
 override {
   335     CAFFE_ENFORCE_EQ(
Input(PARAM).numel(), 
Input(MOMENT_1).numel());
   336     CAFFE_ENFORCE_EQ(
Input(PARAM).sizes()[0], 
Input(MOMENT_2).numel());
   338         Input(PARAM).size_from_dim(1),
   339         Input(GRAD).size_from_dim(
Input(INDICES).dim()));
   340     CAFFE_ENFORCE_EQ(
Input(LR).numel(), 1);
   343         this, 
Input(INDICES));
   346   template <
typename SIndex>
   347   bool DoRunWithType() {
   348     const auto* lr = 
Input(LR).template data<T>();
   350         OperatorBase::Input<Tensor>(ITER, CPU).
template data<int64_t>()[0];
   352     const auto t = iter + 1;
   353     const auto correction =
   354         std::sqrt(
T(1.) - std::pow(beta2_, t)) / (
T(1.) - std::pow(beta1_, t));
   356     auto block_size = 
Input(PARAM).numel() / 
Input(PARAM).size(0);
   357     auto n = 
Input(GRAD).numel() / block_size;
   359     const auto* paramIn = 
Input(PARAM).template data<T>();
   360     const auto* indices = 
Input(INDICES).template data<SIndex>();
   361     const auto* gradIn = 
Input(GRAD).template data<T>();
   362     const auto* moment1In = 
Input(MOMENT_1).template data<T>();
   363     const auto* moment2In = 
Input(MOMENT_2).template data<T>();
   364     auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
   365     auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
   366     auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
   368     if (OutputSize() == 3) {
   369       for (
auto i = 0; i < n; ++i) {
   370         auto idx = indices[i];
   372         if (block_size == 1) {
   373           float gi = gradIn[i];
   374           float mi = moment1Out[idx] =
   375               moment1In[idx] * beta1_ + gi * (1 - beta1_);
   376           float vi = moment2Out[idx] =
   377               moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
   378           paramOut[idx] = paramIn[idx] +
   379               lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
   382           auto offsetI = i * block_size;
   383           auto offsetIdx = idx * block_size;
   387               Input(PARAM).numel(),
   388               block_size + offsetIdx,
   389               this->debug_def().input(PARAM),
   390               ", out of bound,  idx:",
   398               block_size + offsetI,
   399               this->debug_def().input(GRAD),
   400               ", out of bound idx, idx:",
   406           const float* w = paramIn + offsetIdx;
   407           const float* g = gradIn + offsetI;
   408           const float* m1 = moment1In + offsetIdx;
   409           const float* m2 = moment2In + idx;
   410           float* nw = paramOut + offsetIdx;
   411           float* nm1 = moment1Out + offsetIdx;
   412           float* nm2 = moment2Out + idx;
   415           for (
auto j = 0; j < block_size; ++j) {
   420               m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
   421           for (
auto j = 0; j < block_size; ++j) {
   422             float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
   423             nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
   428       Output(OUTPUT_GRAD)->ResizeLike(
Input(GRAD));
   429       auto* gradOut = Output(OUTPUT_GRAD)->template mutable_data<T>();
   430       for (
auto i = 0; i < n; ++i) {
   431         auto idx = indices[i];
   433         if (block_size == 1) {
   434           float gi = gradIn[i];
   435           float mi = moment1Out[idx] =
   436               moment1In[idx] * beta1_ + gi * (1 - beta1_);
   437           float vi = moment2Out[idx] =
   438               moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
   439           float ngi = gradOut[i] = correction * mi / (std::sqrt(vi) + epsilon_);
   440           paramOut[idx] = paramIn[idx] + lr[0] * ngi;
   443           auto offsetI = i * block_size;
   444           auto offsetIdx = idx * block_size;
   448               Input(PARAM).numel(),
   449               block_size + offsetIdx,
   450               this->debug_def().input(PARAM),
   451               ", out of bound,  idx:",
   459               block_size + offsetI,
   460               this->debug_def().input(GRAD),
   461               ", out of bound idx, idx:",
   467           const float* w = paramIn + offsetIdx;
   468           const float* g = gradIn + offsetI;
   469           const float* m1 = moment1In + offsetIdx;
   470           const float* m2 = moment2In + idx;
   471           float* nw = paramOut + offsetIdx;
   472           float* nm1 = moment1Out + offsetIdx;
   473           float* nm2 = moment2Out + idx;
   474           float* ng = gradOut + offsetI;
   477           for (
auto j = 0; j < block_size; ++j) {
   482               m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
   483           for (
auto j = 0; j < block_size; ++j) {
   484             float mi = nm1[j] = m1[j] * beta1_ + g[j] * (1 - beta1_);
   485             float ngi = ng[j] = correction * mi / (std::sqrt(vi) + epsilon_);
   486             nw[j] = w[j] + lr[0] * ngi;
   498   INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
   499   OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
 
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
 
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...