1 #ifndef CAFFE2_OPERATORS_NORMALIZE_OP_H_ 2 #define CAFFE2_OPERATORS_NORMALIZE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/eigen_utils.h" 7 #include "caffe2/utils/math.h" 13 template <
typename T,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 template <
class... Args>
21 bool RunOnDevice()
override {
22 const auto& x =
Input(0);
24 const auto* xData = x.template data<T>();
25 auto* y = Output(0, x.sizes(), at::dtype<T>());
26 auto* yData = y->template mutable_data<T>();
28 const auto canonical_axis = x.canonical_axis_index(
29 this->
template GetSingleArgument<int>(
"axis", -1));
30 const int m = x.dim32(canonical_axis);
31 const int n = x.numel() / m;
32 const int sf = x.size_from_dim(canonical_axis + 1);
33 DoNormalize(xData, yData, m, n, sf);
45 using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
47 Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
48 using ConstStridedVec =
49 Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
51 for (
int i = 0; i < n; ++i) {
52 auto base = (i / sf) * sf * m + (i % sf);
53 ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
54 auto norm = xVec.template lpNorm<2>();
55 norm = std::max(norm, kEps_);
56 StridedVec yVec(yData + base, 1, m, InnerStride(sf));
62 template <
typename T,
class Context>
65 USE_OPERATOR_CONTEXT_FUNCTIONS;
66 template <
class... Args>
70 bool RunOnDevice()
override {
71 const auto& x =
Input(0);
72 const auto& gOut =
Input(GRAD_OUT);
74 auto* gIn = Output(GRAD_IN, gOut.sizes(), at::dtype<T>());
76 const auto* xData = x.template data<T>();
77 const auto* gOutData = gOut.template data<T>();
78 auto* gInData = gIn->template mutable_data<T>();
80 const auto canonical_axis = x.canonical_axis_index(
81 this->
template GetSingleArgument<int>(
"axis", -1));
82 const int m = x.dim32(canonical_axis);
83 const int n = x.numel() / m;
84 const int sf = x.size_from_dim(canonical_axis + 1);
85 DoNormalize(xData, gOutData, gInData, m, n, sf);
99 INPUT_TAGS(INPUT, GRAD_OUT);
100 OUTPUT_TAGS(GRAD_IN);
105 #endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...