Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_fallback_gpu.h
1 #ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
2 #define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/context_gpu.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/proto/caffe2_pb.h"
9 
10 namespace caffe2 {
11 
41 template <typename SkipOutputCopy>
42 class GPUFallbackOpEx final : public Operator<CUDAContext> {
43  public:
44  USE_OPERATOR_FUNCTIONS(CUDAContext);
45  explicit GPUFallbackOpEx(const OperatorDef& def, Workspace* ws)
46  : Operator<CUDAContext>(def, ws) {
47  CAFFE_ENFORCE_EQ(def.device_option().device_type(), PROTO_CUDA);
48  OperatorDef base_def_(def);
49  // base_def_ runs on CPU, so we will set its device option to CPU.
50  base_def_.clear_device_option();
51  base_def_.mutable_device_option()->set_device_type(PROTO_CPU);
52  // Set up the symbols for the local workspace.
53  for (const string& name : def.input()) {
54  local_input_blobs_.push_back(local_ws_.CreateBlob(name));
55  CHECK_NOTNULL(local_input_blobs_.back());
56  }
57  base_op_ = CreateOperator(base_def_, &local_ws_);
58  for (const string& name : def.output()) {
59  local_output_blobs_.push_back(local_ws_.GetBlob(name));
60  CHECK_NOTNULL(local_output_blobs_.back());
61  }
62  }
63 
64  bool RunOnDevice() override {
65  for (int i = 0; i < InputSize(); ++i) {
66  if (this->InputIsTensorType(i, CUDA)) {
67  // use sync copy
68  BlobGetMutableTensor(local_input_blobs_[i], CPU)->CopyFrom(Input(i));
69  } else {
70  VLOG(1) << "Input " << i << " is not TensorCUDA. Skipping copy.";
71  // Note(jiayq): This removes a const but conceptually
72  // local_input_blobs will only be used as const blob input for the
73  // base op so we are still fine.
74  local_input_blobs_[i]->ShareExternal(
75  const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
76  OperatorBase::Inputs()[i]->meta());
77  }
78  }
79 
80  if (!base_op_->Run()) {
81  LOG(ERROR) << "Base op run failed in GPUFallbackOp. Def: "
82  << ProtoDebugString(this->debug_def());
83  return false;
84  }
85  for (int i = 0; i < OutputSize(); ++i) {
86  if (SkipOutputCopy::Contains(i)) {
87  VLOG(1) << "Copy output: index " << i << " skipped.";
88  continue;
89  }
90  CAFFE_ENFORCE(
91  BlobIsTensorType(*local_output_blobs_[i], CPU),
92  "GPU fallback op currently does not support non-TensorCPU "
93  "output type who needs copying.");
94  Output(i)->CopyFrom(local_output_blobs_[i]->template Get<TensorCPU>());
95  }
96  return true;
97  }
98 
99  protected:
100  Workspace local_ws_;
101  vector<Blob*> local_input_blobs_;
102  vector<Blob*> local_output_blobs_;
103  unique_ptr<OperatorBase> base_op_;
104 };
105 
107 
108 } // namespace caffe2
109 
110 #endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
A templated class to allow one to wrap a CPU operator as a CUDA operator.