Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.cc
1 #include "caffe2/operators/expand_squeeze_dims_op.h"
2 #include <caffe2/ideep/ideep_utils.h>
3 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
4 
5 namespace caffe2 {
6 
7 class IDEEPExpandDimsOp final : public IDEEPOperator {
8  public:
9  USE_IDEEP_DEF_ALIASES();
10  USE_IDEEP_OPERATOR_FUNCTIONS();
12 
13  IDEEPExpandDimsOp(const OperatorDef& operator_def, Workspace* ws)
14  : IDEEPOperator(operator_def, ws),
15  fallback_(operator_def, ws) {
16  dims_ = OperatorBase::GetRepeatedArgument<int>("dims");
17  auto originalSize = dims_.size();
18  CAFFE_ENFORCE_GT(originalSize, 0, "Parameter `dims` must be provided.");
19  std::sort(dims_.begin(), dims_.end());
20  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
21  if (dims_.size() < originalSize) {
22  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
23  }
24  CAFFE_ENFORCE_GE(dims_.front(), 0, "Dimension ids must be non-negative.");
25  }
26  ~IDEEPExpandDimsOp() override {}
27 
28  bool RunOnDevice() override {
29  if (!OperatorBase::InputBlob(INPUT).template IsType<itensor>()) {
30  return fallback_.Run(0);
31  }
32 
33  const auto& X = Input(INPUT);
34  auto* Y = Output(OUTPUT);
35  if (&X != Y) {
36  // Copy if not inplace
37  ideep::direct_copy::compute(X, *Y);
38  }
39  if (dims_.empty()) {
40  return true;
41  }
42 
43  auto newDims = X.get_dims();
44  CAFFE_ENFORCE_GE(
45  newDims.size() + dims_.size(),
46  dims_.back() + 1,
47  "Input needs at least ",
48  (1 + dims_.back() - dims_.size()),
49  " dimensions given `dims`.");
50 
51  for (const auto dim : dims_) {
52  newDims.insert(newDims.begin() + dim, 1);
53  }
54 
55  Y->reshape(newDims);
56  return true;
57  }
58 
59  private:
60  std::vector<int> dims_;
61  FALLBACK_OP fallback_;
62 
63  INPUT_TAGS(INPUT);
64  OUTPUT_TAGS(OUTPUT);
65 };
66 
67 
68 class IDEEPSqueezeOp final : public IDEEPOperator {
69  public:
70  USE_IDEEP_DEF_ALIASES();
71  USE_IDEEP_OPERATOR_FUNCTIONS();
73 
74  IDEEPSqueezeOp(const OperatorDef& operator_def, Workspace* ws)
75  : IDEEPOperator(operator_def, ws),
76  fallback_(operator_def, ws) {
77  dims_ = OperatorBase::GetRepeatedArgument<int>("dims");
78  auto originalSize = dims_.size();
79  CAFFE_ENFORCE_GT(originalSize, 0, "Parameter `dims` must be provided.");
80 
81  std::sort(dims_.begin(), dims_.end());
82  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
83  if (dims_.size() < originalSize) {
84  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
85  }
86  CAFFE_ENFORCE_GE(dims_.front(), 0, "Dimension ids must be non-negative.");
87  }
88  ~IDEEPSqueezeOp() override {}
89 
90  bool RunOnDevice() override {
91  if (!OperatorBase::InputBlob(INPUT).template IsType<itensor>()) {
92  return fallback_.Run(0);
93  }
94 
95  const auto& X = Input(INPUT);
96  auto* Y = Output(OUTPUT);
97 
98  CAFFE_ENFORCE_GT(
99  X.ndims(),
100  dims_.back(),
101  "Input needs at least ",
102  (dims_.back() + 1),
103  " dimensions.");
104  const auto& ideep_dims = X.get_dims();
105  std::vector<int64_t> dims(ideep_dims.begin(), ideep_dims.end());
106  const auto new_dims = SqueezeOp<IDEEPContext>::ComputeDims(dims, dims_);
107  itensor::dims new_dims_ideep(new_dims.begin(), new_dims.end());
108  if (&X != Y) {
109  // Copy if not inplace
110  ideep::direct_copy::compute(X, *Y);
111  }
112 
113  Y->reshape(new_dims_ideep);
114  return true;
115  }
116 
117  private:
118  std::vector<int> dims_;
119  FALLBACK_OP fallback_;
120 
121  INPUT_TAGS(INPUT);
122  OUTPUT_TAGS(OUTPUT);
123 };
124 
125 
126 REGISTER_IDEEP_OPERATOR(ExpandDims, IDEEPExpandDimsOp);
127 REGISTER_IDEEP_OPERATOR(Squeeze, IDEEPSqueezeOp);
128 
129 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
A templated class to allow one to wrap a CPU operator as an IDEEP operator.