Caffe2 - C++ API
A deep learning, cross platform ML framework
reduction_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_
18 #define CAFFE2_OPERATORS_REDUCTION_OPS_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/utils/math.h"
25 
26 namespace caffe2 {
27 
28 template <typename T, class Context>
29 class SumElementsOp : public Operator<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32 
33  SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
34  : Operator<Context>(operator_def, ws),
35  average_(OperatorBase::GetSingleArgument<bool>("average", false)) {}
36  SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
37  : Operator<Context>(operator_def, ws), average_(average) {}
38  ~SumElementsOp() {}
39 
40  bool RunOnDevice() override
41 // TODO: T21635002 fix float-divide-by-zero undefined behavior
42 #if defined(__has_feature)
43 #if __has_feature(__address_sanitizer__)
44  __attribute__((__no_sanitize__("float-divide-by-zero")))
45 #endif
46 #endif
47  {
48  auto& X = Input(0);
49  auto* sum = Output(0);
50  sum->Resize(vector<TIndex>());
51  T* data = sum->template mutable_data<T>();
52  math::Sum<T, Context>(
53  X.size(), X.template data<T>(), data, &context_, &scratch_);
54  if (average_) {
55  math::Scale<T, Context>(
56  1,
57  static_cast<T>(1.) / X.size(),
58  sum->template data<T>(),
59  data,
60  &context_);
61  }
62  return true;
63  }
64 
65  private:
66  bool average_;
67  Tensor<Context> scratch_;
68 };
69 
70 template <typename T, class Context>
71 class SumElementsGradientOp : public Operator<Context> {
72  public:
73  USE_OPERATOR_CONTEXT_FUNCTIONS;
74 
75  SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
76  : Operator<Context>(operator_def, ws),
77  average_(OperatorBase::GetSingleArgument<bool>("average", false)) {}
79  const OperatorDef& operator_def,
80  Workspace* ws,
81  bool average)
82  : Operator<Context>(operator_def, ws), average_(average) {}
84 
85  bool RunOnDevice() override;
86 
87  private:
88  bool average_;
89 };
90 
91 template <class Context>
92 class SumSqrElementsOp : public Operator<Context> {
93  public:
94  USE_SIMPLE_CTOR_DTOR(SumSqrElementsOp)
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96 
97  bool RunOnDevice() override {
98  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
99  }
100 
101  template <typename T>
102  bool DoRunWithType() {
103  bool average = OperatorBase::GetSingleArgument<bool>("average", false);
104  auto& X = Input(0);
105  auto* sum = Output(0);
106  sum->Resize(vector<TIndex>());
107  math::SumSqr<T, Context>(
108  X.size(),
109  X.template data<T>(),
110  sum->template mutable_data<T>(),
111  &context_,
112  &scratch_);
113  if (average) {
114  math::Scale<T, Context>(
115  1,
116  float(1.) / X.size(),
117  sum->template data<T>(),
118  sum->template mutable_data<T>(),
119  &context_);
120  }
121  return true;
122  }
123 
124  private:
125  Tensor<Context> scratch_;
126 };
127 
128 template <typename T, class Context, bool ROWWISE>
129 class MaxReductionOp : public Operator<Context> {
130  public:
131  USE_SIMPLE_CTOR_DTOR(MaxReductionOp)
132  USE_OPERATOR_CONTEXT_FUNCTIONS;
133 
134  bool RunOnDevice() override {
135  auto& X = Input(0);
136  CAFFE_ENFORCE_EQ(X.ndim(), 3);
137 
138  const int batch_size = X.dim32(0);
139  const int M = X.dim32(1);
140  const int N = X.dim32(2);
141 
142  auto* Y = Output(0);
143  ROWWISE ? Y->Resize(batch_size, M) : Y->Resize(batch_size, N);
144 
145  if (ROWWISE) {
146  math::RowwiseMax<T, Context>(
147  batch_size * M,
148  N,
149  X.template data<T>(),
150  Y->template mutable_data<T>(),
151  &context_);
152  } else {
153  const int input_size = N * M;
154  for (int i = 0; i < batch_size; ++i) {
155  math::ColwiseMax<T, Context>(
156  M,
157  N,
158  X.template data<T>() + i * input_size,
159  Y->template mutable_data<T>() + i * N,
160  &context_);
161  }
162  }
163  return true;
164  }
165 };
166 
167 template <typename T, class Context, bool ROWWISE>
168 class MaxReductionGradientOp : public Operator<Context> {
169  public:
170  USE_SIMPLE_CTOR_DTOR(MaxReductionGradientOp)
171  USE_OPERATOR_CONTEXT_FUNCTIONS;
172 
173  bool RunOnDevice() override;
174 };
175 
176 } // namespace caffe2
177 
178 #endif
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.