Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_blob_fetcher_op.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
2 #define CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/operators/rnn/recurrent_network_op.h"
9 #include "google/protobuf/text_format.h"
10 
11 #include <string>
12 
13 namespace caffe2 {
14 
15 template <class Context>
16 class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  explicit RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws) {
22  prefix_ = this->template GetSingleArgument<std::string>("prefix", "rnn");
23  ws_ = ws;
24  }
25 
26  bool RunOnDevice() override {
27  const detail::ScratchWorkspaces& scratch =
28  this->template Input<detail::ScratchWorkspaces>(0);
29  const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
30  scratch.stepWorkspaces;
31 
32  std::vector<std::string> blob_names_vector = {};
33 
34  for (int64_t i = 0; i < stepWorkspaces.size(); i++) {
35  Workspace* currentStepWorkspace = stepWorkspaces[i].get();
36  std::vector<std::string> blob_names = currentStepWorkspace->LocalBlobs();
37 
38  for (auto& blob_name : blob_names) {
39  const Blob* currentBlob = currentStepWorkspace->GetBlob(blob_name);
40  const auto& currentTensor = currentBlob->Get<Tensor>();
41 
42  std::string newBlobName =
43  prefix_ + std::string("_") + blob_name + c10::to_string(i);
44  blob_names_vector.push_back(newBlobName);
45 
46  BlobGetMutableTensor(ws_->CreateBlob(newBlobName), CPU)
47  ->ResizeLike(currentTensor);
48  auto type = Context::GetDeviceType();
49  auto* newTensor = BlobGetMutableTensor(ws_->GetBlob(newBlobName), type);
50  newTensor->CopyFrom(currentTensor);
51  }
52  }
53 
54  auto* output =
55  Output(0, {static_cast<int64_t>(blob_names_vector.size())}, at::dtype<std::string>());
56  std::copy(
57  blob_names_vector.begin(),
58  blob_names_vector.end(),
59  output->template mutable_data<std::string>());
60 
61  return true;
62  }
63 
64  private:
65  std::string prefix_;
66  Workspace* ws_;
67 };
68 } // namespace caffe2
69 
70 #endif // CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:71
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
const T & Get() const
Gets the const reference of the stored object.
Definition: blob.h:71