Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_funhash_op.cc
1 
17 #include "caffe2/experiments/operators/sparse_funhash_op.h"
18 
19 namespace caffe2 {
20 namespace {
21 
22 REGISTER_CPU_OPERATOR(SparseFunHash, SparseFunHashOp<float, CPUContext>);
23 REGISTER_CPU_OPERATOR(
24  SparseFunHashGradient,
25  SparseFunHashGradientOp<float, CPUContext>);
26 
27 OPERATOR_SCHEMA(SparseFunHash)
28  .NumInputs(4, 5)
29  .NumOutputs(1)
30  .DisallowInputFillers() // TODO: enable the filler
31  .SetDoc(R"DOC(
32 This layer compresses a fully-connected layer for sparse inputs
33 via hashing.
34 It takes four required inputs and an option fifth input.
35 The first three inputs `scalars`, `indices`, and `segment_ids` are
36 the sparse segmented representation of sparse data, which are the
37 same as the last three inputs of the `SparseSortedSegmentWeightedSum`
38 operator. If the argument `num_segments` is specified, it would be used
39 as the first dimension for the output; otherwise it would be derived
40 from the maximum segment ID.
41 
42 The fourth input is a 1D weight vector. Each entry of the fully-connected
43 layer would be randomly mapped from one of the entries in this vector.
44 
45 When the optional fifth input vector is present, each weight of the
46 fully-connected layer would be the linear combination of K entries
47 randomly mapped from the weight vector, provided the input
48 (length-K vector) serves as the coefficients.
49 )DOC")
50  .Input(0, "scalars", "Values of the non-zero entries of the sparse data.")
51  .Input(1, "indices", "Indices to the non-zero valued features.")
52  .Input(
53  2,
54  "segment_ids",
55  "Segment IDs corresponding to the non-zero entries.")
56  .Input(3, "weight", "Weight vector")
57  .Input(
58  4,
59  "alpha",
60  "Optional coefficients for linear combination of hashed weights.")
61  .Output(
62  0,
63  "output",
64  "Output tensor with the first dimension equal to the number "
65  "of segments.")
66  .Arg("num_outputs", "Number of outputs")
67  .Arg("num_segments", "Number of segments");
68 
69 OPERATOR_SCHEMA(SparseFunHashGradient)
70  .NumInputs(5, 6)
71  .NumOutputs(2, 3)
72  .DisallowInputFillers();
73 
74 class GetSparseFunHashGradient : public GradientMakerBase {
75  using GradientMakerBase::GradientMakerBase;
76  vector<OperatorDef> GetGradientDefs() override {
77  if (def_.input_size() == 4) {
78  return SingleGradientDef(
79  "SparseFunHashGradient",
80  "",
81  vector<string>{GO(0), I(0), I(1), I(2), I(3)},
82  vector<string>{GI_V(3), GI_I(3)});
83  }
84  // def_.input_size() == 5
85  return SingleGradientDef(
86  "SparseFunHashGradient",
87  "",
88  vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
89  vector<string>{GI_V(3), GI_I(3), GI(4)});
90  }
91 };
92 
93 REGISTER_GRADIENT(SparseFunHash, GetSparseFunHashGradient);
94 
95 } // namespace
96 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13