Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_matrix_reshape_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
18 #define CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class SparseMatrixReshapeOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  SparseMatrixReshapeOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws) {
32  CAFFE_ENFORCE(
33  OperatorBase::HasArgument("old_shape"),
34  "Argument `old_shape` is missing.");
35  CAFFE_ENFORCE(
36  OperatorBase::HasArgument("new_shape"),
37  "Argument `new_shape` is missing.");
38 
39  vector<int64_t> old_shape =
40  OperatorBase::GetRepeatedArgument<int64_t>("old_shape");
41  vector<int64_t> new_shape =
42  OperatorBase::GetRepeatedArgument<int64_t>("new_shape");
43 
44  CAFFE_ENFORCE(
45  old_shape.size() == 2,
46  "Argument `old_shape` must contain exactly two integers.");
47  CAFFE_ENFORCE(
48  new_shape.size() == 2,
49  "Argument `new_shape` must contain exactly two integers.");
50 
51  CAFFE_ENFORCE(
52  old_shape[1] > 0,
53  "The second dimension in argument `old_shape` must be positive.");
54 
55  old_stride_ = old_shape[1];
56 
57  if (old_shape[0] == -1) {
58  CAFFE_ENFORCE(
59  new_shape[1] > 0,
60  "The second dimension in `new_shape` must be positive.");
61  } else {
62  CAFFE_ENFORCE(
63  old_shape[0] > 0,
64  "The first dimension in `old_shape` must be positive.");
65 
66  int64_t matrix_size = old_shape[0] * old_shape[1];
67 
68  if (new_shape[0] == -1) {
69  CAFFE_ENFORCE(
70  new_shape[1] > 0,
71  "Only one dimension in argument `new_shape` can be -1.");
72  CAFFE_ENFORCE(
73  matrix_size % new_shape[1] == 0,
74  "Argument `new_shape` does not agree with `old_shape`.");
75  } else {
76  CAFFE_ENFORCE(
77  new_shape[0] > 0 && (new_shape[1] == -1 || new_shape[1] > 0),
78  "Dimensions in argument `new_shape` must be positive or -1.");
79  if (new_shape[1] == -1) {
80  CAFFE_ENFORCE(
81  matrix_size % new_shape[0] == 0,
82  "Argument `new_shape` does not agree with `old_shape`.");
83  new_shape[1] = matrix_size / new_shape[0];
84  } else {
85  CAFFE_ENFORCE(
86  new_shape[0] * new_shape[1] == matrix_size,
87  "Argument `new_shape` does not agree with `old_shape`.");
88  }
89  }
90  }
91  new_stride_ = new_shape[1];
92  }
93 
94  bool RunOnDevice() override {
95  auto& old_col = Input(0);
96  CAFFE_ENFORCE(old_col.dim() == 1, "Row index tensor must be 1-D.");
97  auto& old_row = Input(1);
98  CAFFE_ENFORCE(old_row.dim() == 1, "Column index tensor must be 1-D.");
99 
100  const auto nnz = old_col.numel();
101  CAFFE_ENFORCE(
102  old_row.numel() == nnz,
103  "Column and row tensors must have the same size.");
104 
105  auto* new_col = Output(0, {nnz}, at::dtype<int64_t>());
106  auto* new_row = Output(1, {nnz}, at::dtype<int>());
107 
108  const auto* old_col_data = old_col.template data<int64_t>();
109  const auto* old_row_data = old_row.template data<int>();
110 
111  auto* new_col_data = new_col->template mutable_data<int64_t>();
112  auto* new_row_data = new_row->template mutable_data<int>();
113 
114  for (int i = 0; i < nnz; ++i) {
115  int64_t offset = old_row_data[i] * old_stride_ + old_col_data[i];
116  new_row_data[i] = offset / new_stride_;
117  new_col_data[i] = offset % new_stride_;
118  }
119 
120  return true;
121  }
122 
123  private:
124  int64_t old_stride_;
125  int64_t new_stride_;
126 };
127 
128 } // namespace caffe2
129 
130 #endif // CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70