1 #include "caffe2/operators/rnn/recurrent_network_op.h" 2 #include "caffe2/core/workspace.h" 3 #include "caffe2/utils/proto_utils.h" 5 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT 11 "If set, uses special RNN executor for executing RecurrentNetworkOp");
14 CAFFE_KNOWN_TYPE(detail::ScratchWorkspaces);
16 REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<CPUContext>);
17 OPERATOR_SCHEMA(RecurrentNetwork)
18 .NumInputs(1, INT_MAX)
19 .NumOutputs(2, INT_MAX)
21 Run the input network in a recurrent fashion. This can be used to 22 implement fairly general recurrent neural networks (RNNs). 24 The operator proceeds as follows. 26 - First, initialized the states from the input recurrent states 27 - For each timestep T, apply the links (that map offsets from input/output 28 tensors into the inputs/outputs for the `step` network) 29 - Finally, alias the recurrent states to the specified output blobs. 31 This is a fairly special-case meta-operator, and so the implementation 32 is somewhat complex. It trades of generality (and frankly usability) 33 against performance and control (compared to e.g. TF 34 dynamic_rnn, Theano scan, etc). 36 See the usage examples for a flavor of how to use it. 39 REGISTER_CPU_OPERATOR( 40 RecurrentNetworkGradient, 41 RecurrentNetworkGradientOp<CPUContext>); 42 OPERATOR_SCHEMA(RecurrentNetworkGradient); 44 REGISTER_CPU_OPERATOR( 45 rnn_internal_accumulate_gradient_input, 46 AccumulateInputGradientOp<CPUContext>); 47 OPERATOR_SCHEMA(rnn_internal_accumulate_gradient_input) 49 .NumOutputs(1, INT_MAX) 50 .EnforceInplace({{2, 0}}) 53 Internal RNN operator. 56 REGISTER_CPU_OPERATOR( 57 rnn_internal_apply_link, 59 OPERATOR_SCHEMA(rnn_internal_apply_link) 62 .EnforceInplace({{1, 1}}) 65 Internal RNN operator. 69 using GradientMakerBase::GradientMakerBase;
70 std::vector<OperatorDef> GetGradientDefs()
override {
72 auto params = argsHelper.GetRepeatedArgument<int32_t>(
"param");
73 auto recurrentInputs =
74 argsHelper.GetRepeatedArgument<int32_t>(
"initial_recurrent_state_ids");
76 std::vector<std::string> gradientInputs;
79 auto outputs_with_grads =
80 argsHelper.GetRepeatedArgument<int32_t>(
"outputs_with_grads");
81 CAFFE_ENFORCE(outputs_with_grads.size() > 0);
82 for (
auto id : outputs_with_grads) {
83 gradientInputs.push_back(GO(
id));
87 for (
int i = 0; i < def_.input_size(); ++i) {
88 gradientInputs.push_back(I(i));
90 for (
int i = 0; i < def_.output_size(); ++i) {
91 gradientInputs.push_back(O(i));
95 std::vector<std::string> gradientOutputs;
96 gradientOutputs.push_back(GI(0));
97 for (
auto id : params) {
98 gradientOutputs.push_back(GI(
id));
100 for (
auto id : recurrentInputs) {
101 gradientOutputs.push_back(GI(
id));
104 VLOG(1) <<
"Gradient blobs: " << c10::Join(
", ", gradientOutputs);
107 "RecurrentNetworkGradient",
"", gradientInputs, gradientOutputs);
115 std::map<string, string> GetRecurrentMapping(
116 const std::vector<detail::Link>& links,
118 std::map<string, string> mappings;
119 for (
auto it = links.begin(); it != links.end(); ++it) {
120 const auto& l1 = *it;
124 const int offset_l1 = backward ? 1 : 0;
125 const int offset_l2 = 1 - offset_l1;
126 if (l1.offset == offset_l1) {
129 for (
auto it2 = it + 1; it2 != links.end(); ++it2) {
130 const auto& l2 = *it2;
131 if (l2.offset == offset_l2 && l2.external == l1.external) {
132 mappings[l2.internal] = l1.internal;
141 void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef) {
142 for (
auto& o : netdef->op()) {
145 netdef->mutable_op()->Clear();
146 for (
auto& o : ops) {
147 auto* ao = netdef->add_op();
152 void AddApplyLinkOps(
153 const vector<Link>& links,
154 std::string timestep,
155 const DeviceOption& device_option,
157 std::vector<OperatorDef> ops;
158 for (
auto& link : links) {
160 opdef.set_type(
"rnn_internal_apply_link");
161 opdef.add_input(timestep);
162 opdef.add_input(link.external);
163 opdef.add_output(link.internal);
164 opdef.add_output(link.external);
165 opdef.mutable_device_option()->CopyFrom(device_option);
167 Argument* offset_arg = opdef.add_arg();
168 offset_arg->set_name(
"offset");
169 offset_arg->set_i(link.offset);
171 Argument* window_arg = opdef.add_arg();
172 window_arg->set_name(
"window");
173 window_arg->set_i(link.window);
177 for (
auto& op : *netdef->mutable_op()) {
178 if (HasInput(op, link.internal)) {
182 if (HasOutput(op, link.internal)) {
183 op.add_control_input(link.internal);
188 ops.push_back(opdef);
190 netdef->add_external_input(link.internal);
191 netdef->add_external_input(link.external);
194 detail::PrependOps(ops, netdef);
199 const std::string& internalArg,
200 const std::string& externalArg,
201 const std::string& offsetArg,
202 const std::string& windowArg,
203 std::vector<detail::Link>* links) {
204 const auto&
internal = op->GetRepeatedArgument<std::string>(internalArg);
205 const auto& external = op->GetRepeatedArgument<std::string>(externalArg);
206 const auto& offset = op->GetRepeatedArgument<int32_t>(offsetArg);
207 const auto& window = op->GetRepeatedArgument<int32_t>(
208 windowArg, vector<int32_t>(offset.size(), 1));
212 "internal/offset mismatch: ",
219 "external/offset mismatch: ",
226 "external/window mismatch: ",
230 for (
auto i = 0; i <
internal.size(); ++i) {
232 l.internal =
internal[i];
233 l.external = external[i];
234 l.offset = offset[i];
235 l.window = window[i];
240 NetDef extractNetDef(
const OperatorDef& op,
const std::string& argName) {
241 if (ArgumentHelper::HasSingleArgumentOfType<OperatorDef, NetDef>(
243 return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
244 op, argName, NetDef());
246 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT 248 const auto netString =
249 ArgumentHelper::GetSingleArgument<OperatorDef, string>(op, argName,
"");
251 TextFormat::ParseFromString(netString, &result),
255 CAFFE_THROW(
"No valid NetDef for argument ", argName);
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...