Caffe2 - C++ API
A deep learning, cross platform ML framework
upsample_nearest_op.h
1 
17 #ifndef UPSAMPLE_NEAREST_OP_H_
18 #define UPSAMPLE_NEAREST_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class UpsampleNearestOp final : public Operator<Context> {
29  public:
30  UpsampleNearestOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(this->template GetSingleArgument<int>("scale", 2)) {
33  DCHECK_GE(scale_, 1);
34  }
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36 
37  bool RunOnDevice() override {
38  auto& X = Input(0);
39 
40  auto out_shape = X.sizes().vec();
41  out_shape[X.dim() - 1] *= scale_;
42  out_shape[X.dim() - 2] *= scale_;
43  auto* Y = Output(0, out_shape, at::dtype<T>());
44 
45  int d1;
46  int d2;
47  int d3;
48  if (X.dim() == 3) {
49  d1 = Y->dim32(0);
50  d2 = Y->dim32(1);
51  d3 = Y->dim32(2);
52  } else {
53  d1 = Y->dim32(0) * Y->dim32(1);
54  d2 = Y->dim32(2);
55  d3 = Y->dim32(3);
56  }
57 
58  const T *input_data = X.template data<T>();
59  T *output_data = Y->template mutable_data<T>();
60  int scaled_d2 = d2 / scale_;
61  int scaled_d3 = d3 / scale_;
62 
63 #ifdef _OPENMP
64 #if (_OPENMP >= 201307)
65 #pragma omp parallel for simd
66 #else
67 #pragma omp parallel for
68 #endif
69 #endif
70  for (int i = 0; i < d1; ++i) {
71  for (int j = 0; j < d2; ++j) {
72  for (int u = 0; u < d3; ++u) {
73  int ii = (i * d2 + j) * d3 + u;
74  int scaled_u = u / scale_;
75  int scaled_j = j / scale_;
76  int ipidx = ((i * scaled_d2) + scaled_j) * scaled_d3 + scaled_u;
77  output_data[ii] = input_data[ipidx];
78  }
79  }
80  }
81 
82  return true;
83  }
84 
85  protected:
86  int scale_;
87 };
88 
89 template <typename T, class Context>
90 class UpsampleNearestGradientOp final : public Operator<Context> {
91  public:
92  UpsampleNearestGradientOp(const OperatorDef& def, Workspace* ws)
93  : Operator<Context>(def, ws),
94  scale_(this->template GetSingleArgument<int>("scale", 2)) {
95  DCHECK_GE(scale_, 1);
96  }
97  USE_OPERATOR_CONTEXT_FUNCTIONS;
98 
99  bool RunOnDevice() override {
100  // No CPU implementation for now
101  CAFFE_NOT_IMPLEMENTED;
102  }
103 
104  protected:
105  int scale_;
106 };
107 
108 } // namespace caffe2
109 
110 #endif // UPSAMPLE_NEAREST_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13