Caffe2 - C++ API
A deep learning, cross platform ML framework
cross_entropy_op.h
1 
17 #ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
18 #define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class LabelCrossEntropyOp final : public Operator<Context> {
29  public:
30  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp);
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  bool RunOnDevice() override;
33 
34  protected:
35  static constexpr T kLOG_THRESHOLD() {
36  return static_cast<T>(1e-20);
37  }
38  // Input: X, label
39  // Output: Y
40 };
41 
42 template <typename T, class Context>
43 class LabelCrossEntropyGradientOp final : public Operator<Context> {
44  public:
45  USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp);
46  USE_OPERATOR_CONTEXT_FUNCTIONS;
47  bool RunOnDevice() override;
48 
49  protected:
50  // Input: X, label, dY
51  // Ouptut: dX. There is no gradient with respect to the label.
52  static constexpr T kLOG_THRESHOLD() {
53  return static_cast<T>(1e-20);
54  }
55 };
56 
57 // Hacky: turns a vector of probabilities into a 2-column matrix with
58 // complimentary probabilities for binary classification
59 template <typename T, class Context>
60 class MakeTwoClassOp final : public Operator<Context> {
61  public:
62  USE_SIMPLE_CTOR_DTOR(MakeTwoClassOp);
63  USE_OPERATOR_CONTEXT_FUNCTIONS;
64  bool RunOnDevice() override;
65 
66  protected:
67  // Input: X
68  // Output: Y = vstack(1-X, X)
69 };
70 
71 template <typename T, class Context>
72 class MakeTwoClassGradientOp final : public Operator<Context> {
73  public:
74  USE_SIMPLE_CTOR_DTOR(MakeTwoClassGradientOp);
75  USE_OPERATOR_CONTEXT_FUNCTIONS;
76  bool RunOnDevice() override;
77 
78  protected:
79  // Input: dY
80  // Ouptut: dX
81 };
82 
83 template <typename T, class Context>
84 class SigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
85  public:
86  USE_SIMPLE_CTOR_DTOR(SigmoidCrossEntropyWithLogitsOp);
87  USE_OPERATOR_CONTEXT_FUNCTIONS;
88  bool RunOnDevice() override;
89 };
90 
91 template <typename T, class Context>
92 class SigmoidCrossEntropyWithLogitsGradientOp final : public Operator<Context> {
93  public:
94  USE_SIMPLE_CTOR_DTOR(SigmoidCrossEntropyWithLogitsGradientOp);
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  bool RunOnDevice() override;
97 };
98 
99 template <typename T, class Context>
100 class WeightedSigmoidCrossEntropyWithLogitsOp final : public Operator<Context> {
101  public:
102  USE_SIMPLE_CTOR_DTOR(WeightedSigmoidCrossEntropyWithLogitsOp);
103  USE_OPERATOR_CONTEXT_FUNCTIONS;
104  bool RunOnDevice() override;
105 };
106 
107 template <typename T, class Context>
109  : public Operator<Context> {
110  public:
112  USE_OPERATOR_CONTEXT_FUNCTIONS;
113  bool RunOnDevice() override;
114 };
115 
116 template <typename T, class Context>
117 class CrossEntropyOp final : public Operator<Context> {
118  public:
119  USE_SIMPLE_CTOR_DTOR(CrossEntropyOp);
120  USE_OPERATOR_CONTEXT_FUNCTIONS;
121  bool RunOnDevice() override;
122 
123  protected:
124  // Input: X, label
125  // Output: Y
126  static constexpr T kLOG_THRESHOLD() {
127  return static_cast<T>(1e-20);
128  }
129 };
130 
131 template <typename T, class Context>
132 class CrossEntropyGradientOp final : public Operator<Context> {
133  public:
134  USE_SIMPLE_CTOR_DTOR(CrossEntropyGradientOp);
135  USE_OPERATOR_CONTEXT_FUNCTIONS;
136  bool RunOnDevice() override;
137 
138  protected:
139  // Input: X, label, dY
140  // Ouptut: dX. There is no gradient with respect to the label.
141  static constexpr T kLOG_THRESHOLD() {
142  return static_cast<T>(1e-20);
143  }
144 };
145 
146 } // namespace caffe2
147 
148 #endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
Copyright (c) 2016-present, Facebook, Inc.