Caffe2 - C++ API
A deep learning, cross platform ML framework
layer_norm_op.h
1 #ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H_
2 #define CAFFE2_OPERATORS_LAYER_NORM_OP_H_
3 
4 #include <array>
5 #include <vector>
6 
7 #include <ATen/core/dispatch/OpSchemaRegistration.h>
8 
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/operator.h"
11 #include "caffe2/core/types.h"
12 #include "caffe2/utils/math.h"
13 
14 C10_DECLARE_CAFFE2_OPERATOR(LayerNorm)
15 
16 namespace caffe2 {
17 
18 template <class Context>
19 class LayerNormOp final : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22 
23  template <class... Args>
24  explicit LayerNormOp(Args&&... args)
25  : Operator<Context>(std::forward<Args>(args)...),
26  OP_SINGLE_ARG(int, "axis", axis_, 1),
27  OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f) {}
28 
29  bool RunOnDevice() override {
30  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
31  }
32 
33  template <typename T>
34  bool DoRunWithType() {
35  const auto& X = Input(0);
36  auto* Y = Output(0);
37  const int canonical_axis = X.canonical_axis_index(axis_);
38  std::vector<int64_t> moments_dims(
39  X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
40  moments_dims.push_back(1);
41  auto* mean = Output(1, moments_dims, at::dtype<T>());
42  auto* sig = Output(2, moments_dims, at::dtype<T>());
43  RunLayerNorm<T>(
44  X, canonical_axis, epsilon_, Y, mean, sig, &scale_, &bias_, &context_);
45  return true;
46  }
47 
48  template <typename T>
49  static void RunLayerNorm(
50  const Tensor& X,
51  const int canonical_axis,
52  const float epsilon,
53  Tensor* Y,
54  Tensor* mean,
55  Tensor* sig,
56  Tensor* scale_buffer,
57  Tensor* bias_buffer,
58  Context* context) {
59  CAFFE_ENFORCE_GE(X.dim(), 2, "LayerNorm requires input dim >= 2.");
60  const int M = X.size_to_dim(canonical_axis);
61  const int N = X.size_from_dim(canonical_axis);
62  Y->ResizeLike(X);
63  scale_buffer->Resize(M);
64  bias_buffer->Resize(M);
65  const T* X_data = X.template data<T>();
66  T* Y_data = Y->template mutable_data<T>();
67  T* mean_data = mean->template mutable_data<T>();
68  T* sig_data = sig->template mutable_data<T>();
69  T* scale_data = scale_buffer->template mutable_data<T>();
70  T* bias_data = bias_buffer->template mutable_data<T>();
71  const std::array<int, 2> X_dims = {M, N};
72  const std::array<int, 2> Y_dims = {M, 1};
73  math::Moments<T, Context>(
74  2, X_dims.data(), Y_dims.data(), X_data, mean_data, sig_data, context);
75  ComputeStdDevAndFusedParams<T>(
76  M,
77  mean_data,
78  sig_data,
79  sig_data,
80  scale_data,
81  bias_data,
82  epsilon,
83  context);
84  LayerNormForward<T>(M, N, X_data, scale_data, bias_data, Y_data, context);
85  }
86 
87  private:
88  template <typename T>
89  static void ComputeStdDevAndFusedParams(
90  const int N,
91  const T* mean,
92  const T* var,
93  T* stddev,
94  T* scale,
95  T* bias,
96  float epsilon,
97  Context* context);
98 
99  template <typename T>
100  static void LayerNormForward(
101  const int M,
102  const int N,
103  const T* X,
104  const T* scale,
105  const T* bias,
106  T* Y,
107  Context* context);
108 
109  const int axis_;
110  const float epsilon_;
111 
112  Tensor scale_{Context::GetDeviceType()};
113  Tensor bias_{Context::GetDeviceType()};
114 };
115 
116 template <class Context>
117 class LayerNormGradientOp final : public Operator<Context> {
118  public:
119  USE_OPERATOR_CONTEXT_FUNCTIONS;
120  template <class... Args>
121  explicit LayerNormGradientOp(Args&&... args)
122  : Operator<Context>(std::forward<Args>(args)...),
123  OP_SINGLE_ARG(int, "axis", axis_, 1) {}
124 
125  ~LayerNormGradientOp() {}
126 
127  bool RunOnDevice() override {
128  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
129  }
130 
131  template <typename T>
132  bool DoRunWithType() {
133  const auto& dY = Input(0);
134  const auto& Y = Input(1);
135  const auto& mean = Input(2);
136  const auto& sig = Input(3);
137  const auto& X = Input(4);
138 
139  const int canonical_axis = X.canonical_axis_index(axis_);
140  const int M = X.size_to_dim(canonical_axis);
141  const int N = X.size_from_dim(canonical_axis);
142 
143  auto* dX = Output(0, X.sizes(), at::dtype<T>());
145  &ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
147  &db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
149  &dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
151  &X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
153  &bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
154  const T* dY_data = dY.template data<T>();
155  const T* X_data = X.template data<T>();
156  const T* mean_data = mean.template data<T>();
157  const T* sig_data = sig.template data<T>();
158  T* dX_data = dX->template mutable_data<T>();
159  T* ds_data = ds_.template mutable_data<T>();
160  T* db_data = db_.template mutable_data<T>();
161  T* dY_scale_data = dY_scale_.template mutable_data<T>();
162  T* X_scale_data = X_scale_.template mutable_data<T>();
163  T* bias_data = bias_.template mutable_data<T>();
164 
165  ComputeInternalGradients<T>(M, N, dY_data, X_data, ds_data, db_data);
166  ComputeFusedParams<T>(
167  M,
168  N,
169  mean_data,
170  sig_data,
171  ds_data,
172  db_data,
173  dY_scale_data,
174  X_scale_data,
175  bias_data);
176  LayerNormBackward<T>(
177  M, N, dY_scale_data, dY_data, X_scale_data, X_data, bias_data, dX_data);
178 
179  return true;
180  }
181 
182  private:
183  template <typename T>
184  void ComputeInternalGradients(
185  const int M,
186  const int N,
187  const T* dY,
188  const T* X,
189  T* ds,
190  T* db);
191 
192  template <typename T>
193  void ComputeFusedParams(
194  const int M,
195  const int N,
196  const T* mean,
197  const T* sig,
198  const T* ds,
199  const T* db,
200  T* dY_scale,
201  T* X_scale,
202  T* bias);
203 
204  template <typename T>
205  void LayerNormBackward(
206  const int M,
207  const int N,
208  const T* dY_scale,
209  const T* dY,
210  const T* X_scale,
211  const T* X,
212  const T* bias,
213  T* dX);
214 
215  const int axis_;
216 
217  Tensor ds_;
218  Tensor db_;
219  Tensor dY_scale_;
220  Tensor X_scale_;
221  Tensor bias_;
222 };
223 
224 } // namespace caffe2
225 
226 #endif // CAFFE2_OPERATORS_LAYER_NORM_OP_H_
Definition: any.cpp:108
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13