Caffe2 - C++ API
A deep learning, cross platform ML framework
ftrl_op.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 
5 namespace caffe2 {
6 
7 template <typename T>
8 struct FtrlParams {
9  explicit FtrlParams(OperatorBase* op)
10  : alphaInv(1.0 / op->GetSingleArgument<float>("alpha", 0.005f)),
11  beta(op->GetSingleArgument<float>("beta", 1.0f)),
12  lambda1(op->GetSingleArgument<float>("lambda1", 0.001f)),
13  lambda2(op->GetSingleArgument<float>("lambda2", 0.001f)) {}
14  T alphaInv;
15  T beta;
16  T lambda1;
17  T lambda2;
18 };
19 
20 // TODO(dzhulgakov): implement GPU version if necessary
21 template <typename T, class Context>
22 class FtrlOp final : public Operator<Context> {
23  public:
24  USE_OPERATOR_CONTEXT_FUNCTIONS;
25  FtrlOp(const OperatorDef& operator_def, Workspace* ws)
26  : Operator<Context>(operator_def, ws), params_(this) {
27  CAFFE_ENFORCE(
28  !HasArgument("alpha") || ALPHA >= InputSize(),
29  "Cannot specify alpha by both input and argument");
30  }
31  bool RunOnDevice() override;
32 
33  protected:
34  FtrlParams<T> params_;
35  INPUT_TAGS(VAR, N_Z, GRAD, ALPHA);
36  OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z);
37 };
38 
39 template <typename T>
40 class SparseFtrlOp final : public Operator<CPUContext> {
41  public:
42  SparseFtrlOp(const OperatorDef& operator_def, Workspace* ws)
43  : Operator<CPUContext>(operator_def, ws), params_(this) {
44  CAFFE_ENFORCE(
45  !HasArgument("alpha") || ALPHA >= InputSize(),
46  "Cannot specify alpha by both input and argument");
47  }
48 
49  bool RunOnDevice() override {
50  // run time learning rate override
51  if (ALPHA < InputSize()) {
52  CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
53  params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
54  }
55  // Use run-time polymorphism
56  auto& indices = Input(INDICES);
57  if (indices.template IsType<int32_t>()) {
58  DoRun<int32_t>();
59  } else if (indices.template IsType<int64_t>()) {
60  DoRun<int64_t>();
61  } else {
62  LOG(FATAL) << "Unsupported type of INDICES in SparseFtrlOp: "
63  << indices.dtype().name();
64  }
65  return true;
66  }
67 
68  protected:
69  FtrlParams<T> params_;
70  INPUT_TAGS(VAR, N_Z, INDICES, GRAD, ALPHA);
71  OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z);
72 
73  private:
74  template <typename SIndex>
75  void DoRun();
76 };
77 
78 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13