1 #include "caffe2/operators/expand_squeeze_dims_op.h" 2 #include <caffe2/ideep/ideep_utils.h> 3 #include <caffe2/ideep/operators/operator_fallback_ideep.h> 9 USE_IDEEP_DEF_ALIASES();
10 USE_IDEEP_OPERATOR_FUNCTIONS();
15 fallback_(operator_def, ws) {
16 dims_ = OperatorBase::GetRepeatedArgument<int>(
"dims");
17 auto originalSize = dims_.size();
18 CAFFE_ENFORCE_GT(originalSize, 0,
"Parameter `dims` must be provided.");
19 std::sort(dims_.begin(), dims_.end());
20 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
21 if (dims_.size() < originalSize) {
22 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
24 CAFFE_ENFORCE_GE(dims_.front(), 0,
"Dimension ids must be non-negative.");
28 bool RunOnDevice()
override {
29 if (!OperatorBase::InputBlob(INPUT).
template IsType<itensor>()) {
30 return fallback_.Run(0);
33 const auto& X = Input(INPUT);
34 auto* Y = Output(OUTPUT);
37 ideep::direct_copy::compute(X, *Y);
43 auto newDims = X.get_dims();
45 newDims.size() + dims_.size(),
47 "Input needs at least ",
48 (1 + dims_.back() - dims_.size()),
49 " dimensions given `dims`.");
51 for (
const auto dim : dims_) {
52 newDims.insert(newDims.begin() + dim, 1);
60 std::vector<int> dims_;
70 USE_IDEEP_DEF_ALIASES();
71 USE_IDEEP_OPERATOR_FUNCTIONS();
76 fallback_(operator_def, ws) {
77 dims_ = OperatorBase::GetRepeatedArgument<int>(
"dims");
78 auto originalSize = dims_.size();
79 CAFFE_ENFORCE_GT(originalSize, 0,
"Parameter `dims` must be provided.");
81 std::sort(dims_.begin(), dims_.end());
82 dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
83 if (dims_.size() < originalSize) {
84 LOG(WARNING) <<
"Parameter `dims` has repeated dimensions.";
86 CAFFE_ENFORCE_GE(dims_.front(), 0,
"Dimension ids must be non-negative.");
90 bool RunOnDevice()
override {
91 if (!OperatorBase::InputBlob(INPUT).
template IsType<itensor>()) {
92 return fallback_.Run(0);
95 const auto& X = Input(INPUT);
96 auto* Y = Output(OUTPUT);
101 "Input needs at least ",
104 const auto& ideep_dims = X.get_dims();
105 std::vector<int64_t> dims(ideep_dims.begin(), ideep_dims.end());
107 itensor::dims new_dims_ideep(new_dims.begin(), new_dims.end());
110 ideep::direct_copy::compute(X, *Y);
113 Y->reshape(new_dims_ideep);
118 std::vector<int> dims_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A templated class to allow one to wrap a CPU operator as an IDEEP operator.