1 #ifndef CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_ 2 #define CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
18 bool RunOnDevice()
override {
19 if (this->InputIsTensorType(0, CPU)) {
20 return CopyWithContext<CPUContext>();
21 }
else if (this->InputIsTensorType(0, Context::GetDeviceType())) {
23 return CopyWithContext<Context>();
26 "Unexpected Input Blob: ",
27 OperatorBase::Inputs().
at(0)->meta().name());
33 template <
class InputContext>
34 bool CopyWithContext() {
36 auto* output = this->
template Output<Tensor>(0, CPU);
37 auto& input = this->
template Input<Tensor>(0, InputContext::GetDeviceType());
38 output->ResizeLike(input);
39 context_.CopyItemsToCPU(
43 output->raw_mutable_data(input.dtype()));
44 context_.FinishDeviceComputation();
51 #endif // CAFFE2_OPERATORS_ENSURE_CPU_OUTPUT_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Flush-To-Zero and Denormals-Are-Zero mode.