1 #include "batch_sparse_to_dense_op.h" 3 #include "caffe2/core/context.h" 7 template <
typename T,
class Context>
8 bool BatchSparseToDenseOp<T, Context>::RunOnDevice() {
9 auto& lengths = Input(LENGTHS);
10 auto& indices = Input(INDICES);
11 auto& values = Input(VALUES);
13 CAFFE_ENFORCE_EQ(indices.numel(), values.numel());
14 CAFFE_ENFORCE_EQ(lengths.dim(), 1);
15 CAFFE_ENFORCE_EQ(indices.dim(), 1);
17 const int64_t* lengths_data = lengths.template data<int64_t>();
18 const int64_t* indices_data = indices.template data<int64_t>();
19 const T* values_data = values.template data<T>();
20 int64_t batch_size = lengths.numel();
21 int64_t lengths_sum = 0;
22 math::Sum<int64_t, Context>(batch_size, lengths_data, &lengths_sum, &context_);
23 CAFFE_ENFORCE_EQ(lengths_sum, indices.numel());
25 vector<int64_t> output_shape = {batch_size};
26 if (InputSize() == 4) {
27 auto& shaper = Input(3);
28 CAFFE_ENFORCE_EQ(shaper.dim(), 2);
29 if (dense_last_dim_ == -1) {
30 dense_last_dim_ = shaper.size(1);
33 dense_last_dim_ == shaper.size(1),
34 "The last dim argument is not aligned with the shape input last dim");
37 CAFFE_ENFORCE(dense_last_dim_ >= 1,
"The last dim of dense must be >= 1");
39 output_shape.push_back(dense_last_dim_);
40 auto* output = Output(0, output_shape, at::dtype<T>());
41 T* output_data = output->template mutable_data<T>();
43 output->numel(),
static_cast<T>(default_value_), output_data, &context_);
46 for (int64_t i = 0; i < batch_size; ++i) {
47 for (int64_t j = 0; j < lengths_data[i]; ++j) {
49 indices_data[k] < dense_last_dim_,
52 ") is larger then last dim of dense (",
55 output_data[i * dense_last_dim_ + indices_data[k]] = values_data[k];
63 template <
typename T,
class Context>
64 bool BatchDenseToSparseOp<T, Context>::RunOnDevice() {
65 auto& lengths = Input(LENGTHS);
66 auto& indices = Input(INDICES);
67 auto& dense = Input(DENSE);
69 CAFFE_ENFORCE_EQ(lengths.dim(), 1);
70 CAFFE_ENFORCE_EQ(indices.dim(), 1);
71 CAFFE_ENFORCE_EQ(dense.dim(), 2);
72 const int64_t* lengths_data = lengths.template data<int64_t>();
73 const int64_t* indices_data = indices.template data<int64_t>();
74 const T* dense_data = dense.template data<T>();
76 int64_t batch_size = lengths.numel();
77 int64_t lengths_sum = 0;
78 math::Sum<int64_t, Context>(batch_size, lengths_data, &lengths_sum, &context_);
79 CAFFE_ENFORCE_EQ(lengths_sum, indices.numel());
81 CAFFE_ENFORCE_EQ(batch_size, dense.size(0));
82 dense_last_dim_ = dense.size(1);
83 vector<int64_t> output_shape = indices.sizes().vec();
84 auto* output = Output(0, output_shape, at::dtype<T>());
85 T* output_data = output->template mutable_data<T>();
88 for (int64_t i = 0; i < batch_size; ++i) {
89 for (int64_t j = 0; j < lengths_data[i]; ++j) {
91 indices_data[k] < dense.size(1),
94 ") is larger then last dim of dense (",
97 output_data[k] = dense_data[i * dense.size(1) + indices_data[k]];
104 REGISTER_CPU_OPERATOR(
106 BatchSparseToDenseOp<float, CPUContext>);
108 OPERATOR_SCHEMA(BatchSparseToDense)
111 .DisallowInputFillers()
113 Convert sparse matrix representation into dense matrix. 115 A sparse matrix is represented by `lengths` vector, `indices` vector, 116 and `values` vector. Each element in `lengths` vector (lengths[`i`]) represents 117 the number of indices in this batch (batch `i`). 118 With in each batch, `indices` should not have duplicate number. 120 For example, with input: 123 indices = [0, 1, 2, 3, 4, 5] 124 values = [6, 7, 8, 9, 10, 11] 130 output = [[6, 7, 0, 0, 0, 0], 134 after running this operator. 139 "Flatten tensor, used to break down indices and values into per batch indices and values.")
143 "Flatten tensor of total size = \\sum lengths, containing the indices ")
144 .Input(2,
"values",
"Data tensor, dimension has to match `indices`")
147 "output_shape_inference",
148 "Optional, a dense tensor whose shape define the output shape")
152 "2-D dense tensor, with 1st dim = len(lengths), 2nd dim = dense_last_dim" 153 "in the arg list, the tensor is of the same data type as `values`." 154 "Missing values are filled with default_value")
157 "Optional, output dense last dimension. " 158 "If both this argument and output_shape_inference are set, " 159 "it should be consistent with output_shape_inference's last dim")
162 "Optional, missing values are filled with this value." 163 "default_value = 0 when not set");
165 REGISTER_CPU_OPERATOR(
167 BatchDenseToSparseOp<float, CPUContext>);
169 OPERATOR_SCHEMA(BatchDenseToSparse)
173 This Op is a inverse of BatchSparseToDenseOp. 174 Basically, given a `lengths` vector, a `indices` vector, 175 and a dense matrix `dense`, output `value` vector so that, along with 176 `lengths` vector and `indices` vector, forms a sparse representation 179 A sparse matrix is represented by `lengths` vector, `indices` vector, 180 and `values` vector. Each element in `lengths` vector (lengths[`i`]) represents 181 the number of indices in this batch (batch `i`). 182 With in each batch, `indices` should not have duplicate number. 184 For example, with input: 187 indices = [0, 1, 2, 3, 4, 5] 188 output = [[6, 7, 0, 0, 0, 0], 194 values = [6, 7, 8, 9, 10, 11] 196 after running this operator. 201 "Flatten lengths, Used to break down indices into per batch indices")
205 "Flatten indices, tensor of total size = \\sum lengths, containing the indices ")
209 "dense 2-D tensor, first dim = len(lengths), last dim > Any(indices)")
213 "Values, tensor of the same size as `indices` and same data type as dense tensor.");
217 class GetBatchSparseToDenseGradient :
public GradientMakerBase {
218 using GradientMakerBase::GradientMakerBase;
219 vector<OperatorDef> GetGradientDefs()
override {
220 return SingleGradientDef(
221 "BatchDenseToSparse",
223 vector<string>{I(0), I(1), GO(0)},
224 vector<string>{GI(2)});
228 class GetBatchDenseToSparseGradient :
public GradientMakerBase {
229 using GradientMakerBase::GradientMakerBase;
230 vector<OperatorDef> GetGradientDefs()
override {
231 return SingleGradientDef(
232 "BatchSparseToDense",
234 vector<string>{I(0), I(1), GO(0), I(2)},
235 vector<string>{GI(2)});
239 REGISTER_GRADIENT(BatchSparseToDense, GetBatchSparseToDenseGradient);
240 REGISTER_GRADIENT(BatchDenseToSparse, GetBatchDenseToSparseGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...