1 #ifndef CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/operators/rnn/recurrent_network_op.h" 9 #include "google/protobuf/text_format.h" 15 template <
class Context>
18 USE_OPERATOR_CONTEXT_FUNCTIONS;
22 prefix_ = this->
template GetSingleArgument<std::string>(
"prefix",
"rnn");
26 bool RunOnDevice()
override {
28 this->
template Input<detail::ScratchWorkspaces>(0);
29 const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
30 scratch.stepWorkspaces;
32 std::vector<std::string> blob_names_vector = {};
34 for (int64_t i = 0; i < stepWorkspaces.size(); i++) {
35 Workspace* currentStepWorkspace = stepWorkspaces[i].get();
36 std::vector<std::string> blob_names = currentStepWorkspace->
LocalBlobs();
38 for (
auto& blob_name : blob_names) {
39 const Blob* currentBlob = currentStepWorkspace->
GetBlob(blob_name);
40 const auto& currentTensor = currentBlob->
Get<
Tensor>();
42 std::string newBlobName =
43 prefix_ + std::string(
"_") + blob_name + c10::to_string(i);
44 blob_names_vector.push_back(newBlobName);
46 BlobGetMutableTensor(ws_->
CreateBlob(newBlobName), CPU)
47 ->ResizeLike(currentTensor);
48 auto type = Context::GetDeviceType();
49 auto* newTensor = BlobGetMutableTensor(ws_->
GetBlob(newBlobName), type);
50 newTensor->CopyFrom(currentTensor);
55 Output(0, {
static_cast<int64_t
>(blob_names_vector.size())}, at::dtype<std::string>());
57 blob_names_vector.begin(),
58 blob_names_vector.end(),
59 output->template mutable_data<std::string>());
70 #endif // CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ Blob is a general container that hosts a typed pointer.
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const T & Get() const
Gets the const reference of the stored object.