Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_to_dense_mask_op.cc
1 
17 #include "caffe2/operators/sparse_to_dense_mask_op.h"
18 
19 namespace caffe2 {
20 namespace {
21 
22 REGISTER_CPU_OPERATOR(SparseToDenseMask, SparseToDenseMaskOp<CPUContext>);
23 REGISTER_CPU_OPERATOR(
24  SparseToDenseMaskGradient,
25  SparseToDenseMaskGradientOp<CPUContext>);
26 
27 OPERATOR_SCHEMA(SparseToDenseMask)
28  .NumInputs(3, 4)
29  .NumOutputs(1, 2)
30  .TensorInferenceFunction([](const OperatorDef& def,
31  const vector<TensorShape>& in) {
32  ArgumentHelper helper(def);
33  auto mask = helper.template GetRepeatedArgument<int64_t>("mask");
34  bool return_presence_mask = helper.template GetSingleArgument<bool>(
35  "return_presence_mask", false);
36  vector<TensorShape> out(1);
37 
38  if (in.size() == 4) {
39  out[0].add_dims(in[3].dims(0));
40  }
41  out[0].add_dims(mask.size());
42  for (const auto dim : in[2].dims()) {
43  out[0].add_dims(dim);
44  }
45  out[0].set_data_type(in[2].data_type());
46 
47  if (return_presence_mask) {
48  out.emplace_back();
49  if (in.size() == 4) {
50  out[1].add_dims(in[3].dims(0));
51  }
52  out[1].add_dims(mask.size());
53  out[1].set_data_type(TensorProto::BOOL);
54  }
55 
56  return out;
57  })
58  .SetDoc(R"DOC(
59 Convert sparse representations to dense with given indices.
60 
61 Transforms a sparse representation of map<id, value> represented as `indices`
62 vector and `values` tensor into a compacted tensor where the first dimension
63 corresponds to each id provided in mask argument. Missing values are filled with
64 the value of `default_value`. After running this op:
65 
66  output[j, :] = values[i] # where mask[j] == indices[i]
67  output[j, ...] = default_value # when mask[j] doesn't appear in indices
68 
69 If `lengths` is provided and not empty, and extra "batch" dimension is prepended
70 to the output.
71 
72 `values` and `default_value` can have additional matching dimensions, operation
73 is performed on the entire subtensor in thise case.
74 
75 For example, if `lengths` is supplied and `values` is 1-D vector of floats and
76 `default_value` is a float scalar, the output is going to be a float matrix
77 of size `len(lengths) X len(mask)`
78 )DOC")
79  .Arg(
80  "mask",
81  "list(int) argument with desired ids on the 'dense' output dimension")
82  .Arg(
83  "return_presence_mask",
84  "bool whether to return presence mask, false by default")
85  .Input(0, "indices", "1-D int32/int64 tensor of concatenated ids of data")
86  .Input(1, "values", "Data tensor, first dimension has to match `indices`")
87  .Input(
88  2,
89  "default_value",
90  "Default value for the output if the id is not present in `indices`. "
91  "Must have the same type as `values` and the same shape, but without "
92  "the first dimension")
93  .Input(
94  3,
95  "lengths",
96  "Optional lengths to represent a batch of `indices` and `values`.")
97  .Output(
98  0,
99  "output",
100  "Output tensor of the same type as `values` of shape `[len(lengths), "
101  "len(mask)] + shape(default_value)` (if `lengths` is not provided the "
102  "first dimension is omitted)")
103  .Output(
104  1,
105  "presence_mask",
106  "Bool tensor of shape `[len(lengths), len(mask)]` (if `lengths` is not "
107  "provided the first dimension is omitted). True when a value for given "
108  "id was present, false otherwise.");
109 
110 OPERATOR_SCHEMA(SparseToDenseMaskGradient)
111  .NumInputs(2, 3)
112  .NumOutputs(1)
113  .SetDoc(R"DOC(
114 The output is the gradient of the input value from SparseToDenseMask. The
115 gradient for default_value has not been implemented.
116 )DOC");
117 
118 class GetSparseToDenseMaskGradient : public GradientMakerBase {
119  using GradientMakerBase::GradientMakerBase;
120  vector<OperatorDef> GetGradientDefs() override {
121  vector<string> blob_names{I(0), GO(0)};
122 
123  // Add lengths blob if given
124  if (def_.input_size() == 4) {
125  blob_names.push_back(I(3));
126  }
127  return SingleGradientDef(
128  "SparseToDenseMaskGradient", "", blob_names, vector<string>{GI(1)});
129  }
130 };
131 REGISTER_GRADIENT(SparseToDenseMask, GetSparseToDenseMaskGradient);
132 } // namespace
133 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.