1 #include "caffe2/core/common_gpu.h" 2 #include "caffe2/core/context_gpu.h" 3 #include "caffe2/image/image_input_op.h" 8 bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
9 const std::vector<std::int64_t>& dims,
12 if (output_type_ == TensorProto_DataType_FLOAT) {
14 OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
15 TransformOnGPU<uint8_t, float, CUDAContext>(
16 prefetched_image_on_device_,
21 }
else if (output_type_ == TensorProto_DataType_FLOAT16) {
23 OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
24 TransformOnGPU<uint8_t, at::Half, CUDAContext>(
25 prefetched_image_on_device_,
36 REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
Represents a a compute device on which a tensor is located.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...