Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.h
1 #ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
2 #define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class ExpandDimsOp : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13  template <class... Args>
14  explicit ExpandDimsOp(Args&&... args)
15  : Operator<Context>(std::forward<Args>(args)...),
16  dims_(this->template GetRepeatedArgument<int>("dims")) {
17  auto originalSize = dims_.size();
18  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
19  std::sort(dims_.begin(), dims_.end());
20  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
21  if (dims_.size() < originalSize) {
22  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
23  }
24  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
25  }
26 
27  bool RunOnDevice() override {
28  auto& input = Input(0);
29  auto* output = Output(0);
30  output->CopyFrom(input, true /*async*/);
31  if (dims_.empty()) {
32  return true;
33  }
34 
35  auto newDims = input.sizes().vec();
36  CAFFE_ENFORCE_GE(
37  input.sizes().size() + dims_.size(),
38  dims_.back() + 1,
39  "Input needs at least ",
40  (1 + dims_.back() - dims_.size()),
41  " dimensions given `dims`.");
42  for (const auto dim : dims_) {
43  newDims.insert(newDims.begin() + dim, 1);
44  }
45  output->Reshape(newDims);
46  return true;
47  }
48 
49  private:
50  vector<int> dims_;
51 };
52 
53 template <class Context>
54 class SqueezeOp : public Operator<Context> {
55  public:
56  USE_OPERATOR_CONTEXT_FUNCTIONS;
57  template <class... Args>
58  explicit SqueezeOp(Args&&... args)
59  : Operator<Context>(std::forward<Args>(args)...),
60  dims_(this->template GetRepeatedArgument<int>("dims")) {
61  auto originalSize = dims_.size();
62  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
63 
64  std::sort(dims_.begin(), dims_.end());
65  dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
66  if (dims_.size() < originalSize) {
67  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
68  }
69  CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
70  }
71 
72  bool RunOnDevice() override {
73  auto& input = Input(0);
74  auto* output = Output(0);
75  output->CopyFrom(input, true /*async*/);
76 
77  CAFFE_ENFORCE_GT(
78  input.dim(),
79  dims_.back(),
80  "Input needs at least ",
81  (dims_.back() + 1),
82  " dimensions.");
83 
84  std::vector<int> newDims = ComputeDims(input.sizes(), dims_);
85  output->Reshape(newDims);
86  return true;
87  }
88 
89  static std::vector<int> ComputeDims(
90  at::IntArrayRef inputDims,
91  std::vector<int> dims) {
92  size_t j = 0;
93  std::vector<int> newDims;
94  for (size_t i = 0; i < inputDims.size(); ++i) {
95  if (j < dims.size() && dims[j] == i) {
96  CAFFE_ENFORCE_EQ(
97  inputDims[i],
98  1,
99  "Dimension ",
100  i,
101  " of input must be 1",
102  " instead of ",
103  inputDims[i],
104  ".");
105  ++j;
106  continue;
107  }
108  newDims.push_back(inputDims.at(i));
109  }
110  return newDims;
111  }
112 
113  private:
114  vector<int> dims_;
115 
116  public:
117  C10_DISABLE_COPY_AND_ASSIGN(SqueezeOp);
118 };
119 } // namespace caffe2
120 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
Definition: ArrayRef.h:186