Caffe2 - C++ API
A deep learning, cross platform ML framework
snpe_op.cc
1 #include "caffe2/core/context.h"
2 #include "caffe2/core/logging.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/timer.h"
5 #include "snpe_ffi.h"
6 #include <dlfcn.h>
7 
8 namespace caffe2 {
9 
10 template <typename T>
11 using deleted_unique_ptr = std::unique_ptr<T, std::function<void(T*)>>;
12 
13 class SNPEOp final : public Operator<CPUContext> {
14  public:
15  SNPEOp(const OperatorDef& def, Workspace* ws) : Operator<CPUContext>(def, ws),
16  model_buffer_(OperatorBase::GetSingleArgument<string>("model_buffer", "")),
17  input_name_(OperatorBase::GetSingleArgument<string>("input_name", "data"))
18  {
19  CAFFE_ENFORCE(gSNPELocation() != "", "SNPE library \"", gSNPELocation(), "\" does not exist.");
20  std::ostringstream snpe_ffi;
21  snpe_ffi << gSNPELocation() << "/" << snpe_ffi_so;
22  handle_ = deleted_unique_ptr<void>(dlopen(snpe_ffi.str().c_str(), RTLD_LAZY), [](void* handle) {
23  if (handle) {
24  dlclose(handle);
25  }
26  });
27  if (!handle_.get()) {
28  std::cerr << dlerror() << std::endl;
29  }
30 
31  OPERATOR_NEEDS_FEATURE(handle_.get(), "Couldn't find ", snpe_ffi.str());
32 
33 #define X(n) \
34  dlerror(); \
35  auto* n##_f = (decltype(&n))dlsym(handle_.get(), #n); \
36  OPERATOR_NEEDS_FEATURE(n##_f, dlerror());
37 
38  {
39  X(snpe_has_gpu);
40  X(snpe_create);
41  X(snpe_destroy);
42  X(snpe_get_input_dims);
43  X(snpe_run);
44  X(snpe_copy_output_to);
45  }
46 
47  X(snpe_has_gpu);
48  OPERATOR_NEEDS_FEATURE(snpe_has_gpu_f(), "No GPU found, cannot use SNPE.");
49 
50  X(snpe_create)
51 #undef X
52 
53 // Redefine to use CAFFE_ENFORCE instead of OPERATOR_NEEDS_FEATURE.
54 
55 #define X(n) \
56  dlerror(); \
57  auto* n##_f = (decltype(&n))dlsym(handle_.get(), #n); \
58  CAFFE_ENFORCE(n##_f, dlerror());
59 
60  CAFFE_ENFORCE(def.input_size(), "No inputs.");
61  if (input_name_ == "") {
62  input_name_ = def.input().Get(0);
63  }
64  ctx_ = deleted_unique_ptr<void>(snpe_create_f(reinterpret_cast<const unsigned char *>(model_buffer_.data()),
65  model_buffer_.length(), input_name_.c_str()), [this](void* ctx) {
66  if (ctx) {
67  X(snpe_destroy);
68  snpe_destroy_f(ctx);
69  }
70  });
71  }
72 
73  bool RunOnDevice() override {
74  CAFFE_ENFORCE(gSNPELocation() != "", "SNPE library was never loaded.");
75 
76  X(snpe_get_input_dims);
77  size_t const* dims;
78  size_t dimSize;
79  snpe_get_input_dims_f(ctx_.get(), &dims, &dimSize);
80  if (Input(0).ndim() != dimSize) {
81  if (dimSize == 3 && dimSize == Input(0).ndim() - 1 && Input(0).dim32(0) == 1) {
82  const int C = Input(0).dim32(1);
83  const int H = Input(0).dim32(2);
84  const int W = Input(0).dim32(3);
85  if (dims[0] != C ||
86  dims[1] != H ||
87  dims[2] != W) {
88  CAFFE_THROW("Input size must match what SNPE expects, which in this case is: ",
89  dims[0], " ", dims[1], " ", dims[2]);
90  }
91  } else {
92  CAFFE_THROW("SNPE input dimensions are not compatible.");
93  }
94  } else {
95  for (auto i = 0; i < Input(0).ndim(); ++i) {
96  CAFFE_ENFORCE_EQ(dims[i], Input(0).dim32(i), "SNPE input dimension is not compatible.");
97  }
98  }
99 
100  X(snpe_run);
101  CAFFE_ENFORCE(ctx_.get(), "SNPE context doesn't exist.");
102  snpe_run_f(ctx_.get(), Input(0).data<float>(), Input(0).size(), &dims, &dimSize);
103 
104  std::vector<int64_t> outputDims(dimSize + 1);
105  outputDims[0] = 1;
106  for (auto i = 0; i < dimSize; ++i) {
107  outputDims[i+1] = dims[i];
108  };
109 
110  Output(0)->Resize(outputDims);
111  X(snpe_copy_output_to);
112  snpe_copy_output_to_f(ctx_.get(), Output(0)->mutable_data<float>());
113 
114  CAFFE_ENFORCE(
115  Output(0)->data<float>(), "nullptr where output should be!\n");
116  return true;
117  }
118 
119  private:
120  string model_buffer_;
121  string input_name_;
122  deleted_unique_ptr<void> handle_;
123  // needs to be destroyed *before* handle_
124  deleted_unique_ptr<void> ctx_;
125 };
126 
127 REGISTER_CPU_OPERATOR(SNPE, SNPEOp);
128 }
129 
130 #undef X
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64