Caffe2 - C++ API
A deep learning, cross platform ML framework
wngrad_op.cc
1 #include "wngrad_op.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(Wngrad, WngradOp<float, CPUContext>);
6 OPERATOR_SCHEMA(Wngrad)
7  .NumInputs(4)
8  .NumOutputs(2, 4)
9  .AllowInplace({{0, 0}, {1, 1}})
10  .SetDoc(R"DOC(
11 
12 Computes the WnGrad update for an input gradient and accumulated
13 history. This operator implement the optimization algorithm
14 in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
15 Concretely, given inputs (param, grad, seq_b, learning_rate),
16 computes
17 
18  new_seq_b = seq_b + 1 / seq_b * norm(grad)^2
19  effective_lr = learning_rate / (new_seq_b + epsilon)
20  update = learning_rate * grad / (new_seq_b + epsilon)
21  new_param = param + update
22 and returns (new_param, new_seq_b).
23 
24 Optionally returns effective_lr and update as well.
25 
26 )DOC")
27  .Input(0, "param", "Parameters to be updated")
28  .Input(1, "seq_b", "Seq_b history")
29  .Input(2, "grad", "Gradient computed")
30  .Input(3, "lr", "learning rate")
31  .Output(0, "output_param", "Updated parameters")
32  .Output(1, "output_seq_b", "Updated seq_b")
33  .Output(2, "output_effective_lr", "(optional) Effective learning rate")
34  .Output(3, "output_update", "(optional) Actual update that is applied.")
35 
36  .Arg("epsilon", "Default 1e-5");
37 
38 REGISTER_CPU_OPERATOR(SparseWngrad, SparseWngradOp<float, CPUContext>);
39 OPERATOR_SCHEMA(SparseWngrad)
40  .NumInputs(5)
41  .NumOutputs(2)
42  .EnforceOneToOneInplace()
43  .SetDoc(R"DOC(
44 
45 This operator implement the optimization algorithm
46 in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
47 Given inputs (param, seq_b, indices, grad, lr), runs the dense WnGrad
48 update on (param, grad, seq_b, lr), and returns (new_param,
49 new_seq_b) as in the dense case.
50 
51 )DOC")
52  .Input(0, "param", "Parameters to be updated")
53  .Input(1, "seq_b", "seq_b history")
54  .Input(2, "indices", "Sparse indices")
55  .Input(3, "grad", "Gradient computed")
56  .Input(4, "lr", "learning rate")
57  .Output(0, "output_param", "Updated parameters")
58  .Output(1, "output_seq_b", "Updated seq_b")
59  .Arg("epsilon", "Default 1e-5");
60 
61 SHOULD_NOT_DO_GRADIENT(Wngrad);
62 SHOULD_NOT_DO_GRADIENT(SparseWngrad);
63 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13