Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.cc
1 
17 #include "caffe2/operators/expand_squeeze_dims_op.h"
18 
19 namespace caffe2 {
20 REGISTER_CPU_OPERATOR(ExpandDims, ExpandDimsOp<CPUContext>);
21 REGISTER_CPU_OPERATOR(Squeeze, SqueezeOp<CPUContext>);
22 
23 OPERATOR_SCHEMA(ExpandDims)
24  .NumInputs(1)
25  .NumOutputs(1)
26  .AllowInplace({{0, 0}})
27  .TensorInferenceFunction([](const OperatorDef& def,
28  const vector<TensorShape>& in) {
29  ArgumentHelper helper(def);
30  auto dims = helper.template GetRepeatedArgument<int>("dims");
31  auto originalSize = dims.size();
32  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
33 
34  std::sort(dims.begin(), dims.end());
35  dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
36  if (dims.size() < originalSize) {
37  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
38  }
39 
40  CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
41  CAFFE_ENFORCE_GE(
42  in[0].dims_size() + dims.size(),
43  dims.back() + 1,
44  "Input needs at least ",
45  (1 + dims.back() - dims.size()),
46  " dimensions given `dims`.");
47 
48  vector<TensorShape> out(1);
49 
50  int cur_pos = 0;
51  int idx = 0;
52  for (const auto new_dim : dims) {
53  for (int i = cur_pos; i < new_dim; i++) {
54  out[0].add_dims(in[0].dims(idx++));
55  }
56  out[0].add_dims(1);
57  cur_pos = new_dim + 1;
58  }
59  for (; idx < in[0].dims_size(); idx++) {
60  out[0].add_dims(in[0].dims(idx));
61  }
62  out[0].set_data_type(in[0].data_type());
63  return out;
64  })
65  .SetDoc(R"DOC(
66 Insert single-dimensional entries to the shape of a tensor.
67 Takes one required argument `dims`, a list of dimensions that will be inserted.
68 Dimension indices in `dims` are as seen in the output tensor. For example:
69 
70  Given a tensor such that tensor.Shape() = [3, 4, 5], then
71  ExpandDims(tensor, dims=[0, 4]).Shape() == [1, 3, 4, 5, 1])
72 
73 If the same blob is provided in input and output, the operation is copy-free.
74 )DOC")
75  .Input(0, "data", "Original tensor")
76  .Output(0, "expanded", "Reshaped tensor with same data as input.");
77 
78 OPERATOR_SCHEMA(Squeeze)
79  .NumInputs(1)
80  .NumOutputs(1)
81  .AllowInplace({{0, 0}})
82  .SetDoc(R"DOC(
83 Remove single-dimensional entries from the shape of a tensor.
84 Takes a parameter `dims` with a list of dimension to squeeze.
85 If the same blob is provided in input and output, the operation is copy-free.
86 This is the exact inverse operation of ExpandDims given the same `dims` arg.
87 )DOC")
88  .Input(0, "data", "Tensors with at least max(dims) dimensions.")
89  .Output(0, "squeezed", "Reshaped tensor with same data as input.")
90  .TensorInferenceFunction([](const OperatorDef& def,
91  const vector<TensorShape>& in) {
92  ArgumentHelper helper(def);
93  auto dims = helper.template GetRepeatedArgument<int>("dims");
94  auto originalSize = dims.size();
95  std::sort(dims.begin(), dims.end());
96  dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
97  if (dims.size() < originalSize) {
98  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
99  }
100  CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
101 
102  vector<TensorShape> out(1);
103  std::vector<int> newDims =
104  SqueezeOp<CPUContext>::ComputeDims(GetDimsVector(in[0]), dims);
105  out[0] = CreateTensorShape(newDims, in[0].data_type());
106  return out;
107  });
108 
109 class GetSqueezeGradient : public GradientMakerBase {
110  using GradientMakerBase::GradientMakerBase;
111  vector<OperatorDef> GetGradientDefs() override {
112  return SingleGradientDef(
113  "ExpandDims", "", vector<string>{GO(0)}, vector<string>{GI(0)});
114  }
115 };
116 REGISTER_GRADIENT(Squeeze, GetSqueezeGradient);
117 
118 class GetExpandDimsGradient : public GradientMakerBase {
119  using GradientMakerBase::GradientMakerBase;
120  vector<OperatorDef> GetGradientDefs() override {
121  return SingleGradientDef(
122  "Squeeze", "", vector<string>{GO(0)}, vector<string>{GI(0)});
123  }
124 };
125 REGISTER_GRADIENT(ExpandDims, GetExpandDimsGradient);
126 }
Copyright (c) 2016-present, Facebook, Inc.