Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_funhash_op.cc
1 
17 #include "caffe2/experiments/operators/sparse_funhash_op.h"
18 
19 namespace caffe2 {
20 namespace {
21 
22 REGISTER_CPU_OPERATOR(SparseFunHash, SparseFunHashOp<float, CPUContext>);
23 REGISTER_CPU_OPERATOR(
24  SparseFunHashGradient,
25  SparseFunHashGradientOp<float, CPUContext>);
26 
27 OPERATOR_SCHEMA(SparseFunHash)
28  .NumInputs(4, 5)
29  .NumOutputs(1)
30  .SetDoc(R"DOC(
31 This layer compresses a fully-connected layer for sparse inputs
32 via hashing.
33 It takes four required inputs and an option fifth input.
34 The first three inputs `scalars`, `indices`, and `segment_ids` are
35 the sparse segmented representation of sparse data, which are the
36 same as the last three inputs of the `SparseSortedSegmentWeightedSum`
37 operator. If the argument `num_segments` is specified, it would be used
38 as the first dimension for the output; otherwise it would be derived
39 from the maximum segment ID.
40 
41 The fourth input is a 1D weight vector. Each entry of the fully-connected
42 layer would be randomly mapped from one of the entries in this vector.
43 
44 When the optional fifth input vector is present, each weight of the
45 fully-connected layer would be the linear combination of K entries
46 randomly mapped from the weight vector, provided the input
47 (length-K vector) serves as the coefficients.
48 )DOC")
49  .Input(0, "scalars", "Values of the non-zero entries of the sparse data.")
50  .Input(1, "indices", "Indices to the non-zero valued features.")
51  .Input(
52  2,
53  "segment_ids",
54  "Segment IDs corresponding to the non-zero entries.")
55  .Input(3, "weight", "Weight vector")
56  .Input(
57  4,
58  "alpha",
59  "Optional coefficients for linear combination of hashed weights.")
60  .Output(
61  0,
62  "output",
63  "Output tensor with the first dimension equal to the number "
64  "of segments.")
65  .Arg("num_outputs", "Number of outputs")
66  .Arg("num_segments", "Number of segments");
67 
68 OPERATOR_SCHEMA(SparseFunHashGradient).NumInputs(5, 6).NumOutputs(2, 3);
69 
70 class GetSparseFunHashGradient : public GradientMakerBase {
71  using GradientMakerBase::GradientMakerBase;
72  vector<OperatorDef> GetGradientDefs() override {
73  if (def_.input_size() == 4) {
74  return SingleGradientDef(
75  "SparseFunHashGradient",
76  "",
77  vector<string>{GO(0), I(0), I(1), I(2), I(3)},
78  vector<string>{GI_V(3), GI_I(3)});
79  }
80  // def_.input_size() == 5
81  return SingleGradientDef(
82  "SparseFunHashGradient",
83  "",
84  vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
85  vector<string>{GI_V(3), GI_I(3), GI(4)});
86  }
87 };
88 
89 REGISTER_GRADIENT(SparseFunHash, GetSparseFunHashGradient);
90 
91 } // namespace
92 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.