Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op_mobile.h
1 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_MOBILE_OP_H_
2 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_MOBILE_OP_H_
3 
4 #include "caffe2/core/common.h"
5 
6 #ifdef C10_MOBILE
7 
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/operators/conv_transpose_unpool_op_base.h"
11 
12 namespace caffe2 {
13 
14 template <typename T, class Context>
15 class ConvTransposeMobileOp final : public ConvTransposeUnpoolBase<Context> {
16  public:
17  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context);
18  ConvTransposeMobileOp(const OperatorDef& operator_def, Workspace* ws)
19  : ConvTransposeUnpoolBase<Context>(operator_def, ws) {
20  OPERATOR_NEEDS_FEATURE(order_ == StorageOrder::NCHW, "Only NCHW order is supported right now.");
21  OPERATOR_NEEDS_FEATURE(
22  this->pad_l() == 0, "operator does not handle row width padding");
23  OPERATOR_NEEDS_FEATURE(
24  this->pad_r() == 0, "operator does not handle row width padding");
25  OPERATOR_NEEDS_FEATURE(this->stride_w() <= 4, "stride width must be <= 4");
26  }
27 
28  bool RunOnDeviceWithOrderNCHW() override;
29  bool RunOnDeviceWithOrderNHWC() override;
30 
31  private:
32  // We store a numThreasds per-worker tiles of Y, and numThreads per-worker threadBuffer for the
33  // gemm output, laid out in that order.
34  Tensor threadBuffer_{CPU};
35 
36  // Input: X, W, b
37  // Output: Y
38  INPUT_TAGS(INPUT, FILTER, BIAS);
39 };
40 
41 } // namespace caffe2
42 
43 #endif // C10_MOBILE
44 
45 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_MOBILE_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13