7 #include <unordered_map> 9 #include <ATen/core/ATenGeneral.h> 10 #include <c10/core/Allocator.h> 11 #include <c10/util/typeid.h> 12 #include <c10/util/Exception.h> 13 #include <c10/util/Registry.h> 14 #include <c10/core/CopyBytes.h> 36 virtual Device device()
const = 0;
39 virtual DeviceType device_type()
const = 0;
41 virtual void SwitchToDevice(
int ) = 0;
43 inline void SwitchToDevice() {
49 virtual void Record(
caffe2::Event* ev,
const char* err_msg =
nullptr)
52 virtual void FinishDeviceComputation() = 0;
58 virtual void CopyBytesSameDevice(
63 virtual void CopyBytesFromCPU(
size_t nbytes,
const void* src,
void* dst) = 0;
65 virtual void CopyBytesToCPU(
size_t nbytes,
const void* src,
void* dst) = 0;
68 inline void CopySameDevice(
size_t n,
const T* src,
T* dst) {
70 std::is_fundamental<T>::value,
71 "CopySameDevice requires fundamental types");
73 n *
sizeof(
T), static_cast<const void*>(src), static_cast<void*>(dst));
77 inline void CopyFromCPU(
size_t n,
const T* src,
T* dst) {
79 std::is_fundamental<T>::value,
80 "CopyFromCPU requires fundamental types");
82 n *
sizeof(
T), static_cast<const void*>(src), static_cast<void*>(dst));
86 inline void CopyToCPU(
size_t n,
const T* src,
T* dst) {
88 std::is_fundamental<T>::value,
"CopyToCPU requires fundamental types");
90 n *
sizeof(
T), static_cast<const void*>(src), static_cast<void*>(dst));
93 virtual bool SupportsNonFundamentalTypes()
const {
97 inline void EnforceMetaCopyOK() {
99 SupportsNonFundamentalTypes(),
"Context requires fundamental types");
102 void CopyItemsSameDevice(
109 meta.
copy()(src, dst, n);
111 CopyBytesSameDevice(n * meta.
itemsize(), src, dst);
115 void CopyItemsFromCPU(
122 meta.
copy()(src, dst, n);
124 CopyBytesFromCPU(n * meta.
itemsize(), src, dst);
135 meta.
copy()(src, dst, n);
137 CopyBytesToCPU(n * meta.
itemsize(), src, dst);
143 C10_DECLARE_TYPED_REGISTRY(
150 #define REGISTER_CONTEXT(type, ...) \ 151 C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__) 153 inline std::unique_ptr<at::BaseContext> CreateContext(
155 return at::ContextRegistry()->Create(device.
type(), device);
163 using at::CreateContext;
Virtual interface for the Context class in Caffe2.
Represents a a compute device on which a tensor is located.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Flush-To-Zero and Denormals-Are-Zero mode.
DeviceType type() const noexcept
Returns the type of device this is.