Caffe2 - C++ API
A deep learning, cross platform ML framework
layer_norm_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LAYER_NORM_OP_H
18 #define CAFFE2_OPERATORS_LAYER_NORM_OP_H
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class LayerNormOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  LayerNormOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)),
33  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
34  ~LayerNormOp() {}
35 
36  template <typename T>
37  bool DoRunWithType();
38 
39  bool RunOnDevice() override {
40  return DoRunWithType<float>();
41  }
42 
43  protected:
44  int axis_;
45  float epsilon_;
46 
47  Tensor<Context> scratch_;
48  Tensor<Context> seg_indices_;
49 };
50 
51 template <class Context>
52 class LayerNormGradientOp : public Operator<Context> {
53  public:
54  USE_OPERATOR_CONTEXT_FUNCTIONS;
55  LayerNormGradientOp(const OperatorDef& operator_def, Workspace* ws)
56  : Operator<Context>(operator_def, ws),
57  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)),
58  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 0.001f)) {}
60 
61  template <typename T>
62  bool DoRunWithType();
63 
64  bool RunOnDevice() override {
65  return DoRunWithType<float>();
66  }
67 
68  protected:
69  int axis_;
70  float epsilon_;
71 
72  Tensor<Context> scratch_;
73  Tensor<Context> gscratch_;
74  Tensor<Context> seg_indices_;
75  Tensor<Context> dstdev_;
76  Tensor<Context> dmean_;
77 };
78 
79 } // namespace caffe2
80 
81 #endif /* CAFFE2_OPERATORS_LAYER_NORM_OP_H */
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.