17 #include "caffe2/experiments/operators/funhash_op.h" 22 REGISTER_CPU_OPERATOR(FunHash, FunHashOp<float, CPUContext>);
23 REGISTER_CPU_OPERATOR(FunHashGradient, FunHashGradientOp<float, CPUContext>);
25 OPERATOR_SCHEMA(FunHash)
29 This layer compresses a fully-connected layer for sparse inputs 31 It takes four required inputs and an optional fifth input. 32 The first three inputs `scalars`, `indices`, and `segment_ids` are 33 the sparse segmented representation of sparse data, which are the 34 same as the last three inputs of the `SparseSortedSegmentWeightedSum` 35 operator. If the argument `num_segments` is specified, it would be used 36 as the first dimension for the output; otherwise it would be derived 37 from the maximum segment ID. 39 The fourth input is a 1D weight vector. Each entry of the fully-connected 40 layer would be randomly mapped from one of the entries in this vector. 42 When the optional fifth input vector is present, each weight of the 43 fully-connected layer would be the linear combination of K entries 44 randomly mapped from the weight vector, provided the input 45 (length-K vector) serves as the coefficients. 47 .Input(0, "scalars",
"Values of the non-zero entries of the sparse data.")
48 .Input(1,
"indices",
"Indices to the non-zero valued features.")
49 .Input(2,
"segment_ids",
50 "Segment IDs corresponding to the non-zero entries.")
51 .Input(3,
"weight",
"Weight vector")
53 "Optional coefficients for linear combination of hashed weights.")
55 "Output tensor with the first dimension equal to the number " 57 .Arg(
"num_outputs",
"Number of outputs")
58 .Arg(
"num_segments",
"Number of segments");
60 OPERATOR_SCHEMA(FunHashGradient)
64 class GetFunHashGradient :
public GradientMakerBase {
65 using GradientMakerBase::GradientMakerBase;
66 vector<OperatorDef> GetGradientDefs()
override {
67 if (def_.input_size() == 4) {
68 return SingleGradientDef(
69 "FunHashGradient",
"",
70 vector<string>{GO(0), I(0), I(1), I(2), I(3)},
71 vector<string>{GI(3)});
74 return SingleGradientDef(
75 "FunHashGradient",
"",
76 vector<string>{GO(0), I(0), I(1), I(2), I(3), I(4)},
77 vector<string>{GI(3), GI(4)});
81 REGISTER_GRADIENT(FunHash, GetFunHashGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...