Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_with_loss_op.h
1 
17 #ifndef SOFTMAX_WITH_LOSS_OP_H_
18 #define SOFTMAX_WITH_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SoftmaxWithLossOp final : public Operator<Context> {
29  public:
30  SoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
33  label_prob_mode_(OperatorBase::GetSingleArgument<int>("label_prob", 0)),
34  order_(StringToStorageOrder(
35  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
36  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
37  CAFFE_ENFORCE(scale_ >= 0);
38  CAFFE_ENFORCE_EQ(
39  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
40  }
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42 
43  bool RunOnDevice() override;
44 
45  protected:
46  float scale_;
47  int label_prob_mode_;
48  StorageOrder order_;
49  int axis_;
50 
51  Tensor<Context> losses_; // Per example loss
52  Tensor<Context> rowmax_; // per example row max
53  Tensor<Context> weights_; // unignored weights
54  Tensor<Context> sum_multiplier_; // Vector of ones for summing via dot prod
55  Tensor<Context> total_weight_ptr_;
56  Tensor<Context> scratch_;
57 };
58 
59 template <typename T, class Context>
60 class SoftmaxWithLossGradientOp final : public Operator<Context> {
61  public:
62  SoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
63  : Operator<Context>(def, ws),
64  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
65  label_prob_mode_(OperatorBase::GetSingleArgument<int>("label_prob", 0)),
66  order_(StringToStorageOrder(
67  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
68  only_loss_(OperatorBase::GetSingleArgument<bool>("only_loss", false)),
69  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
70  CAFFE_ENFORCE(scale_ >= 0);
71  CAFFE_ENFORCE_EQ(
72  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
73  }
74  USE_OPERATOR_CONTEXT_FUNCTIONS;
75 
76  bool RunOnDevice() override;
77 
78  protected:
79  float scale_;
80  int label_prob_mode_;
81  Tensor<Context> sum_multiplier_;
82  Tensor<Context> weights_; // unignored weights
83  Tensor<Context> total_weight_ptr_;
84  StorageOrder order_;
85  bool only_loss_;
86  int axis_;
87  Tensor<Context> scratch_;
88 };
89 
90 } // namespace caffe2
91 
92 #endif // SOFTMAX_WITH_LOSS_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.