1 #ifndef CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_ 2 #define CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
13 template <
class... Args>
16 dims_(this->
template GetRepeatedArgument<int>(
"dims")) {
17 auto originalSize = dims_.size();
18 CAFFE_ENFORCE(originalSize > 0,
"Parameter `dims` must be provided.");
19 std::sort(dims_.begin(), dims_.end());
20 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
21 if (dims_.size() < originalSize) {
22 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
24 CAFFE_ENFORCE(dims_.front() >= 0,
"Dimension ids must be non-negative.");
27 bool RunOnDevice()
override {
28 auto& input =
Input(0);
29 auto* output = Output(0);
30 output->CopyFrom(input,
true );
35 auto newDims = input.sizes().vec();
37 input.sizes().size() + dims_.size(),
39 "Input needs at least ",
40 (1 + dims_.back() - dims_.size()),
41 " dimensions given `dims`.");
42 for (
const auto dim : dims_) {
43 newDims.insert(newDims.begin() + dim, 1);
45 output->Reshape(newDims);
53 template <
class Context>
56 USE_OPERATOR_CONTEXT_FUNCTIONS;
57 template <
class... Args>
60 dims_(this->
template GetRepeatedArgument<int>(
"dims")) {
61 auto originalSize = dims_.size();
62 CAFFE_ENFORCE(originalSize > 0,
"Parameter `dims` must be provided.");
64 std::sort(dims_.begin(), dims_.end());
65 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
66 if (dims_.size() < originalSize) {
67 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
69 CAFFE_ENFORCE(dims_.front() >= 0,
"Dimension ids must be non-negative.");
72 bool RunOnDevice()
override {
73 auto& input =
Input(0);
74 auto* output = Output(0);
75 output->CopyFrom(input,
true );
80 "Input needs at least ",
84 std::vector<int> newDims = ComputeDims(input.sizes(), dims_);
85 output->Reshape(newDims);
89 static std::vector<int> ComputeDims(
91 std::vector<int> dims) {
93 std::vector<int> newDims;
94 for (
size_t i = 0; i < inputDims.
size(); ++i) {
95 if (j < dims.size() && dims[j] == i) {
101 " of input must be 1",
108 newDims.push_back(inputDims.
at(i));
117 C10_DISABLE_COPY_AND_ASSIGN(SqueezeOp);
120 #endif // CAFFE2_OPERATORS_EXPAND_SQUEEZE_DIMS_OP_H_
constexpr size_t size() const
size - Get the array size.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.