HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ortdevice.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 
6 #include <sstream>
8 
9 // Struct to represent a physical device.
10 struct OrtDevice {
11  using DeviceType = int8_t;
12  using MemoryType = int8_t;
13  using DeviceId = int16_t;
14 
15  // Pre-defined device types.
16  static const DeviceType CPU = 0;
17  static const DeviceType GPU = 1; // Nvidia or AMD
18  static const DeviceType FPGA = 2;
19  static const DeviceType NPU = 3; // Ascend
20 
21  struct MemType {
22  // Pre-defined memory types.
23  static const MemoryType DEFAULT = 0;
24  static const MemoryType CUDA_PINNED = 1;
25  static const MemoryType HIP_PINNED = 2;
26  static const MemoryType CANN_PINNED = 3;
27  };
28 
29  constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
30  : device_type(device_type_),
31  memory_type(memory_type_),
32  device_id(device_id_) {}
33 
34  constexpr OrtDevice() : OrtDevice(CPU, MemType::DEFAULT, 0) {}
35 
36  DeviceType Type() const {
37  return device_type;
38  }
39 
40  MemoryType MemType() const {
41  return memory_type;
42  }
43 
44  DeviceId Id() const {
45  return device_id;
46  }
47 
49  std::ostringstream ostr;
50  ostr << "Device:["
51  << "DeviceType:" << static_cast<int>(device_type)
52  << " MemoryType:" << static_cast<int>(memory_type)
53  << " DeviceId:" << device_id
54  << "]";
55  return ostr.str();
56  }
57 
58  // This is to make OrtDevice a valid key in hash tables
59  size_t Hash() const {
60  auto h = std::hash<int>()(device_type);
61  onnxruntime::HashCombine(memory_type, h);
62  onnxruntime::HashCombine(device_id, h);
63  return h;
64  }
65 
66  // To make OrtDevice become a valid key in std map
67  bool operator<(const OrtDevice& other) const {
68  if (device_type != other.device_type)
69  return device_type < other.device_type;
70  if (memory_type != other.memory_type)
71  return memory_type < other.memory_type;
72 
73  return device_id < other.device_id;
74  }
75 
76  private:
77  // Device type.
78  int32_t device_type : 8;
79 
80  // Memory type.
81  int32_t memory_type : 8;
82 
83  // Device index.
84  int32_t device_id : 16;
85 };
86 
87 inline bool operator==(const OrtDevice& left, const OrtDevice& other) {
88  return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type();
89 }
90 
91 inline bool operator!=(const OrtDevice& left, const OrtDevice& other) {
92  return !(left == other);
93 }
94 
95 namespace std {
96 template <>
97 struct hash<OrtDevice> {
98  size_t operator()(const OrtDevice& i) const {
99  return i.Hash();
100  }
101 };
102 } // namespace std
GLint left
Definition: glcorearb.h:2005
GLsizei const GLchar *const * string
Definition: glcorearb.h:814
size_t Hash() const
Definition: ortdevice.h:59
static const DeviceType FPGA
Definition: ortdevice.h:18
size_t operator()(const OrtDevice &i) const
Definition: ortdevice.h:98
int8_t DeviceType
Definition: ortdevice.h:11
static const MemoryType DEFAULT
Definition: ortdevice.h:23
bool operator<(const OrtDevice &other) const
Definition: ortdevice.h:67
void HashCombine(const T &value, size_t &seed)
Definition: hash_combine.h:17
int16_t DeviceId
Definition: ortdevice.h:13
bool operator!=(const Mat3< T0 > &m0, const Mat3< T1 > &m1)
Inequality operator, does exact floating point comparisons.
Definition: Mat3.h:556
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
Definition: ortdevice.h:29
MemoryType MemType() const
Definition: ortdevice.h:40
static const DeviceType GPU
Definition: ortdevice.h:17
int8_t MemoryType
Definition: ortdevice.h:12
static const DeviceType CPU
Definition: ortdevice.h:16
GLfloat GLfloat GLfloat GLfloat h
Definition: glcorearb.h:2002
DeviceId Id() const
Definition: ortdevice.h:44
static const MemoryType CUDA_PINNED
Definition: ortdevice.h:24
constexpr OrtDevice()
Definition: ortdevice.h:34
static const MemoryType HIP_PINNED
Definition: ortdevice.h:25
static const MemoryType CANN_PINNED
Definition: ortdevice.h:26
static const DeviceType NPU
Definition: ortdevice.h:19
bool operator==(const Mat3< T0 > &m0, const Mat3< T1 > &m1)
Equality operator, does exact floating point comparisons.
Definition: Mat3.h:542
std::string ToString() const
Definition: ortdevice.h:48
DeviceType Type() const
Definition: ortdevice.h:36