Caffe2 - C++ API
A deep learning, cross platform ML framework
funhash_op.cc
1 
17 #include "caffe2/experiments/operators/funhash_op.h"
18 
19 namespace caffe2 {
20 namespace {
21 
22 REGISTER_CPU_OPERATOR(FunHash, FunHashOp<float, CPUContext>);
23 REGISTER_CPU_OPERATOR(FunHashGradient, FunHashGradientOp<float, CPUContext>);
24 
25 OPERATOR_SCHEMA(FunHash)
26  .NumInputs(4, 5)
27  .NumOutputs(1)
28  .SetDoc(R"DOC(
29 This layer compresses a fully-connected layer for sparse inputs
30 via hashing.
31 It takes four required inputs and an optional fifth input.
32 The first three inputs `scalars`, `indices`, and `segment_ids` are
33 the sparse segmented representation of sparse data, which are the
34 same as the last three inputs of the `SparseSortedSegmentWeightedSum`
35 operator. If the argument `num_segments` is specified, it would be used
36 as the first dimension for the output; otherwise it would be derived
37 from the maximum segment ID.
38 
39 The fourth input is a 1D weight vector. Each entry of the fully-connected
40 layer would be randomly mapped from one of the entries in this vector.
41 
42 When the optional fifth input vector is present, each weight of the
43 fully-connected layer would be the linear combination of K entries
44 randomly mapped from the weight vector, provided the input
45 (length-K vector) serves as the coefficients.
46 )DOC")
47  .Input(0, "scalars", "Values of the non-zero entries of the sparse data.")
48  .Input(1, "indices", "Indices to the non-zero valued features.")
49  .Input(2, "segment_ids",
50  "Segment IDs corresponding to the non-zero entries.")
51  .Input(3, "weight", "Weight vector")
52  .Input(4, "alpha",
53  "Optional coefficients for linear combination of hashed weights.")
54  .Output(0, "output",
55  "Output tensor with the first dimension equal to the number "
56  "of segments.")
57  .Arg("num_outputs", "Number of outputs")
58  .Arg("num_segments", "Number of segments");
59 
60 OPERATOR_SCHEMA(FunHashGradient)
61  .NumInputs(5, 6)
62  .NumOutputs(1, 2);
63 
64 class GetFunHashGradient : public GradientMakerBase {
65  using GradientMakerBase::GradientMakerBase;
66  vector<OperatorDef> GetGradientDefs() override {
67  if (def_.input_size() == 4) {
68  return SingleGradientDef(
69  "FunHashGradient", "",
70  vector<string>{GO(0), I(0), I(1), I(2), I(3)},
71  vector<string>{GI(3)});
72  }
73  // def_.input_size() == 5
74  return SingleGradientDef(
75  "FunHashGradient", "",
76  vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
77  vector<string>{GI(3), GI(4)});
78  }
79 };
80 
81 REGISTER_GRADIENT(FunHash, GetFunHashGradient);
82 
83 } // namespace
84 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.