Caffe2 - C++ API
A deep learning, cross platform ML framework
load_save_op_gpu.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/operators/load_save_op.h"
3 
4 namespace caffe2 {
5 
6 template <>
7 void LoadOp<CUDAContext>::SetCurrentDevice(BlobProto* proto) {
8  if (proto->has_tensor()) {
9  proto->mutable_tensor()->clear_device_detail();
10  auto* device_detail = proto->mutable_tensor()->mutable_device_detail();
11  device_detail->set_device_type(PROTO_CUDA);
12  device_detail->set_device_id(CaffeCudaGetDevice());
13  }
14 }
15 
16 REGISTER_CUDA_OPERATOR(Load, LoadOp<CUDAContext>);
17 REGISTER_CUDA_OPERATOR(Save, SaveOp<CUDAContext>);
18 REGISTER_CUDA_OPERATOR(Checkpoint, CheckpointOp<CUDAContext>);
19 } // namespace caffe2
int CaffeCudaGetDevice()
Gets the current GPU id.
Definition: common_gpu.cc:96
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13