Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_pad_op.h
1 
17 #ifndef CAFFE2_OPERATORS_TT_PAD_OP_H_
18 #define CAFFE2_OPERATORS_TT_PAD_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context, class Engine = DefaultEngine>
27 class TTPadOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  TTPadOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(OperatorBase::GetSingleArgument<TIndex>("scale", 0)) {
33  CAFFE_ENFORCE(
34  OperatorBase::HasArgument("scale"), "Argument `scale` is missing.");
35  }
36 
37  bool RunOnDevice() override {
38  const auto& X = Input(0);
39  auto* X_pad = Output(0);
40  CAFFE_ENFORCE(&X == X_pad);
41 
42  CAFFE_ENFORCE(X.ndim() == 2, X.ndim());
43 
44  auto X_dim0 = X.dim(0);
45  auto X_dim1 = X.dim(1);
46 
47  auto* X_orig_dim0 = Output(1);
48  X_orig_dim0->Resize(1);
49  *X_orig_dim0->template mutable_data<TIndex>() = X_dim0;
50 
51  if (X_dim0 % scale_ != 0) {
52  TIndex padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
53  auto dim0_diff = padded_dim0 - X_dim0;
54  // set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
55  X_pad->template Extend(dim0_diff, 100 * scale_ / X_dim0, &context_);
56 
57  auto* X_pad_data = X_pad->template mutable_data<T>();
58  TIndex X_size = X_dim0 * X_dim1;
59  memset(X_pad_data + X_size, 0, dim0_diff * X_dim1 * sizeof(T));
60  }
61 
62  return true;
63  }
64 
65  protected:
66  TIndex scale_;
67 };
68 
69 template <typename T, class Context, class Engine = DefaultEngine>
70 class TTPadGradientOp final : public Operator<Context> {
71  public:
72  USE_OPERATOR_CONTEXT_FUNCTIONS;
73  TTPadGradientOp(const OperatorDef& operator_def, Workspace* ws)
74  : Operator<Context>(operator_def, ws) {}
75 
76  bool RunOnDevice() override {
77  const auto& G = Input(0);
78  auto* output = Output(0);
79  CAFFE_ENFORCE(&G == output);
80 
81  auto old_dim0 = *Input(1).template data<TIndex>();
82  auto new_dim0 = G.dim(0);
83  auto dim1 = G.dim(1);
84 
85  if (old_dim0 < new_dim0) {
86  output->Shrink(old_dim0);
87  }
88 
89  return true;
90  }
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_TT_PAD_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52