1 #include "caffe2/operators/rnn/recurrent_network_op.h"     2 #include "caffe2/core/workspace.h"     3 #include "caffe2/utils/proto_utils.h"     5 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT    11     "If set, uses special RNN executor for executing RecurrentNetworkOp");
    14 CAFFE_KNOWN_TYPE(detail::ScratchWorkspaces);
    16 REGISTER_CPU_OPERATOR(RecurrentNetwork, RecurrentNetworkOp<CPUContext>);
    17 OPERATOR_SCHEMA(RecurrentNetwork)
    18     .NumInputs(1, INT_MAX)
    19     .NumOutputs(2, INT_MAX)
    21 Run the input network in a recurrent fashion. This can be used to    22 implement fairly general recurrent neural networks (RNNs).    24 The operator proceeds as follows.    26 - First, initialized the states from the input recurrent states    27 - For each timestep T, apply the links (that map offsets from input/output    28 tensors into the inputs/outputs for the `step` network)    29 - Finally, alias the recurrent states to the specified output blobs.    31 This is a fairly special-case meta-operator, and so the implementation    32 is somewhat complex. It trades of generality (and frankly usability)    33 against performance and control (compared to e.g. TF    34 dynamic_rnn, Theano scan, etc).    36 See the usage examples for a flavor of how to use it.    39 REGISTER_CPU_OPERATOR(    40     RecurrentNetworkGradient,    41     RecurrentNetworkGradientOp<CPUContext>);    42 OPERATOR_SCHEMA(RecurrentNetworkGradient);    44 REGISTER_CPU_OPERATOR(    45     rnn_internal_accumulate_gradient_input,    46     AccumulateInputGradientOp<CPUContext>);    47 OPERATOR_SCHEMA(rnn_internal_accumulate_gradient_input)    49     .NumOutputs(1, INT_MAX)    50     .EnforceInplace({{2, 0}})    53 Internal RNN operator.    56 REGISTER_CPU_OPERATOR(    57     rnn_internal_apply_link,    59 OPERATOR_SCHEMA(rnn_internal_apply_link)    62     .EnforceInplace({{1, 1}})    65 Internal RNN operator.    69   using GradientMakerBase::GradientMakerBase;
    70   std::vector<OperatorDef> GetGradientDefs()
 override {
    72     auto params = argsHelper.GetRepeatedArgument<int32_t>(
"param");
    73     auto recurrentInputs =
    74         argsHelper.GetRepeatedArgument<int32_t>(
"initial_recurrent_state_ids");
    76     std::vector<std::string> gradientInputs;
    79     auto outputs_with_grads =
    80         argsHelper.GetRepeatedArgument<int32_t>(
"outputs_with_grads");
    81     CAFFE_ENFORCE(outputs_with_grads.size() > 0);
    82     for (
auto id : outputs_with_grads) {
    83       gradientInputs.push_back(GO(
id));
    87     for (
int i = 0; i < def_.input_size(); ++i) {
    88       gradientInputs.push_back(I(i));
    90     for (
int i = 0; i < def_.output_size(); ++i) {
    91       gradientInputs.push_back(O(i));
    95     std::vector<std::string> gradientOutputs;
    96     gradientOutputs.push_back(GI(0));
    97     for (
auto id : params) {
    98       gradientOutputs.push_back(GI(
id));
   100     for (
auto id : recurrentInputs) {
   101       gradientOutputs.push_back(GI(
id));
   104     VLOG(1) << 
"Gradient blobs: " << c10::Join(
", ", gradientOutputs);
   107         "RecurrentNetworkGradient", 
"", gradientInputs, gradientOutputs);
   115 std::map<string, string> GetRecurrentMapping(
   116     const std::vector<detail::Link>& links,
   118   std::map<string, string> mappings;
   119   for (
auto it = links.begin(); it != links.end(); ++it) {
   120     const auto& l1 = *it;
   124     const int offset_l1 = backward ? 1 : 0;
   125     const int offset_l2 = 1 - offset_l1;
   126     if (l1.offset == offset_l1) {
   129       for (
auto it2 = it + 1; it2 != links.end(); ++it2) {
   130         const auto& l2 = *it2;
   131         if (l2.offset == offset_l2 && l2.external == l1.external) {
   132           mappings[l2.internal] = l1.internal;
   141 void PrependOps(std::vector<OperatorDef> ops, NetDef* netdef) {
   142   for (
auto& o : netdef->op()) {
   145   netdef->mutable_op()->Clear();
   146   for (
auto& o : ops) {
   147     auto* ao = netdef->add_op();
   152 void AddApplyLinkOps(
   153     const vector<Link>& links,
   154     std::string timestep,
   155     const DeviceOption& device_option,
   157   std::vector<OperatorDef> ops;
   158   for (
auto& link : links) {
   160     opdef.set_type(
"rnn_internal_apply_link");
   161     opdef.add_input(timestep);
   162     opdef.add_input(link.external);
   163     opdef.add_output(link.internal);
   164     opdef.add_output(link.external);
   165     opdef.mutable_device_option()->CopyFrom(device_option);
   167     Argument* offset_arg = opdef.add_arg();
   168     offset_arg->set_name(
"offset");
   169     offset_arg->set_i(link.offset);
   171     Argument* window_arg = opdef.add_arg();
   172     window_arg->set_name(
"window");
   173     window_arg->set_i(link.window);
   177     for (
auto& op : *netdef->mutable_op()) {
   178       if (HasInput(op, link.internal)) {
   182       if (HasOutput(op, link.internal)) {
   183         op.add_control_input(link.internal);
   188     ops.push_back(opdef);
   190     netdef->add_external_input(link.internal);
   191     netdef->add_external_input(link.external);
   194   detail::PrependOps(ops, netdef);
   199     const std::string& internalArg,
   200     const std::string& externalArg,
   201     const std::string& offsetArg,
   202     const std::string& windowArg,
   203     std::vector<detail::Link>* links) {
   204   const auto& 
internal = op->GetRepeatedArgument<std::string>(internalArg);
   205   const auto& external = op->GetRepeatedArgument<std::string>(externalArg);
   206   const auto& offset = op->GetRepeatedArgument<int32_t>(offsetArg);
   207   const auto& window = op->GetRepeatedArgument<int32_t>(
   208       windowArg, vector<int32_t>(offset.size(), 1));
   212       "internal/offset mismatch: ",
   219       "external/offset mismatch: ",
   226       "external/window mismatch: ",
   230   for (
auto i = 0; i < 
internal.size(); ++i) {
   232     l.internal = 
internal[i];
   233     l.external = external[i];
   234     l.offset = offset[i];
   235     l.window = window[i];
   240 NetDef extractNetDef(
const OperatorDef& op, 
const std::string& argName) {
   241   if (ArgumentHelper::HasSingleArgumentOfType<OperatorDef, NetDef>(
   243     return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
   244         op, argName, NetDef());
   246 #ifndef CAFFE2_RNN_NO_TEXT_FORMAT   248     const auto netString =
   249         ArgumentHelper::GetSingleArgument<OperatorDef, string>(op, argName, 
"");
   251         TextFormat::ParseFromString(netString, &result),
   255     CAFFE_THROW(
"No valid NetDef for argument ", argName);
 
A helper class to index into arguments. 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...