Caffe2 - C++ API
A deep learning, cross platform ML framework
layer_norm_op.cc
1 
17 #include "caffe2/operators/layer_norm_op.h"
18 
19 namespace caffe2 {
20 
21 namespace {
22 template <typename T>
23 using EigenMatrixMapRowMajor = Eigen::Map<
24  Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
25 
26 template <typename T>
27 using ConstEigenMatrixMapRowMajor = Eigen::Map<
28  const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
29 } // namespace
30 
31 template <>
32 template <>
33 bool LayerNormOp<CPUContext>::DoRunWithType<float>() {
34  const auto& input = Input(0);
35  auto* output = Output(0);
36  auto* mean = Output(1);
37  auto* stdev = Output(2);
38 
39  CAFFE_ENFORCE_GE(input.dims().size(), 2, "LayerNorm requires input dim >= 2");
40 
41  const auto canonical_axis = input.canonical_axis_index(axis_);
42  const int left = input.size_to_dim(canonical_axis);
43  const int right = input.size_from_dim(canonical_axis);
44 
45  output->ResizeLike(input);
46  std::vector<TIndex> stats_dims(
47  input.dims().begin(), input.dims().begin() + canonical_axis);
48  stats_dims.push_back(1);
49  mean->Resize(stats_dims);
50  stdev->Resize(stats_dims);
51 
52  auto input_map = ConstEigenMatrixMapRowMajor<float>(
53  input.template data<float>(), left, right);
54  auto mean_map = EigenMatrixMapRowMajor<float>(
55  mean->template mutable_data<float>(), left, 1);
56  auto stdev_map = EigenMatrixMapRowMajor<float>(
57  stdev->template mutable_data<float>(), left, 1);
58  auto output_map = EigenMatrixMapRowMajor<float>(
59  output->template mutable_data<float>(), left, right);
60 
61  auto sqr = [](float f) { return f * f; };
62  auto add_ep = [this](float f) { return f + epsilon_; };
63  auto fsqrt = [](float f) { return std::sqrt(f); };
64  // Calculate row-wise statistics
65  mean_map = input_map.rowwise().mean();
66  stdev_map =
67  (input_map.unaryExpr(sqr).rowwise().mean() - mean_map.unaryExpr(sqr))
68  .unaryExpr(add_ep)
69  .unaryExpr(fsqrt);
70  output_map = (input_map - mean_map.replicate(1, right))
71  .cwiseQuotient(stdev_map.replicate(1, right));
72 
73  return true;
74 }
75 
76 REGISTER_CPU_OPERATOR(LayerNorm, LayerNormOp<CPUContext>);
77 
78 template <>
79 template <>
80 bool LayerNormGradientOp<CPUContext>::DoRunWithType<float>() {
81  const auto& dout = Input(0);
82  const auto& norm_outputs = Input(1);
83  const auto& means = Input(2);
84  const auto& stdev = Input(3);
85  const auto& norm_inputs = Input(4);
86  auto* ginput = Output(0);
87 
88  const auto canonical_axis = norm_inputs.canonical_axis_index(axis_);
89  const int left = norm_inputs.size_to_dim(canonical_axis);
90  const int right = norm_inputs.size_from_dim(canonical_axis);
91 
92  ginput->ResizeLike(norm_inputs);
93 
94  auto dout_map = ConstEigenMatrixMapRowMajor<float>(
95  dout.template data<float>(), left, right);
96  auto means_map =
97  ConstEigenMatrixMapRowMajor<float>(means.template data<float>(), left, 1);
98  auto stdev_map =
99  ConstEigenMatrixMapRowMajor<float>(stdev.template data<float>(), left, 1);
100  auto norm_inputs_map = ConstEigenMatrixMapRowMajor<float>(
101  norm_inputs.template data<float>(), left, right);
102  auto ginput_map = EigenMatrixMapRowMajor<float>(
103  ginput->template mutable_data<float>(), left, right);
104 
105  // Helper functors
106  auto sqr = [](float f) { return f * f; };
107  auto recip = [](float f) { return 1.0f / f; };
108  auto neg_recip = [](float f) { return -1.0f / f; };
109 
110  // Gradients - output block
111  // -1 / (stdev + epsilon)^2 * \sum_j^D x_ij - mean * dout
112  // First part: -1 / (stdev + epsilon)^2
113  auto dstdev_end_0 = stdev_map.unaryExpr(sqr).unaryExpr(neg_recip);
114  // Second part: \sum_j^D x_ij - mean * dout
115  auto dstdev_end_1 = (norm_inputs_map - means_map.replicate(1, right))
116  .cwiseProduct(dout_map)
117  .rowwise()
118  .sum();
119  auto dstdev_end = dstdev_end_0.cwiseProduct(dstdev_end_1);
120  // \sum_j^D -dout * 1/(std+epsilon)
121  auto dmean_end = stdev_map.unaryExpr(neg_recip)
122  .replicate(1, right)
123  .cwiseProduct(dout_map)
124  .rowwise()
125  .sum();
126  // 1.0 / (stdev + epsilon) * dout
127  auto dx_end =
128  stdev_map.unaryExpr(recip).replicate(1, right).cwiseProduct(dout_map);
129 
130  // Gradients - standard deviation block
131  // -1.0*(mean / stdev) * dstdev_end
132  auto dmean_stdev = stdev_map.unaryExpr(neg_recip)
133  .cwiseProduct(means_map)
134  .replicate(1, right)
135  .cwiseProduct(dstdev_end);
136  // (mean / (D*stdev)) * dstdev
137  auto dx_stdev = (1.0f / right) *
138  norm_inputs_map.cwiseQuotient(stdev_map.replicate(1, right))
139  .cwiseProduct(dstdev_end.replicate(1, right));
140 
141  // Gradients - mean block
142  auto dmean = dmean_end + dmean_stdev;
143  auto dx_mean = (1.0f / right) * dmean.replicate(1, right);
144 
145  ginput_map = dx_end + dx_stdev + dx_mean;
146 
147  return true;
148 }
149 
150 OPERATOR_SCHEMA(LayerNormGradient).NumInputs(5).NumOutputs(1);
151 
152 REGISTER_CPU_OPERATOR(LayerNormGradient, LayerNormGradientOp<CPUContext>);
153 
154 namespace {
155 
156 class GetLayerNormGradient : public GradientMakerBase {
157  using GradientMakerBase::GradientMakerBase;
158  vector<OperatorDef> GetGradientDefs() override {
159  return SingleGradientDef(
160  "LayerNormGradient",
161  "",
162  vector<string>{GO(0), O(0), O(1), O(2), I(0)},
163  vector<string>{GI(0)});
164  }
165 };
166 
167 } // namespace
168 
169 REGISTER_GRADIENT(LayerNorm, GetLayerNormGradient);
170 
171 OPERATOR_SCHEMA(LayerNorm)
172  .NumInputs(1)
173  .NumOutputs(3)
174  .TensorInferenceFunction([](const OperatorDef& def,
175  const vector<TensorShape>& in) {
176  vector<TensorShape> out(3);
177  auto input_dims_long = GetDimsVector(in[0]);
178  std::vector<int> input_dims(
179  input_dims_long.begin(), input_dims_long.end());
180  out[0] = CreateTensorShape(input_dims, TensorProto::FLOAT);
181 
182  ArgumentHelper helper(def);
183 
184  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
185  const auto canonical_axis =
186  canonical_axis_index_(axis, in[0].dims().size());
187  std::vector<int> stat_dims(
188  input_dims.begin(), input_dims.begin() + canonical_axis);
189  stat_dims.push_back(1);
190  out[1] = CreateTensorShape(stat_dims, TensorProto::FLOAT);
191  out[2] = CreateTensorShape(stat_dims, TensorProto::FLOAT);
192  return out;
193  })
194  .SetDoc(R"DOC(
195 Computes layer normalization as described in https://arxiv.org/pdf/1607.06450.pdf.
196 Given an input vector x \in [a_0, a_1, ...,a_{k-1}, a_k, ..., a_{n-1}],
197 this op treats dimensions a_k through a_{n-1} as feature vectors. For each
198 feature vector, the op contains the mean and standard deviation. Then,
199 it returns the normalized values (with respect to the feature vector).
200 
201 Note that this op does not contain the scale an bias terms described in the
202 paper. Simply follow this op with an FC op to add those. Concretely, this op
203 implements:
204 
205 h = \frac{1}{\sigma}(a - \mu)
206 where \mu = \frac{1}{H}\sum_{i=1}^{H} a_i
207 and \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^{H}(a_i - \mu)^2}
208 where H is the number of hidden units (i.e. product of dimensions from 'axis'
209 to the end.)
210 )DOC")
211  .Arg(
212  "axis",
213  "(int) default to 1; Describes axis of the inputs. Defaults to one "
214  "because the 0th axis most likely describes the batch size")
215  .Arg(
216  "epsilon",
217  "(float) default to 0.001. Small value to be added to the stdev when"
218  " dividing out by that value. This prevents division by zero.")
219  .Input(
220  0,
221  "input",
222  "Input tensor which layer normalization will be applied to")
223  .Output(0, "output", "Normalized values")
224  .Output(1, "mean", "Mean values for each feature vector")
225  .Output(2, "stddev", "Standard deviations for each feature vector");
226 
227 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.