17 #ifndef CAFFE2_OPERATORS_TT_PAD_OP_H_ 18 #define CAFFE2_OPERATORS_TT_PAD_OP_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 26 template <
typename T,
class Context,
class Engine = DefaultEngine>
29 USE_OPERATOR_CONTEXT_FUNCTIONS;
32 scale_(OperatorBase::GetSingleArgument<int64_t>(
"scale", 0)) {
37 bool RunOnDevice()
override {
38 const auto& X =
Input(0);
39 auto* X_pad = Output(0);
40 CAFFE_ENFORCE(&X == X_pad);
42 CAFFE_ENFORCE(X.dim() == 2, X.dim());
44 auto X_dim0 = X.size(0);
45 auto X_dim1 = X.size(1);
47 auto* X_orig_dim0 = Output(1, {1}, at::dtype<int64_t>());
48 *X_orig_dim0->template mutable_data<int64_t>() = X_dim0;
50 if (X_dim0 % scale_ != 0) {
51 int64_t padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
52 auto dim0_diff = padded_dim0 - X_dim0;
54 X_pad->Extend(dim0_diff, 100 * scale_ / X_dim0);
56 auto* X_pad_data = X_pad->template mutable_data<T>();
57 int64_t X_size = X_dim0 * X_dim1;
58 memset(X_pad_data + X_size, 0, dim0_diff * X_dim1 *
sizeof(
T));
68 template <
typename T,
class Context,
class Engine = DefaultEngine>
71 USE_OPERATOR_CONTEXT_FUNCTIONS;
75 bool RunOnDevice()
override {
76 const auto& G =
Input(0);
77 auto* output = Output(0);
78 CAFFE_ENFORCE(&G == output);
80 auto old_dim0 = *
Input(1).template data<int64_t>();
81 auto new_dim0 = G.size(0);
82 auto dim1 = G.size(1);
84 if (old_dim0 < new_dim0) {
85 output->ShrinkTo(old_dim0);
94 #endif // CAFFE2_OPERATORS_TT_PAD_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.