Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_op.cc
1 #include "caffe2/operators/rnn/recurrent_network_op.h"
2 #include "caffe2/core/workspace.h"
3 #include "caffe2/utils/proto_utils.h"
4 
5 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT
6 #endif
7 
8 C10_DEFINE_bool(
9  caffe2_rnn_executor,
10  true,
11  "If set, uses special RNN executor for executing RecurrentNetworkOp");
12 
13 namespace caffe2 {
14 CAFFE_KNOWN_TYPE(detail::ScratchWorkspaces);
15 
16 REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<CPUContext>);
17 OPERATOR_SCHEMA(RecurrentNetwork)
18  .NumInputs(1, INT_MAX)
19  .NumOutputs(2, INT_MAX)
20  .SetDoc(R"DOC(
21 Run the input network in a recurrent fashion. This can be used to
22 implement fairly general recurrent neural networks (RNNs).
23 
24 The operator proceeds as follows.
25 
26 - First, initialized the states from the input recurrent states
27 - For each timestep T, apply the links (that map offsets from input/output
28 tensors into the inputs/outputs for the `step` network)
29 - Finally, alias the recurrent states to the specified output blobs.
30 
31 This is a fairly special-case meta-operator, and so the implementation
32 is somewhat complex. It trades of generality (and frankly usability)
33 against performance and control (compared to e.g. TF
34 dynamic_rnn, Theano scan, etc).
35 
36 See the usage examples for a flavor of how to use it.
37 )DOC");
38 
39 REGISTER_CPU_OPERATOR(
40  RecurrentNetworkGradient,
41  RecurrentNetworkGradientOp<CPUContext>);
42 OPERATOR_SCHEMA(RecurrentNetworkGradient);
43 
44 REGISTER_CPU_OPERATOR(
45  rnn_internal_accumulate_gradient_input,
46  AccumulateInputGradientOp<CPUContext>);
47 OPERATOR_SCHEMA(rnn_internal_accumulate_gradient_input)
48  .NumInputs(3)
49  .NumOutputs(1, INT_MAX)
50  .EnforceInplace({{2, 0}})
51  .Private()
52  .SetDoc(R"DOC(
53 Internal RNN operator.
54 )DOC");
55 
56 REGISTER_CPU_OPERATOR(
57  rnn_internal_apply_link,
59 OPERATOR_SCHEMA(rnn_internal_apply_link)
60  .NumInputs(2)
61  .NumOutputs(2)
62  .EnforceInplace({{1, 1}})
63  .Private()
64  .SetDoc(R"DOC(
65 Internal RNN operator.
66 )DOC");
67 
69  using GradientMakerBase::GradientMakerBase;
70  std::vector<OperatorDef> GetGradientDefs() override {
71  ArgumentHelper argsHelper(def_);
72  auto params = argsHelper.GetRepeatedArgument<int32_t>("param");
73  auto recurrentInputs =
74  argsHelper.GetRepeatedArgument<int32_t>("initial_recurrent_state_ids");
75 
76  std::vector<std::string> gradientInputs;
77 
78  // Argument specifies which outputs have external gradient, (0) by default
79  auto outputs_with_grads =
80  argsHelper.GetRepeatedArgument<int32_t>("outputs_with_grads");
81  CAFFE_ENFORCE(outputs_with_grads.size() > 0);
82  for (auto id : outputs_with_grads) {
83  gradientInputs.push_back(GO(id));
84  }
85 
86  // All inputs and outputs are passed back
87  for (int i = 0; i < def_.input_size(); ++i) {
88  gradientInputs.push_back(I(i));
89  }
90  for (int i = 0; i < def_.output_size(); ++i) {
91  gradientInputs.push_back(O(i));
92  }
93 
94  // We calculate gradients only for parameters and recurrent inputs
95  std::vector<std::string> gradientOutputs;
96  gradientOutputs.push_back(GI(0));
97  for (auto id : params) {
98  gradientOutputs.push_back(GI(id));
99  }
100  for (auto id : recurrentInputs) {
101  gradientOutputs.push_back(GI(id));
102  }
103 
104  VLOG(1) << "Gradient blobs: " << c10::Join(", ", gradientOutputs);
105 
106  return SingleGradientDef(
107  "RecurrentNetworkGradient", "", gradientInputs, gradientOutputs);
108  }
109 };
110 
111 REGISTER_GRADIENT(RecurrentNetwork, GetRecurrentNetworkGradient);
112 
113 namespace detail {
114 
115 std::map<string, string> GetRecurrentMapping(
116  const std::vector<detail::Link>& links,
117  bool backward) {
118  std::map<string, string> mappings;
119  for (auto it = links.begin(); it != links.end(); ++it) {
120  const auto& l1 = *it;
121 
122  // In backward op we expect to see offset 1 before offset 0 and
123  // vice versa.
124  const int offset_l1 = backward ? 1 : 0;
125  const int offset_l2 = 1 - offset_l1;
126  if (l1.offset == offset_l1) {
127  // Find offset = 1 from links. We could probaby rely on order, but
128  // since the number of links is links small, O(n^2) algo is ok
129  for (auto it2 = it + 1; it2 != links.end(); ++it2) {
130  const auto& l2 = *it2;
131  if (l2.offset == offset_l2 && l2.external == l1.external) {
132  mappings[l2.internal] = l1.internal;
133  break;
134  }
135  }
136  }
137  }
138  return mappings;
139 }
140 
141 void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef) {
142  for (auto& o : netdef->op()) {
143  ops.push_back(o);
144  }
145  netdef->mutable_op()->Clear();
146  for (auto& o : ops) {
147  auto* ao = netdef->add_op();
148  ao->CopyFrom(o);
149  }
150 }
151 
152 void AddApplyLinkOps(
153  const vector<Link>& links,
154  std::string timestep,
155  const DeviceOption& device_option,
156  NetDef* netdef) {
157  std::vector<OperatorDef> ops;
158  for (auto& link : links) {
159  OperatorDef opdef;
160  opdef.set_type("rnn_internal_apply_link");
161  opdef.add_input(timestep);
162  opdef.add_input(link.external);
163  opdef.add_output(link.internal);
164  opdef.add_output(link.external);
165  opdef.mutable_device_option()->CopyFrom(device_option);
166 
167  Argument* offset_arg = opdef.add_arg();
168  offset_arg->set_name("offset");
169  offset_arg->set_i(link.offset);
170 
171  Argument* window_arg = opdef.add_arg();
172  window_arg->set_name("window");
173  window_arg->set_i(link.window);
174 
175  // Find out if the linked blob is used first as an output: then we need
176  // to add control_input to that op
177  for (auto& op : *netdef->mutable_op()) {
178  if (HasInput(op, link.internal)) {
179  // First appears as an input, no need to do antyhing
180  continue;
181  }
182  if (HasOutput(op, link.internal)) {
183  op.add_control_input(link.internal);
184  break;
185  }
186  }
187 
188  ops.push_back(opdef);
189 
190  netdef->add_external_input(link.internal);
191  netdef->add_external_input(link.external);
192  }
193 
194  detail::PrependOps(ops, netdef);
195 }
196 
197 void extractLinks(
198  OperatorBase* op,
199  const std::string& internalArg,
200  const std::string& externalArg,
201  const std::string& offsetArg,
202  const std::string& windowArg,
203  std::vector<detail::Link>* links) {
204  const auto& internal = op->GetRepeatedArgument<std::string>(internalArg);
205  const auto& external = op->GetRepeatedArgument<std::string>(externalArg);
206  const auto& offset = op->GetRepeatedArgument<int32_t>(offsetArg);
207  const auto& window = op->GetRepeatedArgument<int32_t>(
208  windowArg, vector<int32_t>(offset.size(), 1));
209  CAFFE_ENFORCE_EQ(
210  internal.size(),
211  offset.size(),
212  "internal/offset mismatch: ",
213  internalArg,
214  " ",
215  externalArg);
216  CAFFE_ENFORCE_EQ(
217  external.size(),
218  offset.size(),
219  "external/offset mismatch: ",
220  externalArg,
221  " ",
222  offsetArg);
223  CAFFE_ENFORCE_EQ(
224  external.size(),
225  window.size(),
226  "external/window mismatch: ",
227  externalArg,
228  " ",
229  windowArg);
230  for (auto i = 0; i < internal.size(); ++i) {
231  detail::Link l;
232  l.internal = internal[i];
233  l.external = external[i];
234  l.offset = offset[i];
235  l.window = window[i];
236  links->push_back(l);
237  }
238 }
239 
240 NetDef extractNetDef(const OperatorDef& op, const std::string& argName) {
241  if (ArgumentHelper::HasSingleArgumentOfType<OperatorDef, NetDef>(
242  op, argName)) {
243  return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
244  op, argName, NetDef());
245  } else {
246 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT
247  NetDef result;
248  const auto netString =
249  ArgumentHelper::GetSingleArgument<OperatorDef, string>(op, argName, "");
250  CAFFE_ENFORCE(
251  TextFormat::ParseFromString(netString, &result),
252  "Invalid NetDef");
253  return result;
254 #else
255  CAFFE_THROW("No valid NetDef for argument ", argName);
256 #endif
257  }
258 }
259 } // namespace detail
260 }
A helper class to index into arguments.
Definition: proto_utils.h:200
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...