Caffe2 - C++ API
A deep learning, cross platform ML framework
normalize_op.h
1 
17 #ifndef CAFFE2_OPERATORS_NORMALIZE_OP_H_
18 #define CAFFE2_OPERATORS_NORMALIZE_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context>
27 class NormalizeOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  NormalizeOp(const OperatorDef& def, Workspace* ws)
31  : Operator<Context>(def, ws) {}
32 
33  bool RunOnDevice() override {
34  const auto& x = Input(0);
35  auto* y = Output(0);
36  const auto* xData = x.template data<T>();
37  y->ResizeLike(x);
38  auto* yData = y->template mutable_data<T>();
39 
40  const auto canonical_axis = x.canonical_axis_index(
41  OperatorBase::GetSingleArgument<int>("axis", -1));
42  const int m = x.dim32(canonical_axis);
43  const int n = x.size() / m;
44  const int sf = x.size_from_dim(canonical_axis + 1);
45  DoNormalize(xData, yData, m, n, sf);
46  return true;
47  }
48 
49  private:
50  void
51  DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
52 };
53 
54 template <typename T, class Context>
55 class NormalizeGradientOp final : public Operator<Context> {
56  public:
57  USE_OPERATOR_CONTEXT_FUNCTIONS;
58  NormalizeGradientOp(const OperatorDef& def, Workspace* ws)
59  : Operator<Context>(def, ws) {}
60 
61  bool RunOnDevice() override {
62  const auto& x = Input(0);
63  const auto& gOut = Input(GRAD_OUT);
64  auto* gIn = Output(GRAD_IN);
65  gIn->ResizeLike(gOut);
66 
67  const auto* xData = x.template data<T>();
68  const auto* gOutData = gOut.template data<T>();
69  auto* gInData = gIn->template mutable_data<T>();
70 
71  const auto canonical_axis = x.canonical_axis_index(
72  OperatorBase::GetSingleArgument<int>("axis", -1));
73  const int m = x.dim32(canonical_axis);
74  const int n = x.size() / m;
75  const int sf = x.size_from_dim(canonical_axis + 1);
76  DoNormalize(xData, gOutData, gInData, m, n, sf);
77  return true;
78  }
79 
80  private:
81  void DoNormalize(
82  const T* xData,
83  const T* gOutData,
84  T* gInData,
85  const int m,
86  const int n,
87  const int sf);
88 
89  INPUT_TAGS(INPUT, GRAD_OUT);
90  OUTPUT_TAGS(GRAD_IN);
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.