Caffe2 - C++ API
A deep learning, cross platform ML framework
lstm_unit_op.cc
1 #include "lstm_unit_op.h"
2 
3 namespace caffe2 {
4 REGISTER_CPU_OPERATOR(LSTMUnit, LSTMUnitOp<CPUContext>);
5 OPERATOR_SCHEMA(LSTMUnit)
6  .NumInputs(4, 5)
7  .NumOutputs(2)
8  .SetDoc(R"DOC(
9 LSTMUnit computes the activations of a standard LSTM (without peephole
10 connections), in a sequence-length aware fashion.
11 
12 Concretely, given the (fused) inputs X (TxNxD), the previous cell
13 state (NxD), and the sequence lengths (N), computes the LSTM
14 activations, avoiding computation if the input is invalid (as in, the
15 value at X{t][n] >= seqLengths[n].
16 
17 )DOC")
18  .Arg("forget_bias", "Bias term to add in while calculating forget gate")
19  .Arg(
20  "sequence_lengths",
21  "When false, the sequence lengths input is left out, "
22  "and all following inputs are shifted left by one.");
23 REGISTER_CPU_OPERATOR(LSTMUnitGradient, LSTMUnitGradientOp<CPUContext>);
24 OPERATOR_SCHEMA(LSTMUnitGradient)
25  .NumInputs(8, 9)
26  .NumOutputs(3)
27  .Arg(
28  "sequence_lengths",
29  "When false, the sequence lengths input is left out, "
30  "and all following inputs are shifted left by one.");
31 
33  using GradientMakerBase::GradientMakerBase;
34  vector<OperatorDef> GetGradientDefs() override {
35  if (GetFlagArgument(def_, "sequence_lengths", true)) {
36  return SingleGradientDef(
37  "LSTMUnitGradient",
38  "",
39  vector<string>{
40  I(0), I(1), I(2), I(3), I(4), O(0), O(1), GO(0), GO(1)},
41  vector<string>{GI(0), GI(1), GI(2)});
42  } else {
43  return SingleGradientDef(
44  "LSTMUnitGradient",
45  "",
46  vector<string>{I(0), I(1), I(2), I(3), O(0), O(1), GO(0), GO(1)},
47  vector<string>{GI(0), GI(1), GI(2)});
48  }
49  }
50 };
51 REGISTER_GRADIENT(LSTMUnit, GetLSTMUnitGradient);
52 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...