Caffe2 - C++ API
A deep learning, cross platform ML framework
cross_entropy_op.h
1 #ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
2 #define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class LabelCrossEntropyOp final : public Operator<Context> {
13  public:
14  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp);
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  bool RunOnDevice() override;
17 
18  protected:
19  static constexpr T kLOG_THRESHOLD() {
20  return static_cast<T>(1e-20);
21  }
22  // Input: X, label
23  // Output: Y
24 };
25 
26 template <typename T, class Context>
27 class LabelCrossEntropyGradientOp final : public Operator<Context> {
28  public:
29  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp);
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  bool RunOnDevice() override;
32 
33  protected:
34  // Input: X, label, dY
35  // Ouptut: dX. There is no gradient with respect to the label.
36  static constexpr T kLOG_THRESHOLD() {
37  return static_cast<T>(1e-20);
38  }
39 };
40 
41 // Hacky: turns a vector of probabilities into a 2-column matrix with
42 // complimentary probabilities for binary classification
43 template <typename T, class Context>
44 class MakeTwoClassOp final : public Operator<Context> {
45  public:
46  USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp);
47  USE_OPERATOR_CONTEXT_FUNCTIONS;
48  bool RunOnDevice() override;
49 
50  protected:
51  // Input: X
52  // Output: Y = vstack(1-X, X)
53 };
54 
55 template <typename T, class Context>
56 class MakeTwoClassGradientOp final : public Operator<Context> {
57  public:
58  USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp);
59  USE_OPERATOR_CONTEXT_FUNCTIONS;
60  bool RunOnDevice() override;
61 
62  protected:
63  // Input: dY
64  // Ouptut: dX
65 };
66 
67 template <typename T, class Context>
68 class SigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
69  public:
70  USE_OPERATOR_CONTEXT_FUNCTIONS;
71  template <class... Args>
72  explicit SigmoidCrossEntropyWithLogitsOp(Args&&... args)
73  : Operator<Context>(std::forward<Args>(args)...),
74  log_D_trick_(
75  this->template GetSingleArgument<bool>("log_D_trick", false)),
76  unjoined_lr_loss_(
77  this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
78  CAFFE_ENFORCE(
79  !(log_D_trick_ && unjoined_lr_loss_),
80  "log_D_trick_ and unjoined_lr_loss_ cannot be set as True simultaneously");
81  }
82 
83  bool RunOnDevice() override;
84 
85  protected:
86  bool log_D_trick_;
87  bool unjoined_lr_loss_;
88 };
89 
90 template <typename T, class Context>
91 class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator<Context> {
92  public:
93  USE_OPERATOR_CONTEXT_FUNCTIONS;
94  template <class... Args>
95  explicit SigmoidCrossEntropyWithLogitsGradientOp(Args&&... args)
96  : Operator<Context>(std::forward<Args>(args)...),
97  log_D_trick_(
98  this->template GetSingleArgument<bool>("log_D_trick", false)),
99  unjoined_lr_loss_(
100  this->template GetSingleArgument<bool>("unjoined_lr_loss", false)) {
101  }
102 
103  bool RunOnDevice() override;
104 
105  protected:
106  bool log_D_trick_;
107  bool unjoined_lr_loss_;
108 };
109 
110 template <typename T, class Context>
111 class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
112  public:
113  USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp);
114  USE_OPERATOR_CONTEXT_FUNCTIONS;
115  bool RunOnDevice() override;
116 };
117 
118 template <typename T, class Context>
120  : public Operator<Context> {
121  public:
123  USE_OPERATOR_CONTEXT_FUNCTIONS;
124  bool RunOnDevice() override;
125 };
126 
127 template <typename T, class Context>
128 class CAFFE2_API CrossEntropyOp final : public Operator<Context> {
129  public:
130  USE_SIMPLE_CTOR_DTOR(CrossEntropyOp);
131  USE_OPERATOR_CONTEXT_FUNCTIONS;
132  bool RunOnDevice() override;
133 
134  protected:
135  // Input: X, label
136  // Output: Y
137  static constexpr T kLOG_THRESHOLD() {
138  return static_cast<T>(1e-20);
139  }
140 };
141 
142 template <typename T, class Context>
143 class CAFFE2_API CrossEntropyGradientOp final : public Operator<Context> {
144  public:
145  USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp);
146  USE_OPERATOR_CONTEXT_FUNCTIONS;
147  bool RunOnDevice() override;
148 
149  protected:
150  // Input: X, label, dY
151  // Ouptut: dX. There is no gradient with respect to the label.
152  static constexpr T kLOG_THRESHOLD() {
153  return static_cast<T>(1e-20);
154  }
155 };
156 
157 } // namespace caffe2
158 
159 #endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13