Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_blob_fetcher_op.h
1 
17 #ifndef CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
18 #define CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/core/tensor.h"
24 #include "caffe2/operators/recurrent_network_op.h"
25 #include "google/protobuf/text_format.h"
26 
27 #include <string>
28 
29 namespace caffe2 {
30 
31 template <class Context>
32 class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws) {
38  prefix_ = OperatorBase::GetSingleArgument<std::string>("prefix", "rnn");
39  ws_ = ws;
40  }
41 
42  bool RunOnDevice() override {
43  const detail::ScratchWorkspaces& scratch =
44  OperatorBase::Input<detail::ScratchWorkspaces>(0);
45  const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
46  scratch.stepWorkspaces;
47 
48  std::vector<std::string> blob_names_vector = {};
49 
50  for (TIndex i = 0; i < stepWorkspaces.size(); i++) {
51  Workspace* currentStepWorkspace = stepWorkspaces[i].get();
52  std::vector<std::string> blob_names = currentStepWorkspace->LocalBlobs();
53 
54  for (auto& blob_name : blob_names) {
55  const Blob* currentBlob = currentStepWorkspace->GetBlob(blob_name);
56  const auto& currentTensor = currentBlob->Get<Tensor<Context>>();
57 
58  std::string newBlobName =
59  prefix_ + std::string("_") + blob_name + caffe2::to_string(i);
60  blob_names_vector.push_back(newBlobName);
61 
62  ws_->CreateBlob(newBlobName)
63  ->template GetMutable<TensorCPU>()
64  ->ResizeLike(currentTensor);
65 
66  auto* newTensor =
67  ws_->GetBlob(newBlobName)->template GetMutable<Tensor<Context>>();
68  newTensor->template CopyFrom<Context>(currentTensor);
69  }
70  }
71 
72  auto* output = Output(0);
73  output->Resize(blob_names_vector.size());
74  std::copy(
75  blob_names_vector.begin(),
76  blob_names_vector.end(),
77  output->template mutable_data<std::string>());
78 
79  return true;
80  }
81 
82  private:
83  std::string prefix_;
84  Workspace* ws_;
85 };
86 } // namespace caffe2
87 
88 #endif // CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:120
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:91
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:180
Copyright (c) 2016-present, Facebook, Inc.
const T & Get() const
Gets the const reference of the stored object.
Definition: blob.h:91