Caffe2 - C++ API
A deep learning, cross platform ML framework
prepend_dim_op.h
1 
2 #ifndef CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
3 #define CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
4 
5 #include "caffe2/core/common_omp.h"
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class PrependDimOp : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  template <class... Args>
17  explicit PrependDimOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...),
19  dim_size_(this->template GetSingleArgument<int64_t>("dim_size", 0)) {
20  CAFFE_ENFORCE_GT(
21  dim_size_, 0, "Argument dim_size must be greater than zero.");
22  }
23 
24  bool RunOnDevice() override {
25  auto& input = Input(0);
26  auto* output = Output(0);
27 
28  CAFFE_ENFORCE(input.dim() > 0, "Input must be at least 1D.");
29  CAFFE_ENFORCE(
30  input.size(0) % dim_size_ == 0,
31  "First dimension must be multiple of prepend_dim. Current first dimension: ",
32  input.size(0));
33 
34  vector<int64_t> actual_new_shape(input.dim() + 1);
35  actual_new_shape[0] = dim_size_;
36  actual_new_shape[1] = input.size(0) / dim_size_;
37  for (int i = 1; i < input.sizes().size(); ++i) {
38  actual_new_shape[i + 1] = input.size(i);
39  }
40  output->Resize(actual_new_shape);
41 
42  if (output != &input) {
43  // If we are not doing in-place computation, a copy is needed.
44  context_.CopyItemsSameDevice(
45  input.dtype(),
46  input.numel(),
47  input.raw_data(),
48  output->raw_mutable_data(input.dtype()));
49  }
50  return true;
51  }
52 
53  private:
54  int64_t dim_size_;
55 };
56 
57 template <class Context>
58 class MergeDimOp : public Operator<Context> {
59  public:
60  USE_OPERATOR_CONTEXT_FUNCTIONS;
61  template <class... Args>
62  explicit MergeDimOp(Args&&... args)
63  : Operator<Context>(std::forward<Args>(args)...) {}
64 
65  bool RunOnDevice() override {
66  auto& input = Input(0);
67  auto* output = Output(0);
68 
69  CAFFE_ENFORCE(input.dim() > 1, "Input must be at least 2D.");
70 
71  vector<int64_t> actual_new_shape(input.dim() - 1);
72  actual_new_shape[0] = input.size(0) * input.size(1);
73  for (int i = 1; i < input.sizes().size() - 1; ++i) {
74  actual_new_shape[i] = input.size(i + 1);
75  }
76  output->Resize(actual_new_shape);
77 
78  if (output != &input) {
79  // If we are not doing in-place computation, a copy is needed.
80  context_.CopyItemsSameDevice(
81  input.dtype(),
82  input.numel(),
83  input.raw_data(),
84  output->raw_mutable_data(input.dtype()));
85  }
86  return true;
87  }
88 
89  private:
90  int64_t dim_size_;
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13