Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_fallback_gpu.h
1 
17 #ifndef CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
18 #define CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
19 
20 #include "caffe2/core/common.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/context_gpu.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/proto/caffe2.pb.h"
25 
26 namespace caffe2 {
27 
55 template <class CPUOp, typename SkipOutputCopy = SkipIndices<>>
56 class GPUFallbackOp final : public Operator<CUDAContext> {
57  public:
58  USE_OPERATOR_FUNCTIONS(CUDAContext);
59  GPUFallbackOp(const OperatorDef& def, Workspace* ws)
60  : Operator<CUDAContext>(def, ws) {
61  CAFFE_ENFORCE_EQ(def.device_option().device_type(), CUDA);
62  OperatorDef base_def_(def);
63  // base_def_ runs on CPU, so we will set its device option to CPU.
64  base_def_.clear_device_option();
65  base_def_.mutable_device_option()->set_device_type(CPU);
66  // Set up the symbols for the local workspace.
67  for (const string& name : def.input()) {
68  local_input_blobs_.push_back(local_ws_.CreateBlob(name));
69  CHECK_NOTNULL(local_input_blobs_.back());
70  }
71  base_op_.reset(new CPUOp(base_def_, &local_ws_));
72  for (const string& name : def.output()) {
73  local_output_blobs_.push_back(local_ws_.GetBlob(name));
74  CHECK_NOTNULL(local_output_blobs_.back());
75  }
76  }
77 
78  bool RunOnDevice() override {
79  bool need_sync = false;
80  for (int i = 0; i < InputSize(); ++i) {
81  if (OperatorBase::InputIsType<TensorCUDA>(i)) {
82  local_input_blobs_[i]->template GetMutable<TensorCPU>()->CopyFrom(
83  Input(i), &context_);
84  need_sync = true;
85  } else {
86  VLOG(1) << "Input " << i << " is not TensorCUDA. Skipping copy.";
87  // Note(jiayq): This removes a const but conceptually
88  // local_input_blobs will only be used as const blob input for the
89  // base op so we are still fine.
90  local_input_blobs_[i]->ShareExternal(
91  const_cast<void*>(OperatorBase::Inputs()[i]->GetRaw()),
92  OperatorBase::Inputs()[i]->meta());
93  }
94  }
95 
96  // Sync to make sure copies are done.
97  if (need_sync) {
98  context_.FinishDeviceComputation();
99  }
100 
101  if (!base_op_->Run()) {
102  LOG(ERROR) << "Base op run failed in GPUFallbackOp. Def: "
103  << ProtoDebugString(this->debug_def());
104  return false;
105  }
106  for (int i = 0; i < OutputSize(); ++i) {
107  if (SkipOutputCopy::Contains(i)) {
108  VLOG(1) << "Copy output: index " << i << " skipped.";
109  continue;
110  }
111  CAFFE_ENFORCE(
112  local_output_blobs_[i]->template IsType<TensorCPU>(),
113  "GPU fallback op currently does not support non-TensorCPU "
114  "output type who needs copying.");
115  Output(i)->CopyFrom(
116  local_output_blobs_[i]->template Get<TensorCPU>(), &context_);
117  }
118  return true;
119  }
120 
121  protected:
122  Workspace local_ws_;
123  vector<Blob*> local_input_blobs_;
124  vector<Blob*> local_output_blobs_;
125  std::unique_ptr<CPUOp> base_op_;
126 };
127 
128 } // namespace caffe2
129 
130 #endif // CAFFE2_OPERATORS_OPERATOR_FALLBACK_H_
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:120
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:182
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:180
A templated class to allow one to wrap a CPU operator as a CUDA operator.
Copyright (c) 2016-present, Facebook, Inc.