1 #ifndef CAFFE2_OPERATORS_RESHAPE_OP_H_ 2 #define CAFFE2_OPERATORS_RESHAPE_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 13 template <
typename F,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 template <
class... Args>
20 new_shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")) {}
22 bool RunOnDevice()
override {
23 if (InputSize() == 2) {
28 return this->
template DoRunWithType<int64_t>();
32 bool DoRunWithType() {
33 DoRunWithTypeImpl<T>(
Input(0), Output(0));
39 void DoRunWithTypeImpl(
const Tensor& input,
Tensor* output) {
40 vector<int64_t> actual_new_shape = new_shape_;
41 if (InputSize() == 2) {
44 "New shape is specified by the input blob, do not pass in " 45 "the argument `shape`.");
47 auto& shape =
Input(1);
48 CAFFE_ENFORCE(shape.dim() == 1,
"Shape should be 1-D");
50 const T* shape_data = shape.template data<T>();
53 std::vector<T> tmpv(shape.numel());
54 if (shape.numel() > 0) {
55 context_.CopyBytesToCPU(
56 shape.numel() *
sizeof(
T), shape_data, &tmpv[0]);
57 actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.numel());
62 for (
size_t i = 0; i < actual_new_shape.size() && i < input.dim(); ++i) {
63 if (actual_new_shape[i] == 0) {
64 actual_new_shape[i] = input.size(i);
71 auto total_size = input.numel();
74 for (
int i = 0; i < actual_new_shape.size(); ++i) {
75 const auto dim = actual_new_shape[i];
79 "Argument `shape` has more than one missing dimension.");
85 if (size == 0 && total_size != 0) {
86 CAFFE_THROW(
"Can not reshape a non-zero size (", total_size,
") tensor to zero size.");
89 if (unknown_idx != -1) {
95 " can not be inferred since new size is zero.");
97 total_size % size == 0,
98 "Argument `shape` does not agree with the input data.",
104 actual_new_shape[unknown_idx] = total_size / size;
109 "Argument `shape` does not agree with the input data.",
119 auto* old_shape = Output(1, {input.dim()}, at::dtype<T>());
120 T* old_shape_data = old_shape->template mutable_data<T>();
121 for (
int i = 0; i < input.dim(); ++i) {
122 math::Set<T, Context>(1, input.size(i), old_shape_data + i, &context_);
125 output->Resize(actual_new_shape);
126 if (output != &input) {
128 context_.CopyItemsSameDevice(
132 output->raw_mutable_data(input.dtype()));
137 vector<int64_t> new_shape_;
142 #endif // CAFFE2_OPERATORS_RESHAPE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.