1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/operators/rnn/recurrent_network_executor.h" 9 #include "caffe2/utils/conversions.h" 10 #include "caffe2/utils/math.h" 12 C10_DECLARE_bool(caffe2_rnn_executor);
20 std::string cellGradient;
31 std::string externalGrad;
32 std::string lastExternalGrad;
50 std::vector<std::shared_ptr<Workspace>> stepWorkspaces;
51 std::shared_ptr<Workspace> sharedBlobsWs =
nullptr;
54 inline void UpdateTimestepBlob(
Workspace* ws, std::string blob_name,
int t) {
55 BlobGetMutableTensor(ws->
CreateBlob(blob_name), CPU)->Resize(1);
56 auto timestepBlob = ws->
GetBlob(blob_name);
57 CAFFE_ENFORCE(timestepBlob);
58 BlobGetMutableTensor(timestepBlob, CPU)->template mutable_data<int32_t>()[0] =
62 CAFFE2_API std::map<string, string> GetRecurrentMapping(
63 const std::vector<detail::Link>& links,
66 template <
typename T,
typename Context>
67 void applyOffsetAlias(
71 VLOG(1) <<
"Aliasing: " << oc.src <<
" to: " << oc.dst
72 <<
" at offset: " << oc.offset;
73 auto srcBlob = ws->
GetBlob(oc.src);
74 CAFFE_ENFORCE(srcBlob);
75 auto* src = BlobGetMutableTensor(srcBlob, Context::GetDeviceType());
77 BlobGetMutableTensor(ws->
GetBlob(oc.dst), Context::GetDeviceType());
78 auto timestep = src->numel() / src->size(0);
79 auto dims = src->sizes().vec();
80 const int32_t startDstTimestep =
81 oc.offset >= 0 ? oc.offset : src->size(0) + oc.offset;
82 const int32_t numDstTimesteps = src->size(0) - startDstTimestep;
83 if (numDstTimesteps >= 1) {
84 dims[0] = numDstTimesteps;
86 CAFFE_ENFORCE(timestep == dst->numel() / numDstTimesteps,
"Invalid offset");
87 dst->ShareExternalPointer(
88 src->template mutable_data<T>() + startDstTimestep * timestep);
91 numDstTimesteps, 0,
"Invalid number of timesteps: ", numDstTimesteps);
94 dst->template mutable_data<T>();
98 template <
typename T,
class Context>
105 for (
int i = 0; i < repeat_n; ++i) {
106 context->template CopySameDevice<T>(n, src, dst + i * n);
114 template <
typename T,
typename Context>
121 auto stateBlob = ws->
GetBlob(rc.state);
122 CAFFE_ENFORCE(stateBlob);
123 auto* state = BlobGetMutableTensor(stateBlob, Context::GetDeviceType());
125 auto inputBlob = ws->
GetBlob(rc.input);
126 CAFFE_ENFORCE(inputBlob);
127 const auto& input = inputBlob->template Get<Tensor>();
128 CAFFE_ENFORCE_GE(input.dim(), 1, rc.input);
129 CAFFE_ENFORCE_LE(input.dim(), 3, rc.input);
131 const auto stateSize = input.size(input.dim() - 1);
136 auto initialStateLength = 1;
137 if (input.dim() == 3) {
138 initialStateLength = input.size(0);
141 state->Resize(seqLen + initialStateLength, batchSize, stateSize);
143 if (input.dim() >= 2) {
144 CAFFE_ENFORCE_EQ(input.size(input.dim() - 2), batchSize, rc.input);
145 context->template CopySameDevice<T>(
146 batchSize * stateSize * initialStateLength,
147 input.template data<T>(),
148 state->template mutable_data<T>());
152 repeatCopy<T, Context>(
155 input.template data<T>(),
156 state->template mutable_data<T>(),
161 CAFFE2_API
void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef);
163 CAFFE2_API
void AddApplyLinkOps(
164 const vector<Link>& links,
165 std::string timestep,
166 const DeviceOption& device_option,
169 CAFFE2_API
void extractLinks(
171 const std::string& internalArg,
172 const std::string& externalArg,
173 const std::string& offsetArg,
174 const std::string& windowArg,
175 std::vector<detail::Link>* links);
178 extractNetDef(
const OperatorDef& op,
const std::string& argName);
181 template <
class Context>
184 USE_OPERATOR_CONTEXT_FUNCTIONS;
188 enable_rnn_executor_(this->
template GetSingleArgument<bool>(
189 "enable_rnn_executor",
191 timestep_(this->
template GetSingleArgument<std::string>(
194 operator_def_(operator_def) {
197 stepNetDef_ = detail::extractNetDef(operator_def,
"step_net");
199 recurrentInputs_ = constructRecurrentInputs(operator_def, sharedWs_);
200 links_ = constructLinks();
201 aliases_ = constructAliases();
203 stepNetDef_.add_external_input(timestep_);
204 detail::AddApplyLinkOps(
205 links_, timestep_, operator_def.device_option(), &stepNetDef_);
207 if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
208 InitializeExecutor(operator_def);
212 size_t NumObservers()
override {
213 size_t num = this->observers_list_.size();
215 num += rnnExecutor_->NumObserversStepNet();
220 std::vector<detail::RecurrentInput> constructRecurrentInputs(
221 const OperatorDef& operator_def,
224 this->
template GetRepeatedArgument<std::string>(
"recurrent_states");
226 this->
template GetRepeatedArgument<int>(
"initial_recurrent_state_ids");
227 CAFFE_ENFORCE_EQ(states.size(), inputs.size(),
"states/inputs mismatch");
228 std::vector<detail::RecurrentInput> ris;
229 for (
auto i = 0; i < states.size(); ++i) {
235 ri.state = states[i];
236 ri.input = operator_def.input(inputs[i]);
242 std::vector<detail::OffsetAlias> constructAliases() {
244 this->
template GetRepeatedArgument<std::string>(
"alias_src");
246 this->
template GetRepeatedArgument<std::string>(
"alias_dst");
248 this->
template GetRepeatedArgument<int32_t>(
"alias_offset");
250 src.size() == offset.size(),
"alias_src/alias_offset mismatch");
252 dst.size() == offset.size(),
"alias_dst/alias_offset mismatch");
253 std::vector<detail::OffsetAlias> aliases;
254 for (
auto i = 0; i < src.size(); ++i) {
258 oc.offset = offset[i];
259 aliases.push_back(oc);
271 std::vector<std::string> v;
272 const auto& blobs = this->
template GetRepeatedArgument<std::string>(
273 "recompute_blobs_on_backward", v);
274 for (
const auto& b : blobs) {
280 std::vector<detail::Link> constructLinks() {
281 std::vector<detail::Link> links;
282 detail::extractLinks(
293 bool DoRunWithType() {
294 const auto seqLen = Input(0).dim32(0);
295 const auto batchSize = Input(0).dim32(1);
296 for (
const auto& ri : recurrentInputs_) {
297 detail::initializeRecurrentInput<T, Context>(
298 ri, seqLen, batchSize, sharedWs_, &context_);
303 bool has_backward_pass =
304 this->
template HasSingleArgumentOfType<NetDef>(
"backward_step_net") ||
305 (this->
template HasSingleArgumentOfType<string>(
"backward_step_net") &&
306 this->
template GetSingleArgument<string>(
"backward_step_net",
"") !=
311 OperatorBase::Output<detail::ScratchWorkspaces>(OutputSize() - 1);
312 std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
313 scratch->stepWorkspaces;
314 std::shared_ptr<Workspace>& sharedBlobsWs = scratch->sharedBlobsWs;
315 if (!sharedBlobsWs) {
316 sharedBlobsWs = std::make_shared<Workspace>(sharedWs_);
322 initializeBlobsToRecomputeOnBackward(sharedBlobsWs.get());
324 if (has_backward_pass && seqLen > stepWorkspaces.size()) {
325 stepWorkspaces.resize(seqLen);
331 int num_workspaces_on_fwd_only = rnnExecutor_ ? 4 : 2;
333 if (!has_backward_pass && stepWorkspaces.size() < num_workspaces_on_fwd_only) {
337 stepWorkspaces.resize(num_workspaces_on_fwd_only);
340 for (
auto t = 0; t < seqLen; ++t) {
341 auto& currentStepWorkspace =
342 (has_backward_pass ? stepWorkspaces[t] :
343 stepWorkspaces[t % num_workspaces_on_fwd_only]);
344 if (!currentStepWorkspace) {
345 currentStepWorkspace = std::make_shared<Workspace>(sharedBlobsWs.get());
349 if (!has_backward_pass) {
351 rnnExecutor_->SetMaxParallelTimesteps(num_workspaces_on_fwd_only);
353 rnnExecutor_->EnsureTimestepInitialized(
354 t, currentStepWorkspace.get(), this->observers_list_);
357 detail::UpdateTimestepBlob(currentStepWorkspace.get(), timestep_, t);
358 auto* stepNet = currentStepWorkspace->GetNet(stepNetDef_.name());
359 if (stepNet ==
nullptr) {
360 stepNet = currentStepWorkspace->CreateNet(stepNetDef_);
362 CAFFE_ENFORCE(stepNet,
"Step Net construction failure");
370 rnnExecutor_->Run(seqLen);
371 }
catch (
const std::exception& e) {
372 LOG(ERROR) <<
"Encountered exception in RNN executor: " << e.what();
373 InitializeExecutor(operator_def_);
376 LOG(ERROR) <<
"Encountered exception in RNN executor: unknown";
377 InitializeExecutor(operator_def_);
382 for (
const auto& alias : aliases_) {
383 detail::applyOffsetAlias<T, Context>(alias, sharedWs_, &context_);
389 bool RunOnDevice()
override {
390 return DoRunWithType<float>();
396 bool enable_rnn_executor_;
397 std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
399 std::vector<detail::Link> links_;
400 std::vector<detail::OffsetAlias> aliases_;
401 std::vector<detail::RecurrentInput> recurrentInputs_;
402 std::string timestep_;
403 OperatorDef operator_def_;
406 void InitializeExecutor(
const OperatorDef& operator_def) {
407 VLOG(1) <<
"Use RecurrentNetworkExecutor";
409 detail::GetRecurrentMapping(links_,
false );
410 rnnExecutor_ = createRNNExecutor<Context>(
411 stepNetDef_, recurrent_map, timestep_,
ArgumentHelper(operator_def));
415 template <
class Context>
418 USE_OPERATOR_CONTEXT_FUNCTIONS;
422 enable_rnn_executor_(this->
template GetSingleArgument<bool>(
423 "enable_rnn_executor",
425 timestep_(this->
template GetSingleArgument<std::string>(
429 this->
template GetRepeatedArgument<int32_t>(
"outputs_with_grads")) {
432 stepNetDef_ = detail::extractNetDef(operator_def,
"backward_step_net");
434 links_ = constructLinks();
435 params_ = constructParams(operator_def);
436 recurrentGradients_ = constructRecurrentGradients(operator_def);
437 recurrentInputIds_ = this->
template GetRepeatedArgument<int32_t>(
438 "initial_recurrent_state_ids");
443 stepNetDef_.add_external_input(timestep_);
445 AddGradientInputAccumulationOps(operator_def);
446 detail::AddApplyLinkOps(
447 links_, timestep_, operator_def.device_option(), &stepNetDef_);
448 AddParamGradientAccumulationOps(operator_def);
450 if (FLAGS_caffe2_rnn_executor && enable_rnn_executor_) {
451 InitializeExecutor(operator_def);
456 std::string remappedName(std::string blob_name) {
457 return this->
template GetSingleArgument<std::string>(
458 blob_name +
".rename", blob_name);
463 renamed_link.internal = remappedName(link.internal);
464 renamed_link.external = remappedName(link.external);
468 void renameOpInputOutput(std::string from_name, std::string to_name) {
469 for (
int j = 0; j < stepNetDef_.op_size(); j++) {
470 auto* op = stepNetDef_.mutable_op(j);
471 for (
int i = 0; i < op->input_size(); i++) {
472 if (op->input(i) == from_name) {
473 op->set_input(i, to_name);
476 for (
int i = 0; i < op->output_size(); i++) {
477 if (op->output(i) == from_name) {
478 op->set_output(i, to_name);
484 std::vector<detail::Param> constructParams(
const OperatorDef& operator_def) {
485 std::vector<detail::Param> params;
486 const auto& param = this->
template GetRepeatedArgument<int32_t>(
"param");
487 const auto& param_grads =
488 this->
template GetRepeatedArgument<string>(
"param_grads");
490 param_grads.empty() || param_grads.size() == param.size(),
494 for (
int i = 0; i < param.size(); ++i) {
497 p.param = operator_def.input(param[i] + gradInputs_.size());
499 p.grad = operator_def.output(i + numSequences_);
501 std::string grad_blob =
502 param_grads.empty() ? p.grad : remappedName(param_grads[i]);
503 p.cellGradient = grad_blob +
"_tmpstep";
506 renameOpInputOutput(grad_blob, p.cellGradient);
511 std::vector<detail::RecurrentGradient> constructRecurrentGradients(
512 const OperatorDef& operator_def) {
513 std::vector<detail::RecurrentGradient> rgs;
514 const auto& recurrent =
515 this->
template GetRepeatedArgument<std::string>(
"recurrent_states");
516 const auto& alias_src =
517 this->
template GetRepeatedArgument<std::string>(
"alias_src");
519 this->
template GetRepeatedArgument<int32_t>(
"alias_offset");
521 for (
auto i = 0; i < recurrent.size(); ++i) {
523 rg.param = recurrent[i];
524 rg.grad = remappedName(recurrent[i] +
"_grad");
526 for (
int j = 0; j < alias_src.size(); ++j) {
527 if (alias_src[j] != recurrent[i]) {
531 for (
int k = 0; k < gradInputs_.size(); ++k) {
532 if (gradInputs_[k] == j) {
540 CAFFE_ENFORCE(offset[j] == 1 || offset[j] == -1);
541 if (offset[j] == 1) {
542 rg.externalGrad = operator_def.input(idx);
543 }
else if (offset[j] == -1) {
544 rg.lastExternalGrad = operator_def.input(idx);
553 std::vector<detail::Link> constructLinks() {
554 std::vector<detail::Link> links;
555 detail::extractLinks(
562 detail::extractLinks(
564 "backward_link_internal",
565 "backward_link_external",
566 "backward_link_offset",
569 for (
int i = 0; i < links.size(); i++) {
570 links[i] = remappedLink(links[i]);
575 void InitializeExecutor(
const OperatorDef& operator_def) {
576 VLOG(1) <<
"Use RecurrentNetworkExecutor for backward";
577 auto recurrent_map = detail::GetRecurrentMapping(links_,
true );
578 rnnExecutor_ = createRNNExecutor<Context>(
579 stepNetDef_, recurrent_map, timestep_,
ArgumentHelper(operator_def));
586 std::vector<OperatorDef> ops;
587 for (
const auto& rg : recurrentGradients_) {
588 if (rg.externalGrad.empty()) {
591 VLOG(1) <<
"Accumulating into: " << rg.grad <<
" from " << rg.externalGrad
592 <<
", offset: " << rg.offset;
595 opdef.set_type(
"rnn_internal_accumulate_gradient_input");
596 opdef.add_input(timestep_);
597 opdef.add_input(rg.externalGrad);
598 opdef.add_input(rg.grad);
599 opdef.add_output(rg.grad);
603 for (
auto& l : links_) {
604 if (rg.grad == l.external) {
605 Argument* dep_arg = opdef.add_arg();
606 dep_arg->set_name(
"rnn_dependency." + l.internal);
607 dep_arg->set_s(l.internal);
611 opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
613 Argument* offset_arg = opdef.add_arg();
614 offset_arg->set_name(
"offset");
615 offset_arg->set_i(rg.offset);
616 ops.push_back(opdef);
618 stepNetDef_.add_external_input(rg.externalGrad);
619 stepNetDef_.add_external_input(rg.grad);
621 detail::PrependOps(ops, &stepNetDef_);
624 void AddParamGradientAccumulationOps(
const OperatorDef& operator_def) {
630 for (
const auto& param : params_) {
632 opdef.set_type(
"Sum");
633 opdef.add_input(param.grad);
634 opdef.add_input(param.cellGradient);
635 opdef.add_output(param.grad);
636 opdef.mutable_device_option()->CopyFrom(operator_def.device_option());
637 stepNetDef_.add_op()->CopyFrom(opdef);
638 stepNetDef_.add_external_input(param.grad);
643 const std::shared_ptr<Workspace>& step0Ws,
649 for (
auto& op : stepNetDef_.op()) {
650 for (
const string& outp : op.output()) {
651 if (!step0Ws->HasBlob(outp)) {
659 bool DoRunWithType() {
660 const auto seqLen = Input(gradInputs_.size()).dim32(0);
661 VLOG(1) <<
"seqLen: " << seqLen;
664 this->
template Input<detail::ScratchWorkspaces>(InputSize() - 1);
665 const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
666 scratch.stepWorkspaces;
667 CAFFE_ENFORCE_GE(stepWorkspaces.size(), seqLen);
668 Workspace& sharedBlobsWs = *scratch.sharedBlobsWs.get();
670 const auto batchSize = Input(0).dim32(1);
671 for (
auto& param : params_) {
672 auto pBlob = sharedWs_->GetBlob(param.param);
673 CAFFE_ENFORCE(pBlob);
674 const auto& p = pBlob->template Get<Tensor>();
676 auto gBlob = sharedWs_->GetBlob(param.grad);
677 CAFFE_ENFORCE(gBlob);
678 auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
680 math::Set<T, Context>(
682 convert::To<float, T>(0.0),
683 g->template mutable_data<T>(),
687 for (
auto& rg : recurrentGradients_) {
688 auto pBlob = sharedWs_->GetBlob(rg.param);
689 CAFFE_ENFORCE(pBlob);
690 const auto& p = pBlob->template Get<Tensor>();
692 auto gBlob = sharedWs_->CreateBlob(rg.grad);
693 CAFFE_ENFORCE(gBlob);
694 auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
696 CAFFE_ENFORCE_EQ(g->dim(), 3);
697 const auto timestep = g->numel() / g->size(0);
699 math::Set<T, Context>(
701 convert::To<float, T>(0.0),
702 g->template mutable_data<T>() + (g->size(0) - 1) * timestep,
709 for (
int i = 0; i < numSequences_; ++i) {
712 const int gradientInputIndex = i + gradInputs_.size();
713 const auto& inputName = this->debug_def().input(gradientInputIndex);
714 auto gradientName = remappedName(inputName +
"_grad");
715 VLOG(1) <<
"Initializing gradient for input " << gradientInputIndex
716 <<
" (" << inputName <<
") " 717 <<
" as blob " << gradientName
718 <<
". Size: " << Input(gradientInputIndex).numel();
719 auto pGradientBlob = sharedWs_->GetBlob(gradientName);
720 CAFFE_ENFORCE(pGradientBlob);
721 auto* g = BlobGetMutableTensor(pGradientBlob, Context::GetDeviceType());
722 g->ResizeLike(Input(gradientInputIndex));
723 g->template mutable_data<T>();
726 auto accumulateFinalInputGradients = [&]() {
727 for (
const auto& rg : recurrentGradients_) {
728 if (rg.lastExternalGrad.empty()) {
731 VLOG(1) <<
"Accumulating into: " << rg.grad <<
" from " 732 << rg.lastExternalGrad <<
" for final time step (sep. blob)";
733 auto gBlob = sharedWs_->GetBlob(rg.grad);
734 CAFFE_ENFORCE(gBlob);
735 auto* g = BlobGetMutableTensor(gBlob, Context::GetDeviceType());
737 auto oglastBlob = sharedWs_->GetBlob(rg.lastExternalGrad);
738 CAFFE_ENFORCE(oglastBlob);
739 const auto& oglast = oglastBlob->template Get<Tensor>();
740 CAFFE_ENFORCE_EQ(g->size(1), oglast.size(1));
741 CAFFE_ENFORCE_EQ(g->size(2), oglast.size(2));
743 const auto t = g->size(0) - 1;
744 const auto timestep_size = g->numel() / g->size(0);
745 CAFFE_ENFORCE_EQ(timestep_size, oglast.numel());
746 T* g_data_with_offset =
747 g->template mutable_data<T>() + t * timestep_size;
748 math::Add<T, Context>(
750 oglast.template data<T>(),
757 accumulateFinalInputGradients();
761 if (stepWorkspaces.size() > 0) {
762 CreateSharedBlobs(stepWorkspaces[0], &sharedBlobsWs);
764 for (int32_t t = seqLen - 1; t >= 0; --t) {
766 rnnExecutor_->EnsureTimestepInitialized(
767 t, stepWorkspaces[t].
get(), this->observers_list_);
769 auto* stepNet = stepWorkspaces[t].get()->GetNet(stepNetDef_.name());
770 if (stepNet ==
nullptr) {
771 stepNet = stepWorkspaces[t].get()->CreateNet(stepNetDef_);
773 CAFFE_ENFORCE(stepNet);
779 rnnExecutor_->RunBackwards(seqLen);
782 CAFFE_ENFORCE_EQ(recurrentInputIds_.size(), recurrentGradients_.size());
783 for (
int i = 0; i < recurrentInputIds_.size(); ++i) {
788 auto outputIdx = i + params_.size() + numSequences_;
790 int inputId = recurrentInputIds_[i] + gradInputs_.size();
791 VLOG(1) <<
"Resetting output " << this->debug_def().output(outputIdx)
792 <<
" like input " << this->debug_def().input(inputId);
793 Output(outputIdx)->ResizeLike(Input(inputId));
794 T* output_data = Output(outputIdx)->template mutable_data<T>();
795 auto pBlob = sharedWs_->GetBlob(recurrentGradients_[i].grad);
796 CAFFE_ENFORCE(pBlob);
797 auto* p = BlobGetMutableTensor(pBlob, Context::GetDeviceType());
799 if (Input(inputId).dim() >= 2) {
803 Output(outputIdx)->template ShareExternalPointer<T>(
804 p->template mutable_data<T>());
809 const auto recurrentStateSize = Input(inputId).dim32(0);
811 math::Set<T, Context>(
813 convert::To<float,T>(0.0),
817 math::AddStripedBatch<T, Context>(
819 p->template data<T>(),
830 bool RunOnDevice()
override {
831 return DoRunWithType<float>();
837 bool enable_rnn_executor_;
838 std::unique_ptr<RecurrentNetworkExecutorBase> rnnExecutor_;
839 std::vector<detail::Link> links_;
840 std::vector<detail::Param> params_;
841 std::vector<detail::RecurrentGradient> recurrentGradients_;
842 std::string timestep_;
844 const int numSequences_{1};
845 std::vector<int32_t> recurrentInputIds_;
846 std::vector<int32_t> gradInputs_;
849 template <
class Context>
852 template <
class... Args>
855 offset_(this->
template GetSingleArgument<int>(
"offset", -1)) {
856 CAFFE_ENFORCE(offset_ >= 0,
"Offset not set");
858 USE_OPERATOR_CONTEXT_FUNCTIONS;
861 bool DoRunWithType() {
862 const auto& t0 = this->
template Input<Tensor>(0, CPU);
863 const auto t = t0.template data<int32_t>()[0];
867 T* g_data = g->template mutable_data<T>();
868 const auto timestep_size = g->numel() / g->size(0);
871 (t + offset_) * timestep_size + timestep_size <= g->numel(),
872 "Accumulation destination address over bounds");
874 t * timestep_size + timestep_size <= og.numel(),
875 "Accumulation source address out of bounds");
877 math::Add<T, Context>(
879 og.template data<T>() + t * timestep_size,
880 g_data + (t + offset_) * timestep_size,
881 g_data + (t + offset_) * timestep_size,
886 bool RunOnDevice()
override {
894 template <
class Context>
897 template <
class... Args>
900 offset_(this->
template GetSingleArgument<int>(
"offset", -1)),
901 window_(this->
template GetSingleArgument<int>(
"window", -1)) {
902 CAFFE_ENFORCE(offset_ >= 0,
"offset not set");
903 CAFFE_ENFORCE(window_ >= 0,
"window not set");
906 USE_OPERATOR_CONTEXT_FUNCTIONS;
908 template <
typename T>
909 bool DoRunWithType() {
912 const auto& t0 = this->
template Input<Tensor>(0, CPU);
913 const auto t = t0.template data<int32_t>()[0];
914 auto& external = Input(1);
916 auto* internal_out = Output(0);
917 auto* external_out = Output(1);
919 CAFFE_ENFORCE_GT(external.numel(), 0);
920 const int64_t externalTimestepSize = external.numel() / external.size(0);
921 auto* externalData = external_out->template mutable_data<T>() +
922 (t + offset_) * externalTimestepSize;
923 auto internalDims = external_out->sizes().vec();
924 internalDims[0] = window_;
926 internal_out->Resize(internalDims);
927 internal_out->ShareExternalPointer(
928 externalData, externalTimestepSize * window_);
932 bool RunOnDevice()
override {
933 return DoRunWithType<float>();
943 #endif // CAFFE2_OPERATORS_RECURRENT_NETWORK_OP_H_ void AddGradientInputAccumulationOps(const OperatorDef &operator_def)
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
void initializeRecurrentInput(const RecurrentInput &rc, int32_t seqLen, int32_t batchSize, Workspace *ws, Context *context)
Copy external input to the step net into the first item of (T + 1) X batch_size X input_size tensor...
void initializeBlobsToRecomputeOnBackward(Workspace *sharedBlobsWs)
Some blobs can be marked as to be recomputed on backward pass.
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void CreateSharedBlobs(const std::shared_ptr< Workspace > &step0Ws, Workspace *sharedBlobsWs)