Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.h
1 
17 #ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
18 #define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <class Context>
26 class ExpandDimsOp : public Operator<Context> {
27  public:
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29  ExpandDimsOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<Context>(operator_def, ws),
31  dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
32  auto originalSize = dims_.size();
33  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
34  std::sort(dims_.begin(), dims_.end());
35  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
36  if (dims_.size() < originalSize) {
37  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
38  }
39  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
40  }
41 
42  bool RunOnDevice() override {
43  auto& input = Input(0);
44  auto* output = Output(0);
45  output->CopyFrom(input, &context_);
46  if (dims_.empty()) {
47  return true;
48  }
49 
50  auto newDims = input.dims();
51  CAFFE_ENFORCE_GE(
52  input.dims().size() + dims_.size(),
53  dims_.back() + 1,
54  "Input needs at least ",
55  (1 + dims_.back() - dims_.size()),
56  " dimensions given `dims`.");
57  for (const auto dim : dims_) {
58  newDims.insert(newDims.begin() + dim, 1);
59  }
60  output->Reshape(newDims);
61  return true;
62  }
63 
64  private:
65  vector<int> dims_;
66 };
67 
68 template <class Context>
69 class SqueezeOp : public Operator<Context> {
70  public:
71  USE_OPERATOR_CONTEXT_FUNCTIONS;
72  SqueezeOp(const OperatorDef& operator_def, Workspace* ws)
73  : Operator<Context>(operator_def, ws),
74  dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
75  auto originalSize = dims_.size();
76  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
77 
78  std::sort(dims_.begin(), dims_.end());
79  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
80  if (dims_.size() < originalSize) {
81  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
82  }
83  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
84  }
85 
86  bool RunOnDevice() override {
87  auto& input = Input(0);
88  auto* output = Output(0);
89  output->CopyFrom(input, &context_);
90 
91  CAFFE_ENFORCE_GT(
92  input.ndim(),
93  dims_.back(),
94  "Input needs at least ",
95  (dims_.back() + 1),
96  " dimensions.");
97 
98  std::vector<int> newDims = ComputeDims(input.dims(), dims_);
99  output->Reshape(newDims);
100  return true;
101  }
102 
103  static std::vector<int> ComputeDims(
104  std::vector<TIndex> inputDims,
105  std::vector<int> dims) {
106  int j = 0;
107  std::vector<int> newDims;
108  for (int i = 0; i < inputDims.size(); ++i) {
109  if (j < dims.size() && dims[j] == i) {
110  CAFFE_ENFORCE_EQ(
111  inputDims[i],
112  1,
113  "Dimension ",
114  i,
115  " of input must be 1",
116  " instead of ",
117  inputDims[i],
118  ".");
119  ++j;
120  continue;
121  }
122  newDims.push_back(inputDims.at(i));
123  }
124  return newDims;
125  }
126 
127  private:
128  vector<int> dims_;
129 
130  public:
131  DISABLE_COPY_AND_ASSIGN(SqueezeOp);
132 };
133 } // namespace caffe2
134 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.