Caffe2 - C++ API
A deep learning, cross platform ML framework
image_input_op_gpu.cc
1 #include "caffe2/core/common_gpu.h"
2 #include "caffe2/core/context_gpu.h"
3 #include "caffe2/image/image_input_op.h"
4 
5 namespace caffe2 {
6 
7 template <>
8 bool ImageInputOp<CUDAContext>::ApplyTransformOnGPU(
9  const std::vector<std::int64_t>& dims,
10  const c10::Device& type) {
11  // GPU transform kernel allows explicitly setting output type
12  if (output_type_ == TensorProto_DataType_FLOAT) {
13  auto* image_output =
14  OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(type));
15  TransformOnGPU<uint8_t, float, CUDAContext>(
16  prefetched_image_on_device_,
17  image_output,
18  mean_gpu_,
19  std_gpu_,
20  &context_);
21  } else if (output_type_ == TensorProto_DataType_FLOAT16) {
22  auto* image_output =
23  OperatorBase::OutputTensor(0, dims, at::dtype<at::Half>().device(type));
24  TransformOnGPU<uint8_t, at::Half, CUDAContext>(
25  prefetched_image_on_device_,
26  image_output,
27  mean_gpu_,
28  std_gpu_,
29  &context_);
30  } else {
31  return false;
32  }
33  return true;
34 }
35 
36 REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp<CUDAContext>);
37 
38 } // namespace caffe2
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13