17 #ifndef CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ 18 #define CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 26 template <
class Context>
29 USE_OPERATOR_CONTEXT_FUNCTIONS;
34 "Argument `old_shape` is missing.");
37 "Argument `new_shape` is missing.");
39 vector<int64_t> old_shape =
40 OperatorBase::GetRepeatedArgument<int64_t>(
"old_shape");
41 vector<int64_t> new_shape =
42 OperatorBase::GetRepeatedArgument<int64_t>(
"new_shape");
45 old_shape.size() == 2,
46 "Argument `old_shape` must contain exactly two integers.");
48 new_shape.size() == 2,
49 "Argument `new_shape` must contain exactly two integers.");
53 "The second dimension in argument `old_shape` must be positive.");
55 old_stride_ = old_shape[1];
57 if (old_shape[0] == -1) {
60 "The second dimension in `new_shape` must be positive.");
64 "The first dimension in `old_shape` must be positive.");
66 int64_t matrix_size = old_shape[0] * old_shape[1];
68 if (new_shape[0] == -1) {
71 "Only one dimension in argument `new_shape` can be -1.");
73 matrix_size % new_shape[1] == 0,
74 "Argument `new_shape` does not agree with `old_shape`.");
77 new_shape[0] > 0 && (new_shape[1] == -1 || new_shape[1] > 0),
78 "Dimensions in argument `new_shape` must be positive or -1.");
79 if (new_shape[1] == -1) {
81 matrix_size % new_shape[0] == 0,
82 "Argument `new_shape` does not agree with `old_shape`.");
83 new_shape[1] = matrix_size / new_shape[0];
86 new_shape[0] * new_shape[1] == matrix_size,
87 "Argument `new_shape` does not agree with `old_shape`.");
91 new_stride_ = new_shape[1];
94 bool RunOnDevice()
override {
95 auto& old_col =
Input(0);
96 CAFFE_ENFORCE(old_col.dim() == 1,
"Row index tensor must be 1-D.");
97 auto& old_row =
Input(1);
98 CAFFE_ENFORCE(old_row.dim() == 1,
"Column index tensor must be 1-D.");
100 const auto nnz = old_col.numel();
102 old_row.numel() == nnz,
103 "Column and row tensors must have the same size.");
105 auto* new_col = Output(0, {nnz}, at::dtype<int64_t>());
106 auto* new_row = Output(1, {nnz}, at::dtype<int>());
108 const auto* old_col_data = old_col.template data<int64_t>();
109 const auto* old_row_data = old_row.template data<int>();
111 auto* new_col_data = new_col->template mutable_data<int64_t>();
112 auto* new_row_data = new_row->template mutable_data<int>();
114 for (
int i = 0; i < nnz; ++i) {
115 int64_t offset = old_row_data[i] * old_stride_ + old_col_data[i];
116 new_row_data[i] = offset / new_stride_;
117 new_col_data[i] = offset % new_stride_;
130 #endif // CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.