Caffe2 - C++ API
A deep learning, cross platform ML framework
normalize_l1_op.h
1 #ifndef CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
2 #define CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class NormalizeL1Op final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  USE_SIMPLE_CTOR_DTOR(NormalizeL1Op)
15 
16  bool RunOnDevice() override {
17  const auto& x = Input(0);
18 
19  const auto* xData = x.template data<T>();
20  auto* y = Output(0, x.sizes(), at::dtype<T>());
21  auto* yData = y->template mutable_data<T>();
22 
23  const auto canonical_axis = x.canonical_axis_index(
24  this->template GetSingleArgument<int>("axis", -1));
25  const int m = x.dim32(canonical_axis);
26  const int n = x.numel() / m;
27  const int sf = x.size_from_dim(canonical_axis + 1);
28  DoNormalize(xData, yData, m, n, sf);
29  return true;
30  }
31 
32  private:
33  void
34  DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
35 };
36 
37 } // namespace caffe2
38 
39 #endif // CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13