Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_to_dense_op.cc
1 #include "sparse_to_dense_op.h"
2 
3 #include "caffe2/core/context.h"
4 
5 namespace caffe2 {
6 
7 REGISTER_CPU_OPERATOR(SparseToDense, SparseToDenseOp<CPUContext>);
8 
9 OPERATOR_SCHEMA(SparseToDense)
10  .NumInputs(2, 3)
11  .NumOutputs(1)
12  .SetDoc(R"DOC(
13 Convert sparse representations to dense with given indices.
14 
15 Transforms a sparse representation of map<id, value> represented as `indices`
16 vector and `values` tensor into a compacted tensor where the first dimension
17 is determined by the first dimension of the 3rd input if it is given or the
18 max index. Missing values are filled with zeros.
19 
20 The op supports duplicated indices and performs summation over corresponding
21 values. This behavior is useful for converting GradientSlices into dense
22 representation.
23 
24 After running this op:
25 
26  output[indices[i], :] += values[i] // sum over all indices[i] equal to the index
27  output[j, ...] = 0 if j not in indices
28 )DOC")
29  .Input(0, "indices", "1-D int32/int64 tensor of concatenated ids of data")
30  .Input(
31  1,
32  "values",
33  "Data tensor, first dimension has to match `indices`, "
34  "basic numeric types are supported")
35  .Input(
36  2,
37  "data_to_infer_dim",
38  "Optional: if provided, the first dimension of output is the first "
39  "dimension of this tensor.")
40  .Output(
41  0,
42  "output",
43  "Output tensor of the same type as `values` of shape `[len(lengths), "
44  "len(mask)] + shape(default_value)` (if `lengths` is not provided the "
45  "first dimension is omitted)");
46 
47 
48 namespace {
49 class GetSparseToDenseGradient : public GradientMakerBase {
50  using GradientMakerBase::GradientMakerBase;
51  vector<OperatorDef> GetGradientDefs() override {
52  return SingleGradientDef(
53  "Gather", "", vector<string>{GO(0), I(0)}, vector<string>{GI(1)});
54  }
55 };
56 
57 REGISTER_GRADIENT(SparseToDense, GetSparseToDenseGradient);
58 }
59 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13