Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_op.cc
1 
17 #include "caffe2/operators/recurrent_network_op.h"
18 #include "caffe2/core/workspace.h"
19 #include "caffe2/utils/proto_utils.h"
20 
21 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT
22 #include <google/protobuf/text_format.h>
23 #endif
24 
25 CAFFE2_DEFINE_bool(
26  caffe2_rnn_executor,
27  true,
28  "If set, uses special RNN executor for executing RecurrentNetworkOp");
29 
30 namespace caffe2 {
31 CAFFE_KNOWN_TYPE(detail::ScratchWorkspaces);
32 
33 REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<CPUContext>);
34 OPERATOR_SCHEMA(RecurrentNetwork)
35  .NumInputs(1, INT_MAX)
36  .NumOutputs(2, INT_MAX)
37  .SetDoc(R"DOC(
38 Run the input network in a recurrent fashion. This can be used to
39 implement fairly general recurrent neural networks (RNNs).
40 
41 The operator proceeds as follows.
42 
43 - First, initialized the states from the input recurrent states
44 - For each timestep T, apply the links (that map offsets from input/output
45 tensors into the inputs/outputs for the `step` network)
46 - Finally, alias the recurrent states to the specified output blobs.
47 
48 This is a fairly special-case meta-operator, and so the implementation
49 is somewhat complex. It trades of generality (and frankly usability)
50 against performance and control (compared to e.g. TF
51 dynamic_rnn, Theano scan, etc).
52 
53 See the usage examples for a flavor of how to use it.
54 )DOC");
55 
56 REGISTER_CPU_OPERATOR(
57  RecurrentNetworkGradient,
58  RecurrentNetworkGradientOp<CPUContext>);
59 OPERATOR_SCHEMA(RecurrentNetworkGradient);
60 
61 REGISTER_CPU_OPERATOR(
62  rnn_internal_accumulate_gradient_input,
63  AccumulateInputGradientOp<CPUContext>);
64 OPERATOR_SCHEMA(rnn_internal_accumulate_gradient_input)
65  .NumInputs(3)
66  .NumOutputs(1, INT_MAX)
67  .EnforceInplace({{2, 0}})
68  .Private()
69  .SetDoc(R"DOC(
70 Internal RNN operator.
71 )DOC");
72 
73 REGISTER_CPU_OPERATOR(
74  rnn_internal_apply_link,
76 OPERATOR_SCHEMA(rnn_internal_apply_link)
77  .NumInputs(2)
78  .NumOutputs(2)
79  .EnforceInplace({{1, 1}})
80  .Private()
81  .SetDoc(R"DOC(
82 Internal RNN operator.
83 )DOC");
84 
86  using GradientMakerBase::GradientMakerBase;
87  std::vector<OperatorDef> GetGradientDefs() override {
88  ArgumentHelper argsHelper(def_);
89  auto params = argsHelper.GetRepeatedArgument<int32_t>("param");
90  auto recurrentInputs =
91  argsHelper.GetRepeatedArgument<int32_t>("initial_recurrent_state_ids");
92 
93  std::vector<std::string> gradientInputs;
94 
95  // Argument specifies which outputs have external gradient, (0) by default
96  auto outputs_with_grads =
97  argsHelper.GetRepeatedArgument<int32_t>("outputs_with_grads");
98  CAFFE_ENFORCE(outputs_with_grads.size() > 0);
99  for (auto id : outputs_with_grads) {
100  gradientInputs.push_back(GO(id));
101  }
102 
103  // All inputs and outputs are passed back
104  for (int i = 0; i < def_.input_size(); ++i) {
105  gradientInputs.push_back(I(i));
106  }
107  for (int i = 0; i < def_.output_size(); ++i) {
108  gradientInputs.push_back(O(i));
109  }
110 
111  // We calculate gradients only for parameters and recurrent inputs
112  std::vector<std::string> gradientOutputs;
113  gradientOutputs.push_back(GI(0));
114  for (auto id : params) {
115  gradientOutputs.push_back(GI(id));
116  }
117  for (auto id : recurrentInputs) {
118  gradientOutputs.push_back(GI(id));
119  }
120 
121  VLOG(1) << "Gradient blobs: " << Join(", ", gradientOutputs);
122 
123  return SingleGradientDef(
124  "RecurrentNetworkGradient", "", gradientInputs, gradientOutputs);
125  }
126 };
127 
128 REGISTER_GRADIENT(RecurrentNetwork, GetRecurrentNetworkGradient);
129 
130 namespace detail {
131 
132 std::map<string, string> GetRecurrentMapping(
133  const std::vector<detail::Link>& links,
134  bool backward) {
135  std::map<string, string> mappings;
136  for (auto it = links.begin(); it != links.end(); ++it) {
137  const auto& l1 = *it;
138 
139  // In backward op we expect to see offset 1 before offset 0 and
140  // vice versa.
141  const int offset_l1 = backward ? 1 : 0;
142  const int offset_l2 = 1 - offset_l1;
143  if (l1.offset == offset_l1) {
144  // Find offset = 1 from links. We could probaby rely on order, but
145  // since the number of links is links small, O(n^2) algo is ok
146  for (auto it2 = it + 1; it2 != links.end(); ++it2) {
147  const auto& l2 = *it2;
148  if (l2.offset == offset_l2 && l2.external == l1.external) {
149  mappings[l2.internal] = l1.internal;
150  break;
151  }
152  }
153  }
154  }
155  return mappings;
156 }
157 
158 void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef) {
159  for (auto& o : netdef->op()) {
160  ops.push_back(o);
161  }
162  netdef->mutable_op()->Clear();
163  for (auto& o : ops) {
164  auto* ao = netdef->add_op();
165  ao->CopyFrom(o);
166  }
167 }
168 
169 void AddApplyLinkOps(
170  const vector<Link>& links,
171  std::string timestep,
172  const DeviceOption& device_option,
173  NetDef* netdef) {
174  std::vector<OperatorDef> ops;
175  for (auto& link : links) {
176  OperatorDef opdef;
177  opdef.set_type("rnn_internal_apply_link");
178  opdef.add_input(timestep);
179  opdef.add_input(link.external);
180  opdef.add_output(link.internal);
181  opdef.add_output(link.external);
182  opdef.mutable_device_option()->CopyFrom(device_option);
183 
184  Argument* offset_arg = opdef.add_arg();
185  offset_arg->set_name("offset");
186  offset_arg->set_i(link.offset);
187 
188  Argument* window_arg = opdef.add_arg();
189  window_arg->set_name("window");
190  window_arg->set_i(link.window);
191 
192  // Find out if the linked blob is used first as an output: then we need
193  // to add control_input to that op
194  for (auto& op : *netdef->mutable_op()) {
195  if (HasInput(op, link.internal)) {
196  // First appears as an input, no need to do antyhing
197  continue;
198  }
199  if (HasOutput(op, link.internal)) {
200  op.add_control_input(link.internal);
201  break;
202  }
203  }
204 
205  ops.push_back(opdef);
206 
207  netdef->add_external_input(link.internal);
208  netdef->add_external_input(link.external);
209  }
210 
211  detail::PrependOps(ops, netdef);
212 }
213 
214 void extractLinks(
215  OperatorBase* op,
216  const std::string& internalArg,
217  const std::string& externalArg,
218  const std::string& offsetArg,
219  const std::string& windowArg,
220  std::vector<detail::Link>* links) {
221  const auto& internal = op->GetRepeatedArgument<std::string>(internalArg);
222  const auto& external = op->GetRepeatedArgument<std::string>(externalArg);
223  const auto& offset = op->GetRepeatedArgument<int32_t>(offsetArg);
224  const auto& window = op->GetRepeatedArgument<int32_t>(
225  windowArg, vector<int32_t>(offset.size(), 1));
226  CAFFE_ENFORCE_EQ(
227  internal.size(),
228  offset.size(),
229  "internal/offset mismatch: ",
230  internalArg,
231  " ",
232  externalArg);
233  CAFFE_ENFORCE_EQ(
234  external.size(),
235  offset.size(),
236  "external/offset mismatch: ",
237  externalArg,
238  " ",
239  offsetArg);
240  CAFFE_ENFORCE_EQ(
241  external.size(),
242  window.size(),
243  "external/window mismatch: ",
244  externalArg,
245  " ",
246  windowArg);
247  for (auto i = 0; i < internal.size(); ++i) {
248  detail::Link l;
249  l.internal = internal[i];
250  l.external = external[i];
251  l.offset = offset[i];
252  l.window = window[i];
253  links->push_back(l);
254  }
255 }
256 
257 NetDef extractNetDef(const OperatorDef& op, const std::string& argName) {
258  if (ArgumentHelper::HasSingleArgumentOfType<OperatorDef, NetDef>(
259  op, argName)) {
260  return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
261  op, argName, NetDef());
262  } else {
263 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT
264  NetDef result;
265  const auto netString =
266  ArgumentHelper::GetSingleArgument<OperatorDef, string>(op, argName, "");
267  CAFFE_ENFORCE(
268  google::protobuf::TextFormat::ParseFromString(netString, &result),
269  "Invalid NetDef");
270  return result;
271 #else
272  CAFFE_THROW("No valid NetDef for argument ", argName);
273 #endif
274  }
275 }
276 } // namespace detail
277 }
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...