Caffe2 - C++ API
A deep learning, cross platform ML framework
Device.h
1 #pragma once
2 
3 #include <c10/core/DeviceType.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Exception.h>
6 
7 #include <cstddef>
8 #include <functional>
9 #include <iosfwd>
10 #include <string>
11 
12 namespace c10 {
13 
18 using DeviceIndex = int16_t;
19 
30 struct C10_API Device final {
31  using Type = DeviceType;
32 
35  /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
36  : type_(type), index_(index) {
37  validate();
38  }
39 
45  /* implicit */ Device(const std::string& device_string);
46 
49  bool operator==(const Device& other) const noexcept {
50  return this->type_ == other.type_ && this->index_ == other.index_;
51  }
52 
55  bool operator!=(const Device& other) const noexcept {
56  return !(*this == other);
57  }
58 
60  void set_index(DeviceIndex index) {
61  index_ = index;
62  }
63 
65  DeviceType type() const noexcept {
66  return type_;
67  }
68 
70  DeviceIndex index() const noexcept {
71  return index_;
72  }
73 
75  bool has_index() const noexcept {
76  return index_ != -1;
77  }
78 
80  bool is_cuda() const noexcept {
81  return type_ == DeviceType::CUDA;
82  }
83 
85  bool is_cpu() const noexcept {
86  return type_ == DeviceType::CPU;
87  }
88 
89  private:
90  DeviceType type_;
91  DeviceIndex index_ = -1;
92  void validate();
93 };
94 
95 C10_API std::ostream& operator<<(
96  std::ostream& stream,
97  const Device& device);
98 
99 } // namespace c10
100 
101 namespace std {
102 template <>
103 struct hash<c10::Device> {
104  size_t operator()(c10::Device d) const noexcept {
105  // Are you here because this static assert failed? Make sure you ensure
106  // that the bitmasking code below is updated accordingly!
107  static_assert(sizeof(c10::DeviceType) == 2, "DeviceType is not 16-bit");
108  static_assert(sizeof(c10::DeviceIndex) == 2, "DeviceIndex is not 16-bit");
109  // Note [Hazard when concatenating signed integers]
110  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111  // We must first convert to a same-sized unsigned type, before promoting to
112  // the result type, to prevent sign extension when any of the values is -1.
113  // If sign extension occurs, you'll clobber all of the values in the MSB
114  // half of the resulting integer.
115  //
116  // Technically, by C/C++ integer promotion rules, we only need one of the
117  // uint32_t casts to the result type, but we put in both for explicitness's sake.
118  uint32_t bits =
119  static_cast<uint32_t>(static_cast<uint16_t>(d.type())) << 16
120  | static_cast<uint32_t>(static_cast<uint16_t>(d.index()));
121  return std::hash<uint32_t>{}(bits);
122  }
123 };
124 } // namespace std
bool has_index() const noexcept
Returns true if the device has a non-default index.
Definition: Device.h:75
bool operator!=(const Device &other) const noexcept
Returns true if the type or index of this Device differs from that of other.
Definition: Device.h:55
bool operator==(const Device &other) const noexcept
Returns true if the type and index of this Device matches that of other.
Definition: Device.h:49
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
TensorOptions device(Device device)
Convenience function that returns a TensorOptions object with the device set to the given one...
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
bool is_cpu() const noexcept
Return true if the device is of CPU type.
Definition: Device.h:85
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
Device(DeviceType type, DeviceIndex index=-1)
Constructs a new Device from a DeviceType and an optional device index.
Definition: Device.h:35
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
void set_index(DeviceIndex index)
Sets the device index.
Definition: Device.h:60
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65