Caffe2 - C++ API
A deep learning, cross platform ML framework
Backend.h
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/core/TensorTypeId.h>
5 #include <c10/core/TensorTypeIdRegistration.h>
6 #include <c10/util/Exception.h>
7 
8 #include <stdexcept>
9 
10 namespace c10 {
11 
23 enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, NumOptions };
24 
25 static inline Backend toSparse(Backend b) {
26  switch (b) {
27  case Backend::CPU:
28  return Backend::SparseCPU;
29  case Backend::CUDA:
30  return Backend::SparseCUDA;
31  case Backend::HIP:
32  return Backend::SparseHIP;
33  case Backend::SparseCPU:
34  return Backend::SparseCPU;
35  case Backend::SparseCUDA:
36  return Backend::SparseCUDA;
37  case Backend::SparseHIP:
38  return Backend::SparseHIP;
39  default:
40  throw std::runtime_error("Unknown backend");
41  }
42 }
43 
44 static inline Backend toDense(Backend b) {
45  switch (b) {
46  case Backend::CPU:
47  return Backend::CPU;
48  case Backend::CUDA:
49  return Backend::CUDA;
50  case Backend::HIP:
51  return Backend::HIP;
52  case Backend::MSNPU:
53  return Backend::MSNPU;
54  case Backend::XLA:
55  return Backend::XLA;
56  case Backend::SparseCPU:
57  return Backend::CPU;
58  case Backend::SparseCUDA:
59  return Backend::CUDA;
60  case Backend::SparseHIP:
61  return Backend::HIP;
62  default:
63  throw std::runtime_error("Unknown backend");
64  }
65 }
66 
67 static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
68  if (t == CPUTensorId()) {
69  return Backend::CPU;
70  } else if (t == CUDATensorId()) {
71  return Backend::CUDA;
72  } else if (t == HIPTensorId()) {
73  return Backend::HIP;
74  } else if (t == MSNPUTensorId()) {
75  return Backend::MSNPU;
76  } else if (t == XLATensorId()) {
77  return Backend::XLA;
78  } else if (t == SparseCPUTensorId()) {
79  return Backend::SparseCPU;
80  } else if (t == SparseCUDATensorId()) {
81  return Backend::SparseCUDA;
82  } else if (t == SparseHIPTensorId()) {
83  return Backend::SparseHIP;
84  } else if (t == UndefinedTensorId()) {
85  return Backend::Undefined;
86  } else {
87  AT_ERROR("Unrecognized tensor type ID: ", t);
88  }
89 }
90 
91 static inline TensorTypeId backendToTensorTypeId(Backend b) {
92  switch (b) {
93  case Backend::CPU:
94  return CPUTensorId();
95  case Backend::CUDA:
96  return CUDATensorId();
97  case Backend::HIP:
98  return HIPTensorId();
99  case Backend::MSNPU:
100  return MSNPUTensorId();
101  case Backend::XLA:
102  return XLATensorId();
103  case Backend::SparseCPU:
104  return SparseCPUTensorId();
105  case Backend::SparseCUDA:
106  return SparseCUDATensorId();
107  case Backend::SparseHIP:
108  return SparseHIPTensorId();
109  case Backend::Undefined:
110  return UndefinedTensorId();
111  default:
112  throw std::runtime_error("Unknown backend");
113  }
114 }
115 
116 static inline DeviceType backendToDeviceType(Backend b) {
117  switch (b) {
118  case Backend::CPU:
119  return DeviceType::CPU;
120  case Backend::CUDA:
121  return DeviceType::CUDA;
122  case Backend::HIP:
123  return DeviceType::HIP;
124  case Backend::MSNPU:
125  return DeviceType::MSNPU;
126  case Backend::XLA:
127  return DeviceType::XLA;
128  case Backend::SparseCPU:
129  return DeviceType::CPU;
130  case Backend::SparseCUDA:
131  return DeviceType::CUDA;
132  case Backend::SparseHIP:
133  return DeviceType::HIP;
134  case Backend::Undefined:
135  AT_ERROR("Undefined backend is not a valid device type");
136  default:
137  AT_ERROR("Unknown backend");
138  }
139 }
140 
141 static inline Backend backendToCPU(Backend b) {
142  switch (b) {
143  case Backend::CPU:
144  return Backend::CPU;
145  case Backend::CUDA:
146  return Backend::CPU;
147  case Backend::HIP:
148  return Backend::CPU;
149  case Backend::SparseCPU:
150  return Backend::SparseCPU;
151  case Backend::SparseCUDA:
152  return Backend::SparseCPU;
153  case Backend::SparseHIP:
154  return Backend::SparseCPU;
155  case Backend::MSNPU:
156  case Backend::XLA:
157  return Backend::CPU;
158  case Backend::Undefined:
159  return Backend::Undefined;
160  default:
161  AT_ERROR("Unknown backend");
162  }
163 }
164 
165 static inline Backend backendToCUDA(Backend b) {
166  switch (b) {
167  case Backend::CPU:
168  case Backend::CUDA:
169  case Backend::HIP:
170  case Backend::MSNPU:
171  case Backend::XLA:
172  return Backend::CUDA;
173  case Backend::SparseCPU:
174  case Backend::SparseCUDA:
175  case Backend::SparseHIP:
176  return Backend::SparseCUDA;
177  case Backend::Undefined:
178  return Backend::Undefined;
179  default:
180  AT_ERROR("Unknown backend");
181  }
182 }
183 
184 static inline Backend backendToHIP(Backend b) {
185  switch (b) {
186  case Backend::CPU:
187  case Backend::CUDA:
188  case Backend::HIP:
189  case Backend::MSNPU:
190  case Backend::XLA:
191  return Backend::HIP;
192  case Backend::SparseCPU:
193  case Backend::SparseCUDA:
194  case Backend::SparseHIP:
195  return Backend::SparseHIP;
196  case Backend::Undefined:
197  return Backend::Undefined;
198  default:
199  AT_ERROR("Unknown backend");
200  }
201 }
202 
203 constexpr DeviceType kCPU = DeviceType::CPU;
204 constexpr DeviceType kCUDA = DeviceType::CUDA;
205 constexpr DeviceType kHIP = DeviceType::HIP;
206 constexpr DeviceType kMSNPU = DeviceType::MSNPU;
207 constexpr DeviceType kXLA = DeviceType::XLA;
208 
209 static inline const char* toString(Backend b) {
210  switch (b) {
211  case Backend::CPU:
212  return "CPU";
213  case Backend::CUDA:
214  return "CUDA";
215  case Backend::HIP:
216  return "HIP";
217  case Backend::MSNPU:
218  return "MSNPU";
219  case Backend::XLA:
220  return "XLA";
221  case Backend::SparseCPU:
222  return "SparseCPU";
223  case Backend::SparseCUDA:
224  return "SparseCUDA";
225  case Backend::SparseHIP:
226  return "SparseHIP";
227  default:
228  return "UNKNOWN_BACKEND";
229  }
230 }
231 
232 } // namespace c10
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Definition: Backend.h:23
Dynamic type ID of a Tensor argument.
Definition: TensorTypeId.h:19
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7