Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_softmax_with_loss_op.h
1 
17 #ifndef SPATIAL_SOFTMAX_WITH_LOSS_OP_H_
18 #define SPATIAL_SOFTMAX_WITH_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SpatialSoftmaxWithLossOp final : public Operator<Context> {
29  public:
30  SpatialSoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
33  order_(StringToStorageOrder(
34  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
35  CAFFE_ENFORCE(scale_ >= 0);
36  CAFFE_ENFORCE_EQ(
37  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
38  }
39  USE_OPERATOR_CONTEXT_FUNCTIONS;
40 
41  bool RunOnDevice() override;
42 
43  protected:
44  float scale_;
45  StorageOrder order_;
46 
47  Tensor<Context> losses_; // Per example loss
48  Tensor<Context> rowmax_; // per example row max
49  Tensor<Context> weights_; // unignored weights
50  Tensor<Context> sum_multiplier_; // Vector of ones for summing via dot prod
51  Tensor<Context> total_weight_ptr_;
52  Tensor<Context> scratch_;
53 };
54 
55 template <typename T, class Context>
56 class SpatialSoftmaxWithLossGradientOp final : public Operator<Context> {
57  public:
58  SpatialSoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
59  : Operator<Context>(def, ws),
60  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
61  order_(StringToStorageOrder(
62  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
63  only_loss_(OperatorBase::GetSingleArgument<bool>("only_loss", false)) {
64  CAFFE_ENFORCE(scale_ >= 0);
65  CAFFE_ENFORCE_EQ(
66  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
67  }
68  USE_OPERATOR_CONTEXT_FUNCTIONS;
69 
70  bool RunOnDevice() override;
71 
72  protected:
73  float scale_;
74  Tensor<Context> sum_multiplier_;
75  Tensor<Context> weights_; // unignored weights
76  Tensor<Context> total_weight_ptr_;
77  StorageOrder order_;
78  bool only_loss_;
79  Tensor<Context> scratch_;
80 };
81 
82 } // namespace caffe2
83 
84 #endif // SOFTMAX_WITH_LOSS_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.