1 #ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H_ 2 #define CAFFE2_OPERATORS_LAYER_NORM_OP_H_ 7 #include <ATen/core/dispatch/OpSchemaRegistration.h> 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/core/types.h" 12 #include "caffe2/utils/math.h" 14 C10_DECLARE_CAFFE2_OPERATOR(LayerNorm)
18 template <
class Context>
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 template <
class... Args>
26 OP_SINGLE_ARG(
int,
"axis", axis_, 1),
27 OP_SINGLE_ARG(
float,
"epsilon", epsilon_, 1e-5f) {}
29 bool RunOnDevice()
override {
34 bool DoRunWithType() {
35 const auto& X = Input(0);
37 const int canonical_axis = X.canonical_axis_index(axis_);
38 std::vector<int64_t> moments_dims(
39 X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
40 moments_dims.push_back(1);
41 auto* mean = Output(1, moments_dims, at::dtype<T>());
42 auto* sig = Output(2, moments_dims, at::dtype<T>());
44 X, canonical_axis, epsilon_, Y, mean, sig, &scale_, &bias_, &context_);
49 static void RunLayerNorm(
51 const int canonical_axis,
59 CAFFE_ENFORCE_GE(X.dim(), 2,
"LayerNorm requires input dim >= 2.");
60 const int M = X.size_to_dim(canonical_axis);
61 const int N = X.size_from_dim(canonical_axis);
63 scale_buffer->Resize(M);
64 bias_buffer->Resize(M);
65 const T* X_data = X.template data<T>();
66 T* Y_data = Y->template mutable_data<T>();
67 T* mean_data = mean->template mutable_data<T>();
68 T* sig_data = sig->template mutable_data<T>();
69 T* scale_data = scale_buffer->template mutable_data<T>();
70 T* bias_data = bias_buffer->template mutable_data<T>();
71 const std::array<int, 2> X_dims = {M, N};
72 const std::array<int, 2> Y_dims = {M, 1};
73 math::Moments<T, Context>(
74 2, X_dims.data(), Y_dims.data(), X_data, mean_data, sig_data, context);
75 ComputeStdDevAndFusedParams<T>(
84 LayerNormForward<T>(M, N, X_data, scale_data, bias_data, Y_data, context);
89 static void ComputeStdDevAndFusedParams(
100 static void LayerNormForward(
110 const float epsilon_;
112 Tensor scale_{Context::GetDeviceType()};
113 Tensor bias_{Context::GetDeviceType()};
116 template <
class Context>
119 USE_OPERATOR_CONTEXT_FUNCTIONS;
120 template <
class... Args>
123 OP_SINGLE_ARG(
int,
"axis", axis_, 1) {}
125 ~LayerNormGradientOp() {}
127 bool RunOnDevice()
override {
131 template <
typename T>
132 bool DoRunWithType() {
133 const auto& dY = Input(0);
134 const auto& Y = Input(1);
135 const auto& mean = Input(2);
136 const auto& sig = Input(3);
137 const auto& X = Input(4);
139 const int canonical_axis = X.canonical_axis_index(axis_);
140 const int M = X.size_to_dim(canonical_axis);
141 const int N = X.size_from_dim(canonical_axis);
143 auto* dX = Output(0, X.sizes(), at::dtype<T>());
145 &ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
147 &db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
149 &dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
151 &X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
153 &bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
154 const T* dY_data = dY.template data<T>();
155 const T* X_data = X.template data<T>();
156 const T* mean_data = mean.template data<T>();
157 const T* sig_data = sig.template data<T>();
158 T* dX_data = dX->template mutable_data<T>();
159 T* ds_data = ds_.template mutable_data<T>();
160 T* db_data = db_.template mutable_data<T>();
161 T* dY_scale_data = dY_scale_.template mutable_data<T>();
162 T* X_scale_data = X_scale_.template mutable_data<T>();
163 T* bias_data = bias_.template mutable_data<T>();
165 ComputeInternalGradients<T>(M, N, dY_data, X_data, ds_data, db_data);
166 ComputeFusedParams<T>(
176 LayerNormBackward<T>(
177 M, N, dY_scale_data, dY_data, X_scale_data, X_data, bias_data, dX_data);
183 template <
typename T>
184 void ComputeInternalGradients(
192 template <
typename T>
193 void ComputeFusedParams(
204 template <
typename T>
205 void LayerNormBackward(
226 #endif // CAFFE2_OPERATORS_LAYER_NORM_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...