2 #ifndef CAFFE2_OPERATORS_PREPEND_DIM_OP_H_ 3 #define CAFFE2_OPERATORS_PREPEND_DIM_OP_H_ 5 #include "caffe2/core/common_omp.h" 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
19 dim_size_(this->
template GetSingleArgument<int64_t>(
"dim_size", 0)) {
21 dim_size_, 0,
"Argument dim_size must be greater than zero.");
24 bool RunOnDevice()
override {
25 auto& input =
Input(0);
26 auto* output = Output(0);
28 CAFFE_ENFORCE(input.dim() > 0,
"Input must be at least 1D.");
30 input.size(0) % dim_size_ == 0,
31 "First dimension must be multiple of prepend_dim. Current first dimension: ",
34 vector<int64_t> actual_new_shape(input.dim() + 1);
35 actual_new_shape[0] = dim_size_;
36 actual_new_shape[1] = input.size(0) / dim_size_;
37 for (
int i = 1; i < input.sizes().size(); ++i) {
38 actual_new_shape[i + 1] = input.size(i);
40 output->Resize(actual_new_shape);
42 if (output != &input) {
44 context_.CopyItemsSameDevice(
48 output->raw_mutable_data(input.dtype()));
57 template <
class Context>
60 USE_OPERATOR_CONTEXT_FUNCTIONS;
61 template <
class... Args>
65 bool RunOnDevice()
override {
66 auto& input =
Input(0);
67 auto* output = Output(0);
69 CAFFE_ENFORCE(input.dim() > 1,
"Input must be at least 2D.");
71 vector<int64_t> actual_new_shape(input.dim() - 1);
72 actual_new_shape[0] = input.size(0) * input.size(1);
73 for (
int i = 1; i < input.sizes().size() - 1; ++i) {
74 actual_new_shape[i] = input.size(i + 1);
76 output->Resize(actual_new_shape);
78 if (output != &input) {
80 context_.CopyItemsSameDevice(
84 output->raw_mutable_data(input.dtype()));
95 #endif // CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...