Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_matrix_reshape_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
18 #define CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class SparseMatrixReshapeOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  SparseMatrixReshapeOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws) {
32  CAFFE_ENFORCE(
33  OperatorBase::HasArgument("old_shape"),
34  "Argument `old_shape` is missing.");
35  CAFFE_ENFORCE(
36  OperatorBase::HasArgument("new_shape"),
37  "Argument `new_shape` is missing.");
38 
39  vector<TIndex> old_shape =
40  OperatorBase::GetRepeatedArgument<TIndex>("old_shape");
41  vector<TIndex> new_shape =
42  OperatorBase::GetRepeatedArgument<TIndex>("new_shape");
43 
44  CAFFE_ENFORCE(
45  old_shape.size() == 2,
46  "Argument `old_shape` must contain exactly two integers.");
47  CAFFE_ENFORCE(
48  new_shape.size() == 2,
49  "Argument `new_shape` must contain exactly two integers.");
50 
51  CAFFE_ENFORCE(
52  old_shape[1] > 0,
53  "The second dimension in argument `old_shape` must be positive.");
54 
55  old_stride_ = old_shape[1];
56 
57  if (old_shape[0] == -1) {
58  CAFFE_ENFORCE(
59  new_shape[1] > 0,
60  "The second dimension in `new_shape` must be positive.");
61  } else {
62  CAFFE_ENFORCE(
63  old_shape[0] > 0,
64  "The first dimension in `old_shape` must be positive.");
65 
66  TIndex matrix_size = old_shape[0] * old_shape[1];
67 
68  if (new_shape[0] == -1) {
69  CAFFE_ENFORCE(
70  new_shape[1] > 0,
71  "Only one dimension in argument `new_shape` can be -1.");
72  CAFFE_ENFORCE(
73  matrix_size % new_shape[1] == 0,
74  "Argument `new_shape` does not agree with `old_shape`.");
75  } else {
76  CAFFE_ENFORCE(
77  new_shape[0] > 0 && (new_shape[1] == -1 || new_shape[1] > 0),
78  "Dimensions in argument `new_shape` must be positive or -1.");
79  if (new_shape[1] == -1) {
80  CAFFE_ENFORCE(
81  matrix_size % new_shape[0] == 0,
82  "Argument `new_shape` does not agree with `old_shape`.");
83  new_shape[1] = matrix_size / new_shape[0];
84  } else {
85  CAFFE_ENFORCE(
86  new_shape[0] * new_shape[1] == matrix_size,
87  "Argument `new_shape` does not agree with `old_shape`.");
88  }
89  }
90  }
91  new_stride_ = new_shape[1];
92  }
93 
94  bool RunOnDevice() override {
95  auto& old_col = Input(0);
96  CAFFE_ENFORCE(old_col.ndim() == 1, "Row index tensor must be 1-D.");
97  auto& old_row = Input(1);
98  CAFFE_ENFORCE(old_row.ndim() == 1, "Column index tensor must be 1-D.");
99 
100  const auto nnz = old_col.size();
101  CAFFE_ENFORCE(
102  old_row.size() == nnz,
103  "Column and row tensors must have the same size.");
104 
105  auto* new_col = Output(0);
106  auto* new_row = Output(1);
107  new_col->Resize(nnz);
108  new_row->Resize(nnz);
109 
110  const auto* old_col_data = old_col.template data<TIndex>();
111  const auto* old_row_data = old_row.template data<int>();
112 
113  auto* new_col_data = new_col->template mutable_data<TIndex>();
114  auto* new_row_data = new_row->template mutable_data<int>();
115 
116  for (int i = 0; i < nnz; ++i) {
117  TIndex offset = old_row_data[i] * old_stride_ + old_col_data[i];
118  new_row_data[i] = offset / new_stride_;
119  new_col_data[i] = offset % new_stride_;
120  }
121 
122  return true;
123  }
124 
125  private:
126  TIndex old_stride_;
127  TIndex new_stride_;
128 };
129 
130 } // namespace caffe2
131 
132 #endif // CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52