5 REGISTER_CPU_OPERATOR(Adam, AdamOp<float, CPUContext>);
9 .AllowInplace({{0, 0}, {1, 1}, {2, 2}})
10 .DeviceInferenceFunction([](
const OperatorDef& def) {
12 def.has_device_option() ? def.device_option() : DeviceOption();
13 vector<DeviceOption> in_dev(def.input_size(), op_device);
14 vector<DeviceOption> out_dev(def.output_size(), op_device);
16 in_dev[5] = DeviceOption();
17 return std::make_pair(in_dev, out_dev);
21 Computes the Adam update (https://arxiv.org/abs/1412.6980) for an 22 input gradient and momentum parameters. Concretely, given inputs 23 (param, m1, m2, grad, lr, iters), 26 correction_multiplier = sqrt(1 - power(beta2, t)) / 28 m1_o = (beta1 * m1) + (1 - beta1) * grad 29 m2_o = (beta2 * m2) + (1 - beta2) * np.square(grad) 30 grad_o = correction_multiplier * m1_o / \ 31 (sqrt(m2_o) + epsilon) 32 param_o = param + lr * grad_o 34 and returns (param_o, m1_o, m2_o, grad_o), in which grad_o is an optional output 37 .Input(0, "param",
"Parameters to be updated")
38 .Input(1,
"moment_1",
"First moment history")
39 .Input(2,
"moment_2",
"Second moment history")
40 .Input(3,
"grad",
"Gradient computed")
41 .Input(4,
"lr",
"learning rate")
42 .Input(5,
"iter",
"iteration number")
43 .Output(0,
"output_param",
"Updated parameters")
44 .Output(1,
"output_moment_1",
"Updated first moment")
45 .Output(2,
"output_moment_2",
"Updated second moment")
46 .Output(3,
"output_grad",
"Optional Effective gradient")
47 .Arg(
"beta1",
"Default 0.9")
48 .Arg(
"beta2",
"Default 0.999")
49 .Arg(
"epsilon",
"Default 1e-5");
51 REGISTER_CPU_OPERATOR(SparseAdam, SparseAdamOp<float, CPUContext>);
52 OPERATOR_SCHEMA(SparseAdam)
55 .EnforceInplace({{0, 0}, {1, 1}, {2, 2}})
58 Computes the Adam Update for the sparse case. 59 Given inputs (param, moment1, moment2, indices, grad, lr, iter), runs the dense 60 Adam on (param, moment1[indices], momemnt2[indices], lr, iter) and returns 61 (new_param, new_moment1, new_moment2) as in dense case 64 .Input(0, "param",
"Parameters to be updated")
65 .Input(1,
"moment_1",
"First moment history")
66 .Input(2,
"moment_2",
"Second moment history")
67 .Input(3,
"indices",
"Sparse indices")
68 .Input(4,
"grad",
"Gradient computed")
69 .Input(5,
"lr",
"learning rate")
70 .Input(6,
"iter",
"iteration number")
71 .Output(0,
"output_param",
"Updated parameters")
72 .Output(1,
"output_moment_1",
"Updated first moment")
73 .Output(2,
"output_moment_2",
"Updated second moment")
74 .Output(3,
"output_grad",
"Optional Effective gradient")
75 .Arg(
"beta1",
"Default 0.9")
76 .Arg(
"beta2",
"Default 0.999")
77 .Arg(
"epsilon",
"Default 1e-5");
79 REGISTER_CPU_OPERATOR(
81 RowWiseSparseAdamOp<float, CPUContext>);
82 OPERATOR_SCHEMA(RowWiseSparseAdam)
85 .EnforceInplace({{0, 0}, {1, 1}, {2, 2}})
88 Computes a modified Adam Update for the sparse case. 89 Given inputs (param, moment1, moment2, indices, grad, lr, iter), runs the 90 Adam update on (param, moment1[indices], moment2[indices], lr, iter) and returns 91 (new_param, new_moment1, new_moment2), where moment2 is a 1D tensor 92 with length equal to the number of rows in param: 93 shape(moment2) == shape(param)[0]. Each element of moment2 is 94 applied to an entire row of param, and the new moment2 values are 95 calculated by averaging across the row. 98 .Input(0, "param",
"Parameters to be updated")
99 .Input(1,
"moment_1",
"First moment history")
100 .Input(2,
"moment_2",
"Second moment history")
101 .Input(3,
"indices",
"Sparse indices")
102 .Input(4,
"grad",
"Gradient computed")
103 .Input(5,
"lr",
"learning rate")
104 .Input(6,
"iter",
"iteration number")
105 .Output(0,
"output_param",
"Updated parameters")
106 .Output(1,
"output_moment_1",
"Updated first moment")
107 .Output(2,
"output_moment_2",
"Updated second moment")
108 .Output(3,
"output_grad",
"Optional Effective gradient")
109 .Arg(
"beta1",
"Default 0.9")
110 .Arg(
"beta2",
"Default 0.999")
111 .Arg(
"epsilon",
"Default 1e-5");
113 SHOULD_NOT_DO_GRADIENT(Adam);
114 SHOULD_NOT_DO_GRADIENT(SparseAdam);
115 SHOULD_NOT_DO_GRADIENT(RowWiseSparseAdam);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...