1 #include "gru_unit_op.h" 4 REGISTER_CPU_OPERATOR(GRUUnit, GRUUnitOp<float, CPUContext>);
5 OPERATOR_SCHEMA(GRUUnit)
9 GRUUnit computes the activations of a standard GRU, 10 in a sequence-length aware fashion. 12 Concretely, given the (fused) inputs X (TxNxD), the previous hidden 13 state (NxD), and the sequence lengths (N), computes the GRU 14 activations, avoiding computation if the input is invalid (as in, the 15 value at X[t][n] >= seqLengths[n]. 20 "Bool to determine if hidden state is zeroes or passed " 21 "along for timesteps past the given sequence_length.")
24 "When false, the sequence lengths input is left out, " 25 "and all following inputs are shifted left by one.")
26 .Output(0,
"hidden",
"The new GRU hidden state calculated by this op.");
27 REGISTER_CPU_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CPUContext>);
28 OPERATOR_SCHEMA(GRUUnitGradient)
33 "When false, the sequence lengths input is left out, " 34 "and all following inputs are shifted left by one.");
37 using GradientMakerBase::GradientMakerBase;
38 vector<OperatorDef> GetGradientDefs()
override {
39 if (GetFlagArgument(def_,
"sequence_lengths",
true)) {
43 vector<string>{I(0), I(1), I(2), I(3), O(0), GO(0)},
44 vector<string>{GI(0), GI(1)});
49 vector<string>{I(0), I(1), I(2), O(0), GO(0)},
50 vector<string>{GI(0), GI(1)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...