Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_pad_op.h
1 
17 #ifndef CAFFE2_OPERATORS_TT_PAD_OP_H_
18 #define CAFFE2_OPERATORS_TT_PAD_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = DefaultEngine>
27 class TTPadOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  TTPadOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(OperatorBase::GetSingleArgument<int64_t>("scale", 0)) {
33  CAFFE_ENFORCE(
34  OperatorBase::HasArgument("scale"), "Argument `scale` is missing.");
35  }
36 
37  bool RunOnDevice() override {
38  const auto& X = Input(0);
39  auto* X_pad = Output(0);
40  CAFFE_ENFORCE(&X == X_pad);
41 
42  CAFFE_ENFORCE(X.dim() == 2, X.dim());
43 
44  auto X_dim0 = X.size(0);
45  auto X_dim1 = X.size(1);
46 
47  auto* X_orig_dim0 = Output(1, {1}, at::dtype<int64_t>());
48  *X_orig_dim0->template mutable_data<int64_t>() = X_dim0;
49 
50  if (X_dim0 % scale_ != 0) {
51  int64_t padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
52  auto dim0_diff = padded_dim0 - X_dim0;
53  // set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
54  X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0);
55 
56  auto* X_pad_data = X_pad->template mutable_data<T>();
57  int64_t X_size = X_dim0 * X_dim1;
58  memset(X_pad_data + X_size, 0, dim0_diff * X_dim1 * sizeof(T));
59  }
60 
61  return true;
62  }
63 
64  protected:
65  int64_t scale_;
66 };
67 
68 template <typename T, class Context, class Engine = DefaultEngine>
69 class TTPadGradientOp final : public Operator<Context> {
70  public:
71  USE_OPERATOR_CONTEXT_FUNCTIONS;
72  TTPadGradientOp(const OperatorDef& operator_def, Workspace* ws)
73  : Operator<Context>(operator_def, ws) {}
74 
75  bool RunOnDevice() override {
76  const auto& G = Input(0);
77  auto* output = Output(0);
78  CAFFE_ENFORCE(&G == output);
79 
80  auto old_dim0 = *Input(1).template data<int64_t>();
81  auto new_dim0 = G.size(0);
82  auto dim1 = G.size(1);
83 
84  if (old_dim0 < new_dim0) {
85  output->ShrinkTo(old_dim0);
86  }
87 
88  return true;
89  }
90 };
91 
92 } // namespace caffe2
93 
94 #endif // CAFFE2_OPERATORS_TT_PAD_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70