Caffe2 - C++ API
A deep learning, cross platform ML framework
gru_unit_op.cc
1 
17 #include "gru_unit_op.h"
18 
19 namespace caffe2 {
20 REGISTER_CPU_OPERATOR(GRUUnit, GRUUnitOp<float, CPUContext>);
21 OPERATOR_SCHEMA(GRUUnit)
22  .NumInputs(3, 4)
23  .NumOutputs(1)
24  .SetDoc(R"DOC(
25 GRUUnit computes the activations of a standard GRU,
26 in a sequence-length aware fashion.
27 
28 Concretely, given the (fused) inputs X (TxNxD), the previous hidden
29 state (NxD), and the sequence lengths (N), computes the GRU
30 activations, avoiding computation if the input is invalid (as in, the
31 value at X[t][n] >= seqLengths[n].
32 
33 )DOC")
34  .Arg(
35  "drop_states",
36  "Bool to determine if hidden state is zeroes or passed "
37  "along for timesteps past the given sequence_length.")
38  .Arg(
39  "sequence_lengths",
40  "When false, the sequence lengths input is left out, "
41  "and all following inputs are shifted left by one.")
42  .Output(0, "hidden", "The new GRU hidden state calculated by this op.");
43 REGISTER_CPU_OPERATOR(GRUUnitGradient, GRUUnitGradientOp<float, CPUContext>);
44 OPERATOR_SCHEMA(GRUUnitGradient)
45  .NumInputs(5, 6)
46  .NumOutputs(2)
47  .Arg(
48  "sequence_lengths",
49  "When false, the sequence lengths input is left out, "
50  "and all following inputs are shifted left by one.");
51 
53  using GradientMakerBase::GradientMakerBase;
54  vector<OperatorDef> GetGradientDefs() override {
55  if (GetFlagArgument(def_, "sequence_lengths", true)) {
56  return SingleGradientDef(
57  "GRUUnitGradient",
58  "",
59  vector<string>{I(0), I(1), I(2), I(3), O(0), GO(0)},
60  vector<string>{GI(0), GI(1)});
61  } else {
62  return SingleGradientDef(
63  "GRUUnitGradient",
64  "",
65  vector<string>{I(0), I(1), I(2), O(0), GO(0)},
66  vector<string>{GI(0), GI(1)});
67  }
68  }
69 };
70 REGISTER_GRADIENT(GRUUnit, GetGRUUnitGradient);
71 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...