Caffe2 - C++ API
A deep learning, cross platform ML framework
prepend_dim_op.h
1 
18 #ifndef CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
19 #define CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
20 
21 #include "caffe2/core/common_omp.h"
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/logging.h"
24 #include "caffe2/core/operator.h"
25 
26 namespace caffe2 {
27 
28 template <class Context>
29 class PrependDimOp : public Operator<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  PrependDimOp(const OperatorDef& operator_def, Workspace* ws)
33  : Operator<Context>(operator_def, ws),
34  dim_size_(OperatorBase::GetSingleArgument<int64_t>("dim_size", 0)) {
35  CAFFE_ENFORCE_GT(
36  dim_size_, 0, "Argument dim_size must be greater than zero.");
37  }
38 
39  bool RunOnDevice() override {
40  auto& input = Input(0);
41  auto* output = Output(0);
42 
43  CAFFE_ENFORCE(input.ndim() > 0, "Input must be at least 1D.");
44  CAFFE_ENFORCE(
45  input.dim(0) % dim_size_ == 0,
46  "First dimension must be multiple of prepend_dim.");
47 
48  vector<int64_t> actual_new_shape(input.ndim() + 1);
49  actual_new_shape[0] = dim_size_;
50  actual_new_shape[1] = input.dim(0) / dim_size_;
51  for (int i = 1; i < input.dims().size(); ++i) {
52  actual_new_shape[i + 1] = input.dim(i);
53  }
54  output->Resize(actual_new_shape);
55 
56  if (output != &input) {
57  // If we are not doing in-place computation, a copy is needed.
58  context_.template CopyItems<Context, Context>(
59  input.meta(),
60  input.size(),
61  input.raw_data(),
62  output->raw_mutable_data(input.meta()));
63  }
64  return true;
65  }
66 
67  private:
68  int64_t dim_size_;
69 };
70 
71 template <class Context>
72 class MergeDimOp : public Operator<Context> {
73  public:
74  USE_OPERATOR_CONTEXT_FUNCTIONS;
75  MergeDimOp(const OperatorDef& operator_def, Workspace* ws)
76  : Operator<Context>(operator_def, ws) {}
77 
78  bool RunOnDevice() override {
79  auto& input = Input(0);
80  auto* output = Output(0);
81 
82  CAFFE_ENFORCE(input.ndim() > 1, "Input must be at least 2D.");
83 
84  vector<int64_t> actual_new_shape(input.ndim() - 1);
85  actual_new_shape[0] = input.dim(0) * input.dim(1);
86  for (int i = 1; i < input.dims().size() - 1; ++i) {
87  actual_new_shape[i] = input.dim(i + 1);
88  }
89  output->Resize(actual_new_shape);
90 
91  if (output != &input) {
92  // If we are not doing in-place computation, a copy is needed.
93  context_.template CopyItems<Context, Context>(
94  input.meta(),
95  input.size(),
96  input.raw_data(),
97  output->raw_mutable_data(input.meta()));
98  }
99  return true;
100  }
101 
102  private:
103  int64_t dim_size_;
104 };
105 
106 } // namespace caffe2
107 
108 #endif // CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.