Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_op.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_
2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/operators/rnn/recurrent_network_executor.h"
9 #include "caffe2/utils/conversions.h"
10 #include "caffe2/utils/math.h"
11 
12 C10_DECLARE_bool(caffe2_rnn_executor);
13 
14 namespace caffe2 {
15 namespace detail {
16 
17 struct Param {
18  std::string param;
19  std::string grad;
20  std::string cellGradient;
21 };
22 
24  std::string state;
25  std::string input;
26 };
27 
29  std::string param;
30  std::string grad;
31  std::string externalGrad;
32  std::string lastExternalGrad;
33  int32_t offset;
34 };
35 
36 struct OffsetAlias {
37  std::string src;
38  std::string dst;
39  int32_t offset{0};
40 };
41 
42 struct Link {
43  std::string internal;
44  std::string external;
45  int32_t offset{0};
46  int32_t window{1};
47 };
48 
49 struct CAFFE2_API ScratchWorkspaces {
50  std::vector<std::shared_ptr<Workspace>> stepWorkspaces;
51  std::shared_ptr<Workspace> sharedBlobsWs = nullptr;
52 };
53 
54 inline void UpdateTimestepBlob(Workspace* ws, std::string blob_name, int t) {
55  BlobGetMutableTensor(ws->CreateBlob(blob_name), CPU)->Resize(1);
56  auto timestepBlob = ws->GetBlob(blob_name);
57  CAFFE_ENFORCE(timestepBlob);
58  BlobGetMutableTensor(timestepBlob, CPU)->template mutable_data<int32_t>()[0] =
59  t;
60 }
61 
62 CAFFE2_API std::map<string, string> GetRecurrentMapping(
63  const std::vector<detail::Link>& links,
64  bool backward);
65 
66 template <typename T, typename Context>
67 void applyOffsetAlias(
68  const OffsetAlias& oc,
69  Workspace* ws,
70  Context* /*context*/) {
71  VLOG(1) << "Aliasing: " << oc.src << " to: " << oc.dst
72  << " at offset: " << oc.offset;
73  auto srcBlob = ws->GetBlob(oc.src);
74  CAFFE_ENFORCE(srcBlob);
75  auto* src = BlobGetMutableTensor(srcBlob, Context::GetDeviceType());
76  auto* dst =
77  BlobGetMutableTensor(ws->GetBlob(oc.dst), Context::GetDeviceType());
78  auto timestep = src->numel() / src->size(0);
79  auto dims = src->sizes().vec();
80  const int32_t startDstTimestep =
81  oc.offset >= 0 ? oc.offset : src->size(0) + oc.offset;
82  const int32_t numDstTimesteps = src->size(0) - startDstTimestep;
83  if (numDstTimesteps >= 1) {
84  dims[0] = numDstTimesteps;
85  dst->Resize(dims);
86  CAFFE_ENFORCE(timestep == dst->numel() / numDstTimesteps, "Invalid offset");
87  dst->ShareExternalPointer(
88  src->template mutable_data<T>() + startDstTimestep * timestep);
89  } else {
90  CAFFE_ENFORCE_EQ(
91  numDstTimesteps, 0, "Invalid number of timesteps: ", numDstTimesteps);
92  dims[0] = 0;
93  dst->Resize(dims);
94  dst->template mutable_data<T>();
95  }
96 }
97 
98 template <typename T, class Context>
99 void repeatCopy(
100  size_t repeat_n,
101  size_t n,
102  const T* src,
103  T* dst,
104  Context* context) {
105  for (int i = 0; i < repeat_n; ++i) {
106  context->template CopySameDevice<T>(n, src, dst + i * n);
107  }
108 }
109 
114 template <typename T, typename Context>
116  const RecurrentInput& rc,
117  int32_t seqLen,
118  int32_t batchSize,
119  Workspace* ws,
120  Context* context) {
121  auto stateBlob = ws->GetBlob(rc.state);
122  CAFFE_ENFORCE(stateBlob);
123  auto* state = BlobGetMutableTensor(stateBlob, Context::GetDeviceType());
124 
125  auto inputBlob = ws->GetBlob(rc.input);
126  CAFFE_ENFORCE(inputBlob);
127  const auto& input = inputBlob->template Get<Tensor>();
128  CAFFE_ENFORCE_GE(input.dim(), 1, rc.input);
129  CAFFE_ENFORCE_LE(input.dim(), 3, rc.input);
130 
131  const auto stateSize = input.size(input.dim() - 1);
132  // Sometimes we want to provide more than one initial step.
133  // For example, if we do a convolution op in step net
134  // and need a sufficient left padding around the input.
135  // This could be used together with links where window != 1.
136  auto initialStateLength = 1;
137  if (input.dim() == 3) {
138  initialStateLength = input.size(0);
139  }
140  // States at [0, ..., (T + initialStateLength - 1)] (inclusive)
141  state->Resize(seqLen + initialStateLength, batchSize, stateSize);
142 
143  if (input.dim() >= 2) {
144  CAFFE_ENFORCE_EQ(input.size(input.dim() - 2), batchSize, rc.input);
145  context->template CopySameDevice<T>(
146  batchSize * stateSize * initialStateLength,
147  input.template data<T>(),
148  state->template mutable_data<T>());
149  } else {
150  // Usually, the initial state is the same for all inputs in the batch.
151  // So the op conveniently accepts 1-D input and copies it batchSize times.
152  repeatCopy<T, Context>(
153  batchSize,
154  stateSize,
155  input.template data<T>(),
156  state->template mutable_data<T>(),
157  context);
158  }
159 }
160 
161 CAFFE2_API void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef);
162 
163 CAFFE2_API void AddApplyLinkOps(
164  const vector<Link>& links,
165  std::string timestep,
166  const DeviceOption& device_option,
167  NetDef* netdef);
168 
169 CAFFE2_API void extractLinks(
170  OperatorBase* op,
171  const std::string& internalArg,
172  const std::string& externalArg,
173  const std::string& offsetArg,
174  const std::string& windowArg,
175  std::vector<detail::Link>* links);
176 
177 CAFFE2_API NetDef
178 extractNetDef(const OperatorDef& op, const std::string& argName);
179 } // namespace detail
180 
181 template <class Context>
182 class RecurrentNetworkOp final : public Operator<Context> {
183  public:
184  USE_OPERATOR_CONTEXT_FUNCTIONS;
185  explicit RecurrentNetworkOp(const OperatorDef& operator_def, Workspace* ws)
186  : Operator<Context>(operator_def, ws),
187  sharedWs_(ws),
188  enable_rnn_executor_(this->template GetSingleArgument<bool>(
189  "enable_rnn_executor",
190  false)),
191  timestep_(this->template GetSingleArgument<std::string>(
192  "timestep",
193  "timestep")),
194  operator_def_(operator_def) {
195  CAFFE_ENFORCE(ws);
196 
197  stepNetDef_ = detail::extractNetDef(operator_def, "step_net");
198 
199  recurrentInputs_ = constructRecurrentInputs(operator_def, sharedWs_);
200  links_ = constructLinks();
201  aliases_ = constructAliases();
202 
203  stepNetDef_.add_external_input(timestep_);
204  detail::AddApplyLinkOps(
205  links_, timestep_, operator_def.device_option(), &stepNetDef_);
206 
207  if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
208  InitializeExecutor(operator_def);
209  }
210  }
211 
212  size_t NumObservers() override {
213  size_t num = this->observers_list_.size();
214  if (rnnExecutor_) {
215  num += rnnExecutor_->NumObserversStepNet();
216  }
217  return num;
218  }
219 
220  std::vector<detail::RecurrentInput> constructRecurrentInputs(
221  const OperatorDef& operator_def,
222  Workspace* sharedWs) {
223  const auto states =
224  this->template GetRepeatedArgument<std::string>("recurrent_states");
225  const auto inputs =
226  this->template GetRepeatedArgument<int>("initial_recurrent_state_ids");
227  CAFFE_ENFORCE_EQ(states.size(), inputs.size(), "states/inputs mismatch");
228  std::vector<detail::RecurrentInput> ris;
229  for (auto i = 0; i < states.size(); ++i) {
230  // States need to be "global" (since they are shared between
231  // forward and backward).
232  sharedWs->CreateBlob(states[i]);
233 
235  ri.state = states[i];
236  ri.input = operator_def.input(inputs[i]);
237  ris.push_back(ri);
238  }
239  return ris;
240  }
241 
242  std::vector<detail::OffsetAlias> constructAliases() {
243  const auto& src =
244  this->template GetRepeatedArgument<std::string>("alias_src");
245  const auto& dst =
246  this->template GetRepeatedArgument<std::string>("alias_dst");
247  const auto& offset =
248  this->template GetRepeatedArgument<int32_t>("alias_offset");
249  CAFFE_ENFORCE(
250  src.size() == offset.size(), "alias_src/alias_offset mismatch");
251  CAFFE_ENFORCE(
252  dst.size() == offset.size(), "alias_dst/alias_offset mismatch");
253  std::vector<detail::OffsetAlias> aliases;
254  for (auto i = 0; i < src.size(); ++i) {
256  oc.src = src[i];
257  oc.dst = dst[i];
258  oc.offset = offset[i];
259  aliases.push_back(oc);
260  }
261  return aliases;
262  }
263 
271  std::vector<std::string> v;
272  const auto& blobs = this->template GetRepeatedArgument<std::string>(
273  "recompute_blobs_on_backward", v);
274  for (const auto& b : blobs) {
275  // Note: if the blob already was created, this is a no-op.
276  sharedBlobsWs->CreateBlob(b);
277  }
278  }
279 
280  std::vector<detail::Link> constructLinks() {
281  std::vector<detail::Link> links;
282  detail::extractLinks(
283  this,
284  "link_internal",
285  "link_external",
286  "link_offset",
287  "link_window",
288  &links);
289  return links;
290  }
291 
292  template<typename T>
293  bool DoRunWithType() {
294  const auto seqLen = Input(0).dim32(0);
295  const auto batchSize = Input(0).dim32(1);
296  for (const auto& ri : recurrentInputs_) {
297  detail::initializeRecurrentInput<T, Context>(
298  ri, seqLen, batchSize, sharedWs_, &context_);
299  }
300 
301  // If we don't have a backward step net, this operator is forward_only
302  // and we can avoid creating multiple workspaces.
303  bool has_backward_pass =
304  this->template HasSingleArgumentOfType<NetDef>("backward_step_net") ||
305  (this->template HasSingleArgumentOfType<string>("backward_step_net") &&
306  this->template GetSingleArgument<string>("backward_step_net", "") !=
307  "");
308 
309  // With backward pass: we need to create workspace for each timestep
310  detail::ScratchWorkspaces* scratch =
311  OperatorBase::Output<detail::ScratchWorkspaces>(OutputSize() - 1);
312  std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
313  scratch->stepWorkspaces;
314  std::shared_ptr<Workspace>& sharedBlobsWs = scratch->sharedBlobsWs;
315  if (!sharedBlobsWs) {
316  sharedBlobsWs = std::make_shared<Workspace>(sharedWs_);
317  }
318 
319  // Caller can decide that some of the forward activations
320  // are recomputed on backward pass. Then those activations do not
321  // have to be stored in step workspaces but can be shared.
322  initializeBlobsToRecomputeOnBackward(sharedBlobsWs.get());
323 
324  if (has_backward_pass && seqLen > stepWorkspaces.size()) {
325  stepWorkspaces.resize(seqLen);
326  }
327 
328  // In forward-only mode, we cycle over workspaces. This limits the amount
329  // of parallelism over timesteps that the RNNExecutor provides. So with
330  // RNN executor we use more workspaces to get better perf.
331  int num_workspaces_on_fwd_only = rnnExecutor_ ? 4 : 2;
332 
333  if (!has_backward_pass && stepWorkspaces.size() < num_workspaces_on_fwd_only) {
334  // Use alternating stepWorkspaces when forward_only=True.
335  // Note that the step workspaces can be shared by other ops, thus
336  // we cannot shrink it to 2 if there are more than 2 step workspaces.
337  stepWorkspaces.resize(num_workspaces_on_fwd_only);
338  }
339 
340  for (auto t = 0; t < seqLen; ++t) {
341  auto& currentStepWorkspace =
342  (has_backward_pass ? stepWorkspaces[t] :
343  stepWorkspaces[t % num_workspaces_on_fwd_only]);
344  if (!currentStepWorkspace) {
345  currentStepWorkspace = std::make_shared<Workspace>(sharedBlobsWs.get());
346  }
347 
348  if (rnnExecutor_) {
349  if (!has_backward_pass) {
350  // Need to limit timestep parallelism because we cycle over workspaces
351  rnnExecutor_->SetMaxParallelTimesteps(num_workspaces_on_fwd_only);
352  }
353  rnnExecutor_->EnsureTimestepInitialized(
354  t, currentStepWorkspace.get(), this->observers_list_);
355  } else {
356  // Use plain Caffe2 nets
357  detail::UpdateTimestepBlob(currentStepWorkspace.get(), timestep_, t);
358  auto* stepNet = currentStepWorkspace->GetNet(stepNetDef_.name());
359  if (stepNet == nullptr) {
360  stepNet = currentStepWorkspace->CreateNet(stepNetDef_);
361  }
362  CAFFE_ENFORCE(stepNet, "Step Net construction failure");
363  // Since we have a SimpleNet, there are no races here.
364  stepNet->RunAsync();
365  }
366  }
367 
368  if (rnnExecutor_) {
369  try {
370  rnnExecutor_->Run(seqLen);
371  } catch (const std::exception& e) {
372  LOG(ERROR) << "Encountered exception in RNN executor: " << e.what();
373  InitializeExecutor(operator_def_);
374  return false;
375  } catch (...) {
376  LOG(ERROR) << "Encountered exception in RNN executor: unknown";
377  InitializeExecutor(operator_def_);
378  return false;
379  }
380  }
381 
382  for (const auto& alias : aliases_) {
383  detail::applyOffsetAlias<T, Context>(alias, sharedWs_, &context_);
384  }
385 
386  return true;
387  }
388 
389  bool RunOnDevice() override {
390  return DoRunWithType<float>();
391  }
392 
393  protected:
394  NetDef stepNetDef_;
395  Workspace* sharedWs_;
396  bool enable_rnn_executor_;
397  std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
398 
399  std::vector<detail::Link> links_;
400  std::vector<detail::OffsetAlias> aliases_;
401  std::vector<detail::RecurrentInput> recurrentInputs_;
402  std::string timestep_;
403  OperatorDef operator_def_;
404 
405  private:
406  void InitializeExecutor(const OperatorDef& operator_def) {
407  VLOG(1) << "Use RecurrentNetworkExecutor";
408  auto recurrent_map =
409  detail::GetRecurrentMapping(links_, false /* backward */);
410  rnnExecutor_ = createRNNExecutor<Context>(
411  stepNetDef_, recurrent_map, timestep_, ArgumentHelper(operator_def));
412  }
413 };
414 
415 template <class Context>
416 class RecurrentNetworkGradientOp final : public Operator<Context> {
417  public:
418  USE_OPERATOR_CONTEXT_FUNCTIONS;
419  explicit RecurrentNetworkGradientOp(const OperatorDef& operator_def, Workspace* ws)
420  : Operator<Context>(operator_def, ws),
421  sharedWs_(ws),
422  enable_rnn_executor_(this->template GetSingleArgument<bool>(
423  "enable_rnn_executor",
424  false)),
425  timestep_(this->template GetSingleArgument<std::string>(
426  "timestep",
427  "timestep")),
428  gradInputs_(
429  this->template GetRepeatedArgument<int32_t>("outputs_with_grads")) {
430  CAFFE_ENFORCE(ws);
431 
432  stepNetDef_ = detail::extractNetDef(operator_def, "backward_step_net");
433 
434  links_ = constructLinks();
435  params_ = constructParams(operator_def);
436  recurrentGradients_ = constructRecurrentGradients(operator_def);
437  recurrentInputIds_ = this->template GetRepeatedArgument<int32_t>(
438  "initial_recurrent_state_ids");
439 
440  /* Add operators to the backward step net to handle accumulation of
441  gradients over timesteps
442  */
443  stepNetDef_.add_external_input(timestep_);
444 
445  AddGradientInputAccumulationOps(operator_def);
446  detail::AddApplyLinkOps(
447  links_, timestep_, operator_def.device_option(), &stepNetDef_);
448  AddParamGradientAccumulationOps(operator_def);
449 
450  if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
451  InitializeExecutor(operator_def);
452  }
453  }
454 
455  // Renaming maps (generated by memonger.py)
456  std::string remappedName(std::string blob_name) {
457  return this->template GetSingleArgument<std::string>(
458  blob_name + ".rename", blob_name);
459  }
460 
461  detail::Link remappedLink(const detail::Link& link) {
462  detail::Link renamed_link = link;
463  renamed_link.internal = remappedName(link.internal);
464  renamed_link.external = remappedName(link.external);
465  return renamed_link;
466  }
467 
468  void renameOpInputOutput(std::string from_name, std::string to_name) {
469  for (int j = 0; j < stepNetDef_.op_size(); j++) {
470  auto* op = stepNetDef_.mutable_op(j);
471  for (int i = 0; i < op->input_size(); i++) {
472  if (op->input(i) == from_name) {
473  op->set_input(i, to_name);
474  }
475  }
476  for (int i = 0; i < op->output_size(); i++) {
477  if (op->output(i) == from_name) {
478  op->set_output(i, to_name);
479  }
480  }
481  }
482  }
483 
484  std::vector<detail::Param> constructParams(const OperatorDef& operator_def) {
485  std::vector<detail::Param> params;
486  const auto& param = this->template GetRepeatedArgument<int32_t>("param");
487  const auto& param_grads =
488  this->template GetRepeatedArgument<string>("param_grads");
489  CAFFE_ENFORCE(
490  param_grads.empty() || param_grads.size() == param.size(),
491  param.size(),
492  " != ",
493  param_grads.size());
494  for (int i = 0; i < param.size(); ++i) {
495  detail::Param p;
496  // Forward inputs come after [outputs_with_grads] gradient inputs
497  p.param = operator_def.input(param[i] + gradInputs_.size());
498  // See GetRecurrentNetworkGradient to understand offseting here
499  p.grad = operator_def.output(i + numSequences_);
500 
501  std::string grad_blob =
502  param_grads.empty() ? p.grad : remappedName(param_grads[i]);
503  p.cellGradient = grad_blob + "_tmpstep";
504  params.push_back(p);
505 
506  renameOpInputOutput(grad_blob, p.cellGradient);
507  }
508  return params;
509  }
510 
511  std::vector<detail::RecurrentGradient> constructRecurrentGradients(
512  const OperatorDef& operator_def) {
513  std::vector<detail::RecurrentGradient> rgs;
514  const auto& recurrent =
515  this->template GetRepeatedArgument<std::string>("recurrent_states");
516  const auto& alias_src =
517  this->template GetRepeatedArgument<std::string>("alias_src");
518  const auto& offset =
519  this->template GetRepeatedArgument<int32_t>("alias_offset");
520 
521  for (auto i = 0; i < recurrent.size(); ++i) {
523  rg.param = recurrent[i];
524  rg.grad = remappedName(recurrent[i] + "_grad");
525 
526  for (int j = 0; j < alias_src.size(); ++j) {
527  if (alias_src[j] != recurrent[i]) {
528  continue;
529  }
530  int idx = -1;
531  for (int k = 0; k < gradInputs_.size(); ++k) {
532  if (gradInputs_[k] == j) {
533  idx = k;
534  }
535  }
536  if (idx == -1) {
537  continue;
538  }
539 
540  CAFFE_ENFORCE(offset[j] == 1 || offset[j] == -1);
541  if (offset[j] == 1) {
542  rg.externalGrad = operator_def.input(idx);
543  } else if (offset[j] == -1) {
544  rg.lastExternalGrad = operator_def.input(idx);
545  }
546  }
547  rg.offset = 1;
548  rgs.push_back(rg);
549  }
550  return rgs;
551  }
552 
553  std::vector<detail::Link> constructLinks() {
554  std::vector<detail::Link> links;
555  detail::extractLinks(
556  this,
557  "link_internal",
558  "link_external",
559  "link_offset",
560  "link_window",
561  &links);
562  detail::extractLinks(
563  this,
564  "backward_link_internal",
565  "backward_link_external",
566  "backward_link_offset",
567  "",
568  &links);
569  for (int i = 0; i < links.size(); i++) {
570  links[i] = remappedLink(links[i]);
571  }
572  return links;
573  }
574 
575  void InitializeExecutor(const OperatorDef& operator_def) {
576  VLOG(1) << "Use RecurrentNetworkExecutor for backward";
577  auto recurrent_map = detail::GetRecurrentMapping(links_, true /* backward */);
578  rnnExecutor_ = createRNNExecutor<Context>(
579  stepNetDef_, recurrent_map, timestep_, ArgumentHelper(operator_def));
580  }
581 
582  void AddGradientInputAccumulationOps(const OperatorDef& operator_def) {
586  std::vector<OperatorDef> ops;
587  for (const auto& rg : recurrentGradients_) {
588  if (rg.externalGrad.empty()) {
589  continue;
590  }
591  VLOG(1) << "Accumulating into: " << rg.grad << " from " << rg.externalGrad
592  << ", offset: " << rg.offset;
593 
594  OperatorDef opdef;
595  opdef.set_type("rnn_internal_accumulate_gradient_input");
596  opdef.add_input(timestep_);
597  opdef.add_input(rg.externalGrad);
598  opdef.add_input(rg.grad);
599  opdef.add_output(rg.grad);
600 
601  // Add also the linked blobs to outputs, to ensure correct
602  // chaining.
603  for (auto& l : links_) {
604  if (rg.grad == l.external) {
605  Argument* dep_arg = opdef.add_arg();
606  dep_arg->set_name("rnn_dependency." + l.internal);
607  dep_arg->set_s(l.internal);
608  }
609  }
610 
611  opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
612 
613  Argument* offset_arg = opdef.add_arg();
614  offset_arg->set_name("offset");
615  offset_arg->set_i(rg.offset);
616  ops.push_back(opdef);
617 
618  stepNetDef_.add_external_input(rg.externalGrad);
619  stepNetDef_.add_external_input(rg.grad);
620  }
621  detail::PrependOps(ops, &stepNetDef_);
622  }
623 
624  void AddParamGradientAccumulationOps(const OperatorDef& operator_def) {
625  // If a user passes in param_grads mapping, we can copy dirrectly
626  // form a blob where backward cell net written data to.
627  // This becomes handy in a case where gradient from the cell net
628  // is an internal blob of the backward cell. This happens, for example,
629  // when SumOp is the first op of the cell
630  for (const auto& param : params_) {
631  OperatorDef opdef;
632  opdef.set_type("Sum");
633  opdef.add_input(param.grad);
634  opdef.add_input(param.cellGradient);
635  opdef.add_output(param.grad);
636  opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
637  stepNetDef_.add_op()->CopyFrom(opdef);
638  stepNetDef_.add_external_input(param.grad);
639  }
640  }
641 
643  const std::shared_ptr<Workspace>& step0Ws,
644  Workspace* sharedBlobsWs) {
649  for (auto& op : stepNetDef_.op()) {
650  for (const string& outp : op.output()) {
651  if (!step0Ws->HasBlob(outp)) {
652  sharedBlobsWs->CreateBlob(outp);
653  }
654  }
655  }
656  }
657 
658  template<typename T>
659  bool DoRunWithType() {
660  const auto seqLen = Input(gradInputs_.size()).dim32(0);
661  VLOG(1) << "seqLen: " << seqLen;
662 
663  const detail::ScratchWorkspaces& scratch =
664  this->template Input<detail::ScratchWorkspaces>(InputSize() - 1);
665  const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
666  scratch.stepWorkspaces;
667  CAFFE_ENFORCE_GE(stepWorkspaces.size(), seqLen);
668  Workspace& sharedBlobsWs = *scratch.sharedBlobsWs.get();
669 
670  const auto batchSize = Input(0).dim32(1);
671  for (auto& param : params_) {
672  auto pBlob = sharedWs_->GetBlob(param.param);
673  CAFFE_ENFORCE(pBlob);
674  const auto& p = pBlob->template Get<Tensor>();
675 
676  auto gBlob = sharedWs_->GetBlob(param.grad);
677  CAFFE_ENFORCE(gBlob);
678  auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
679  g->ResizeLike(p);
680  math::Set<T, Context>(
681  g->numel(),
682  convert::To<float, T>(0.0),
683  g->template mutable_data<T>(),
684  &context_);
685  }
686 
687  for (auto& rg : recurrentGradients_) {
688  auto pBlob = sharedWs_->GetBlob(rg.param);
689  CAFFE_ENFORCE(pBlob);
690  const auto& p = pBlob->template Get<Tensor>();
691 
692  auto gBlob = sharedWs_->CreateBlob(rg.grad);
693  CAFFE_ENFORCE(gBlob);
694  auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
695  g->ResizeLike(p);
696  CAFFE_ENFORCE_EQ(g->dim(), 3);
697  const auto timestep = g->numel() / g->size(0);
698  // Fill the last timestep with zeros for the gradient
699  math::Set<T, Context>(
700  timestep,
701  convert::To<float, T>(0.0),
702  g->template mutable_data<T>() + (g->size(0) - 1) * timestep,
703  &context_);
704  }
705 
706  // This code assumes that there are several inputs
707  // sequences. Actually it is not supported by the rest of the code,
708  // and numSequences_ is a constant, equal to 1.
709  for (int i = 0; i < numSequences_; ++i) {
710  // Offseting as the first gradInputs_.size() inputs of the op
711  // are from GO. Then all I(0..N).
712  const int gradientInputIndex = i + gradInputs_.size();
713  const auto& inputName = this->debug_def().input(gradientInputIndex);
714  auto gradientName = remappedName(inputName + "_grad");
715  VLOG(1) << "Initializing gradient for input " << gradientInputIndex
716  << " (" << inputName << ") "
717  << " as blob " << gradientName
718  << ". Size: " << Input(gradientInputIndex).numel();
719  auto pGradientBlob = sharedWs_->GetBlob(gradientName);
720  CAFFE_ENFORCE(pGradientBlob);
721  auto* g = BlobGetMutableTensor(pGradientBlob, Context::GetDeviceType());
722  g->ResizeLike(Input(gradientInputIndex));
723  g->template mutable_data<T>();
724  }
725 
726  auto accumulateFinalInputGradients = [&]() {
727  for (const auto& rg : recurrentGradients_) {
728  if (rg.lastExternalGrad.empty()) {
729  continue;
730  }
731  VLOG(1) << "Accumulating into: " << rg.grad << " from "
732  << rg.lastExternalGrad << " for final time step (sep. blob)";
733  auto gBlob = sharedWs_->GetBlob(rg.grad);
734  CAFFE_ENFORCE(gBlob);
735  auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
736 
737  auto oglastBlob = sharedWs_->GetBlob(rg.lastExternalGrad);
738  CAFFE_ENFORCE(oglastBlob);
739  const auto& oglast = oglastBlob->template Get<Tensor>();
740  CAFFE_ENFORCE_EQ(g->size(1), oglast.size(1));
741  CAFFE_ENFORCE_EQ(g->size(2), oglast.size(2));
742 
743  const auto t = g->size(0) - 1;
744  const auto timestep_size = g->numel() / g->size(0);
745  CAFFE_ENFORCE_EQ(timestep_size, oglast.numel());
746  T* g_data_with_offset =
747  g->template mutable_data<T>() + t * timestep_size;
748  math::Add<T, Context>(
749  timestep_size,
750  oglast.template data<T>(),
751  g_data_with_offset,
752  g_data_with_offset,
753  &context_);
754  }
755  };
756 
757  accumulateFinalInputGradients();
758 
759  // Create shared blobs for blobs that can be shared between
760  // all timesteps.
761  if (stepWorkspaces.size() > 0) {
762  CreateSharedBlobs(stepWorkspaces[0], &sharedBlobsWs);
763  }
764  for (int32_t t = seqLen - 1; t >= 0; --t) {
765  if (rnnExecutor_) {
766  rnnExecutor_->EnsureTimestepInitialized(
767  t, stepWorkspaces[t].get(), this->observers_list_);
768  } else {
769  auto* stepNet = stepWorkspaces[t].get()->GetNet(stepNetDef_.name());
770  if (stepNet == nullptr) {
771  stepNet = stepWorkspaces[t].get()->CreateNet(stepNetDef_);
772  }
773  CAFFE_ENFORCE(stepNet);
774  stepNet->RunAsync();
775  }
776  }
777 
778  if (rnnExecutor_) {
779  rnnExecutor_->RunBackwards(seqLen);
780  }
781 
782  CAFFE_ENFORCE_EQ(recurrentInputIds_.size(), recurrentGradients_.size());
783  for (int i = 0; i < recurrentInputIds_.size(); ++i) {
784  // See GetRecurrentNetworkGradient to understand offseting here
785  // Outputs of the gradient are inputs of the forward pass.
786  // So we need to offset on all inputs that go before recurrent
787  // initial ones
788  auto outputIdx = i + params_.size() + numSequences_;
789  // because first gradInputs_.size() inputs are from GO
790  int inputId = recurrentInputIds_[i] + gradInputs_.size();
791  VLOG(1) << "Resetting output " << this->debug_def().output(outputIdx)
792  << " like input " << this->debug_def().input(inputId);
793  Output(outputIdx)->ResizeLike(Input(inputId));
794  T* output_data = Output(outputIdx)->template mutable_data<T>();
795  auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
796  CAFFE_ENFORCE(pBlob);
797  auto* p = BlobGetMutableTensor(pBlob, Context::GetDeviceType());
798 
799  if (Input(inputId).dim() >= 2) {
800  // Gradient states blob should live. And if it gets changed by the
801  // backward pass, then output should be changed as well. Thus it should
802  // be okay to share data here
803  Output(outputIdx)->template ShareExternalPointer<T>(
804  p->template mutable_data<T>());
805  } else {
806  // We need to do a bunch of Adds any way. So lets not worry about
807  // copy / share data here. One way to speed this up could be a kernel
808  // which sums up several tensors together instead of going 1 by 1
809  const auto recurrentStateSize = Input(inputId).dim32(0);
810 
811  math::Set<T, Context>(
812  recurrentStateSize,
813  convert::To<float,T>(0.0),
814  output_data,
815  &context_);
816 
817  math::AddStripedBatch<T, Context>(
818  recurrentStateSize,
819  p->template data<T>(),
820  output_data,
821  recurrentStateSize,
822  batchSize,
823  &context_);
824  }
825  }
826 
827  return true;
828  }
829 
830  bool RunOnDevice() override {
831  return DoRunWithType<float>();
832  }
833 
834  protected:
835  NetDef stepNetDef_;
836  Workspace* sharedWs_;
837  bool enable_rnn_executor_;
838  std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
839  std::vector<detail::Link> links_;
840  std::vector<detail::Param> params_;
841  std::vector<detail::RecurrentGradient> recurrentGradients_;
842  std::string timestep_;
843  // For now we support only one input sequence
844  const int numSequences_{1};
845  std::vector<int32_t> recurrentInputIds_;
846  std::vector<int32_t> gradInputs_;
847 };
848 
849 template <class Context>
850 class AccumulateInputGradientOp : public Operator<Context> {
851  public:
852  template <class... Args>
853  explicit AccumulateInputGradientOp(Args&&... args)
854  : Operator<Context>(std::forward<Args>(args)...),
855  offset_(this->template GetSingleArgument<int>("offset", -1)) {
856  CAFFE_ENFORCE(offset_ >= 0, "Offset not set");
857  }
858  USE_OPERATOR_CONTEXT_FUNCTIONS;
859 
860  template<typename T>
861  bool DoRunWithType() {
862  const auto& t0 = this->template Input<Tensor>(0, CPU);
863  const auto t = t0.template data<int32_t>()[0];
864  auto& og = Input(1);
865  auto* g = Output(0);
866 
867  T* g_data = g->template mutable_data<T>();
868  const auto timestep_size = g->numel() / g->size(0);
869 
870  CAFFE_ENFORCE(
871  (t + offset_) * timestep_size + timestep_size <= g->numel(),
872  "Accumulation destination address over bounds");
873  CAFFE_ENFORCE(
874  t * timestep_size + timestep_size <= og.numel(),
875  "Accumulation source address out of bounds");
876 
877  math::Add<T, Context>(
878  timestep_size,
879  og.template data<T>() + t * timestep_size,
880  g_data + (t + offset_) * timestep_size,
881  g_data + (t + offset_) * timestep_size,
882  &context_);
883  return true;
884  }
885 
886  bool RunOnDevice() override {
887  return DispatchHelper<TensorTypes<float>>::call(this, Input(1));
888  }
889 
890  private:
891  int offset_;
892 };
893 
894 template <class Context>
895 class RNNApplyLinkOp : public Operator<Context> {
896  public:
897  template <class... Args>
898  explicit RNNApplyLinkOp(Args&&... args)
899  : Operator<Context>(std::forward<Args>(args)...),
900  offset_(this->template GetSingleArgument<int>("offset", -1)),
901  window_(this->template GetSingleArgument<int>("window", -1)) {
902  CAFFE_ENFORCE(offset_ >= 0, "offset not set");
903  CAFFE_ENFORCE(window_ >= 0, "window not set");
904  }
905 
906  USE_OPERATOR_CONTEXT_FUNCTIONS;
907 
908  template <typename T>
909  bool DoRunWithType() {
910  // Both internal and external appear as both input and output to enforce
911  // correct dependency computation.
912  const auto& t0 = this->template Input<Tensor>(0, CPU);
913  const auto t = t0.template data<int32_t>()[0];
914  auto& external = Input(1);
915 
916  auto* internal_out = Output(0);
917  auto* external_out = Output(1);
918 
919  CAFFE_ENFORCE_GT(external.numel(), 0);
920  const int64_t externalTimestepSize = external.numel() / external.size(0);
921  auto* externalData = external_out->template mutable_data<T>() +
922  (t + offset_) * externalTimestepSize;
923  auto internalDims = external_out->sizes().vec();
924  internalDims[0] = window_;
925 
926  internal_out->Resize(internalDims);
927  internal_out->ShareExternalPointer(
928  externalData, externalTimestepSize * window_);
929  return true;
930  }
931 
932  bool RunOnDevice() override {
933  return DoRunWithType<float>();
934  }
935 
936  private:
937  int offset_;
938  int window_;
939 };
940 
941 } // namespace caffe2
942 
943 #endif // CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_
void AddGradientInputAccumulationOps(const OperatorDef &operator_def)
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
void initializeRecurrentInput(const RecurrentInput &rc, int32_t seqLen, int32_t batchSize, Workspace *ws, Context *context)
Copy external input to the step net into the first item of (T + 1) X batch_size X input_size tensor...
void initializeBlobsToRecomputeOnBackward(Workspace *sharedBlobsWs)
Some blobs can be marked as to be recomputed on backward pass.
A helper class to index into arguments.
Definition: proto_utils.h:200
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void CreateSharedBlobs(const std::shared_ptr< Workspace > &step0Ws, Workspace *sharedBlobsWs)