Caffe2 - C++ API
A deep learning, cross platform ML framework
lpnorm_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LPNORM_OP_H_
18 #define CAFFE2_OPERATORS_LPNORM_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context>
27 class LpNormOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  LpNormOp(const OperatorDef& def, Workspace* ws)
31  : Operator<Context>(def, ws),
32  p_(OperatorBase::GetSingleArgument<int>("p", 2)) {
33  CAFFE_ENFORCE(p_ == 1 || p_ == 2, "p should be either 1 or 2.");
34  }
35 
36  bool RunOnDevice() override;
37 
38  protected:
39  int p_;
40  INPUT_TAGS(X_IN);
41  OUTPUT_TAGS(OUT);
42  // Input: X; Output: Norm
43 };
44 
45 template <typename T, class Context>
46 class LpNormGradientOp : public Operator<Context> {
47  public:
48  USE_OPERATOR_CONTEXT_FUNCTIONS;
49  LpNormGradientOp(const OperatorDef& def, Workspace* ws)
50  : Operator<Context>(def, ws),
51  p_(OperatorBase::GetSingleArgument<int>("p", 2)) {
52  CAFFE_ENFORCE(p_ == 1 || p_ == 2, "p should be either 1 or 2.");
53  }
54 
55  bool RunOnDevice() override;
56 
57  protected:
58  int p_;
59  INPUT_TAGS(X_IN, DER_NORM_IN);
60  OUTPUT_TAGS(DER_X_OUT);
61  // Input: X, dNorm; Output: dX
62 };
63 
64 } // namespace caffe2
65 
66 #endif // CAFFE2_OPERATORS_LPNORM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.