1 #include "lstm_unit_op.h" 4 REGISTER_CPU_OPERATOR(LSTMUnit, LSTMUnitOp<CPUContext>);
5 OPERATOR_SCHEMA(LSTMUnit)
9 LSTMUnit computes the activations of a standard LSTM (without peephole 10 connections), in a sequence-length aware fashion. 12 Concretely, given the (fused) inputs X (TxNxD), the previous cell 13 state (NxD), and the sequence lengths (N), computes the LSTM 14 activations, avoiding computation if the input is invalid (as in, the 15 value at X{t][n] >= seqLengths[n]. 18 .Arg("forget_bias",
"Bias term to add in while calculating forget gate")
21 "When false, the sequence lengths input is left out, " 22 "and all following inputs are shifted left by one.");
23 REGISTER_CPU_OPERATOR(LSTMUnitGradient, LSTMUnitGradientOp<CPUContext>);
24 OPERATOR_SCHEMA(LSTMUnitGradient)
29 "When false, the sequence lengths input is left out, " 30 "and all following inputs are shifted left by one.");
33 using GradientMakerBase::GradientMakerBase;
34 vector<OperatorDef> GetGradientDefs()
override {
35 if (GetFlagArgument(def_,
"sequence_lengths",
true)) {
40 I(0), I(1), I(2), I(3), I(4), O(0), O(1), GO(0), GO(1)},
41 vector<string>{GI(0), GI(1), GI(2)});
46 vector<string>{I(0), I(1), I(2), I(3), O(0), O(1), GO(0), GO(1)},
47 vector<string>{GI(0), GI(1), GI(2)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...