Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op.cc
1 #include "caffe2/mobile/contrib/arm-compute/core/context.h"
2 #include "caffe2/mobile/contrib/arm-compute/core/operator.h"
3 
4 #include "caffe2/operators/softmax_op.h"
5 
6 namespace caffe2 {
7 
8 template <typename T> class GLSoftmaxOp final : public Operator<GLContext> {
9 public:
10  GLSoftmaxOp(const OperatorDef &operator_def, Workspace *ws)
11  : Operator<GLContext>(operator_def, ws) {}
12  virtual ~GLSoftmaxOp() noexcept {}
13  USE_OPERATOR_FUNCTIONS(GLContext);
14  bool RunOnDevice() override;
15 private:
16  arm_compute::GCSoftmaxLayer softmax_layer_;
17  bool first_run_ = true, second_run_ = true;
18  GLContext::deleted_unique_ptr<const GLTensor<T>> X_;
19 };
20 
21 template <typename T>
23 
24  auto *Xblob = OperatorBase::Inputs()[0];
25  if (first_run_) {
26  X_ = GLContext::getGLTensor<T>(Xblob);
27  }
28 
29  GLTensor<T> *Y =
30  OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>();
31  if (first_run_) {
32  first_run_ = false;
33  Y->ResizeLike(*X_);
34  softmax_layer_.configure(X_->get_underlying(), Y->get_underlying());
35  } else {
36  X_->lazy_allocate(Xblob, second_run_, true);
37  if (second_run_) {
38  second_run_ = false;
39  Y->allocate();
40  }
41  softmax_layer_.run();
42  }
43 
44  return true;
45 }
46 
47 REGISTER_GL_OPERATOR(Softmax, GLSoftmaxOp<DataType>);
48 
49 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...