Caffe2 - C++ API
A deep learning, cross platform ML framework
gftrl_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename T>
8 struct GFtrlParams {
9  explicit GFtrlParams(OperatorBase* op)
10  : alphaInv(1.0 / op->GetSingleArgument<float>("alpha", 0.005f)),
11  beta(op->GetSingleArgument<float>("beta", 1.0f)),
12  lambda1(op->GetSingleArgument<float>("lambda1", 0.001f)),
13  lambda2(op->GetSingleArgument<float>("lambda2", 0.001f)) {}
14  T alphaInv;
15  T beta;
16  T lambda1;
17  T lambda2;
18 };
19 
20 template <typename T, class Context>
21 class GFtrlOp final : public Operator<Context> {
22  public:
23  USE_OPERATOR_CONTEXT_FUNCTIONS;
24  GFtrlOp(const OperatorDef& operator_def, Workspace* ws)
25  : Operator<Context>(operator_def, ws), params_(this) {
26  CAFFE_ENFORCE(
27  !HasArgument("alpha") || ALPHA >= InputSize(),
28  "Cannot specify alpha by both input and argument");
29  }
30  bool RunOnDevice() override;
31 
32  protected:
33  GFtrlParams<T> params_;
34  INPUT_TAGS(VAR, N_Z, GRAD, ALPHA);
35  OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z);
36 };
37 
38 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13