1 #ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ 2 #define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/context_gpu.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/proto/caffe2_pb.h" 41 template <
typename SkipOutputCopy>
47 CAFFE_ENFORCE_EQ(def.device_option().device_type(), PROTO_CUDA);
48 OperatorDef base_def_(def);
50 base_def_.clear_device_option();
51 base_def_.mutable_device_option()->set_device_type(PROTO_CPU);
53 for (
const string& name : def.input()) {
54 local_input_blobs_.push_back(local_ws_.
CreateBlob(name));
55 CHECK_NOTNULL(local_input_blobs_.back());
57 base_op_ = CreateOperator(base_def_, &local_ws_);
58 for (
const string& name : def.output()) {
59 local_output_blobs_.push_back(local_ws_.
GetBlob(name));
60 CHECK_NOTNULL(local_output_blobs_.back());
64 bool RunOnDevice()
override {
65 for (
int i = 0; i < InputSize(); ++i) {
66 if (this->InputIsTensorType(i, CUDA)) {
68 BlobGetMutableTensor(local_input_blobs_[i], CPU)->CopyFrom(
Input(i));
70 VLOG(1) <<
"Input " << i <<
" is not TensorCUDA. Skipping copy.";
74 local_input_blobs_[i]->ShareExternal(
75 const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
76 OperatorBase::Inputs()[i]->meta());
80 if (!base_op_->Run()) {
81 LOG(ERROR) <<
"Base op run failed in GPUFallbackOp. Def: " 82 << ProtoDebugString(this->debug_def());
85 for (
int i = 0; i < OutputSize(); ++i) {
86 if (SkipOutputCopy::Contains(i)) {
87 VLOG(1) <<
"Copy output: index " << i <<
" skipped.";
91 BlobIsTensorType(*local_output_blobs_[i], CPU),
92 "GPU fallback op currently does not support non-TensorCPU " 93 "output type who needs copying.");
94 Output(i)->CopyFrom(local_output_blobs_[i]->
template Get<TensorCPU>());
101 vector<Blob*> local_input_blobs_;
102 vector<Blob*> local_output_blobs_;
103 unique_ptr<OperatorBase> base_op_;
110 #endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_ Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A templated class to allow one to wrap a CPU operator as a CUDA operator.