Caffe2 - C++ API
A deep learning, cross platform ML framework
lstm_unit_op.cc
1 
17 #include "lstm_unit_op.h"
18 
19 namespace caffe2 {
20 REGISTER_CPU_OPERATOR(LSTMUnit, LSTMUnitOp<CPUContext>);
21 OPERATOR_SCHEMA(LSTMUnit)
22  .NumInputs(4, 5)
23  .NumOutputs(2)
24  .SetDoc(R"DOC(
25 LSTMUnit computes the activations of a standard LSTM (without peephole
26 connections), in a sequence-length aware fashion.
27 
28 Concretely, given the (fused) inputs X (TxNxD), the previous cell
29 state (NxD), and the sequence lengths (N), computes the LSTM
30 activations, avoiding computation if the input is invalid (as in, the
31 value at X{t][n] >= seqLengths[n].
32 
33 )DOC")
34  .Arg("forget_bias", "Bias term to add in while calculating forget gate")
35  .Arg(
36  "sequence_lengths",
37  "When false, the sequence lengths input is left out, "
38  "and all following inputs are shifted left by one.");
39 REGISTER_CPU_OPERATOR(LSTMUnitGradient, LSTMUnitGradientOp<CPUContext>);
40 OPERATOR_SCHEMA(LSTMUnitGradient)
41  .NumInputs(8, 9)
42  .NumOutputs(3)
43  .Arg(
44  "sequence_lengths",
45  "When false, the sequence lengths input is left out, "
46  "and all following inputs are shifted left by one.");
47 
49  using GradientMakerBase::GradientMakerBase;
50  vector<OperatorDef> GetGradientDefs() override {
51  if (GetFlagArgument(def_, "sequence_lengths", true)) {
52  return SingleGradientDef(
53  "LSTMUnitGradient",
54  "",
55  vector<string>{
56  I(0), I(1), I(2), I(3), I(4), O(0), O(1), GO(0), GO(1)},
57  vector<string>{GI(0), GI(1), GI(2)});
58  } else {
59  return SingleGradientDef(
60  "LSTMUnitGradient",
61  "",
62  vector<string>{I(0), I(1), I(2), I(3), O(0), O(1), GO(0), GO(1)},
63  vector<string>{GI(0), GI(1), GI(2)});
64  }
65  }
66 };
67 REGISTER_GRADIENT(LSTMUnit, GetLSTMUnitGradient);
68 }
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...