Caffe2 - C++ API
A deep learning, cross platform ML framework
group_spatial_softmax_op.h
1 
17 #ifndef GROUP_SPATIAL_SOFTMAX_OP_H_
18 #define GROUP_SPATIAL_SOFTMAX_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class GroupSpatialSoftmaxOp final : public Operator<Context> {
29  public:
30  GroupSpatialSoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", 81)),
33  order_(StringToStorageOrder(
34  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
35  CAFFE_ENFORCE_EQ(
36  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
37  }
38  USE_OPERATOR_CONTEXT_FUNCTIONS;
39 
40  bool RunOnDevice() override {
41  // No CPU implementation for now
42  CAFFE_NOT_IMPLEMENTED;
43  }
44 
45  protected:
46  int num_classes_;
47  StorageOrder order_;
48 };
49 
50 template <typename T, class Context>
51 class GroupSpatialSoftmaxGradientOp final : public Operator<Context> {
52  public:
53  GroupSpatialSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
54  : Operator<Context>(def, ws),
55  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", 81)),
56  order_(StringToStorageOrder(
57  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
58  CAFFE_ENFORCE_EQ(
59  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
60  }
61  USE_OPERATOR_CONTEXT_FUNCTIONS;
62 
63  bool RunOnDevice() override {
64  // No CPU implementation for now
65  CAFFE_NOT_IMPLEMENTED;
66  }
67 
68  protected:
69  int num_classes_;
70  StorageOrder order_;
71  Tensor<Context> sum_probs_;
72 };
73 
74 } // namespace caffe2
75 
76 #endif // GROUP_SPATIAL_SOFTMAX_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.