Caffe2 - C++ API
A deep learning, cross platform ML framework
gru_unit_op.cc
1 #include "gru_unit_op.h"
2 
3 namespace caffe2 {
4 REGISTER_CPU_OPERATOR(GRUUnit, GRUUnitOp<float, CPUContext>);
5 OPERATOR_SCHEMA(GRUUnit)
6  .NumInputs(3, 4)
7  .NumOutputs(1)
8  .SetDoc(R"DOC(
9 GRUUnit computes the activations of a standard GRU,
10 in a sequence-length aware fashion.
11 
12 Concretely, given the (fused) inputs X (TxNxD), the previous hidden
13 state (NxD), and the sequence lengths (N), computes the GRU
14 activations, avoiding computation if the input is invalid (as in, the
15 value at X[t][n] >= seqLengths[n].
16 
17 )DOC")
18  .Arg(
19  "drop_states",
20  "Bool to determine if hidden state is zeroes or passed "
21  "along for timesteps past the given sequence_length.")
22  .Arg(
23  "sequence_lengths",
24  "When false, the sequence lengths input is left out, "
25  "and all following inputs are shifted left by one.")
26  .Output(0, "hidden", "The new GRU hidden state calculated by this op.");
27 REGISTER_CPU_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CPUContext>);
28 OPERATOR_SCHEMA(GRUUnitGradient)
29  .NumInputs(5, 6)
30  .NumOutputs(2)
31  .Arg(
32  "sequence_lengths",
33  "When false, the sequence lengths input is left out, "
34  "and all following inputs are shifted left by one.");
35 
37  using GradientMakerBase::GradientMakerBase;
38  vector<OperatorDef> GetGradientDefs() override {
39  if (GetFlagArgument(def_, "sequence_lengths", true)) {
40  return SingleGradientDef(
41  "GRUUnitGradient",
42  "",
43  vector<string>{I(0), I(1), I(2), I(3), O(0), GO(0)},
44  vector<string>{GI(0), GI(1)});
45  } else {
46  return SingleGradientDef(
47  "GRUUnitGradient",
48  "",
49  vector<string>{I(0), I(1), I(2), O(0), GO(0)},
50  vector<string>{GI(0), GI(1)});
51  }
52  }
53 };
54 REGISTER_GRADIENT(GRUUnit, GetGRUUnitGradient);
55 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...