HDK
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
ort_mutex.h
Go to the documentation of this file.
1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT License.
3 
4 #pragma once
5 #ifdef _WIN32
6 #include <Windows.h>
7 #include <mutex>
8 namespace onnxruntime {
9 // Q: Why OrtMutex is better than std::mutex
10 // A: OrtMutex supports static initialization but std::mutex doesn't. Static initialization helps us prevent the "static
11 // initialization order problem".
12 
13 // Q: Why std::mutex can't make it?
14 // A: VC runtime has to support Windows XP at ABI level. But we don't have such requirement.
15 
16 // Q: Is OrtMutex faster than std::mutex?
17 // A: Sure
18 
19 class OrtMutex {
20  private:
21  SRWLOCK data_ = SRWLOCK_INIT;
22 
23  public:
24  constexpr OrtMutex() = default;
25  // SRW locks do not need to be explicitly destroyed.
26  ~OrtMutex() = default;
27  OrtMutex(const OrtMutex&) = delete;
28  OrtMutex& operator=(const OrtMutex&) = delete;
29  void lock() { AcquireSRWLockExclusive(native_handle()); }
30  bool try_lock() noexcept { return TryAcquireSRWLockExclusive(native_handle()) == TRUE; }
31  void unlock() noexcept { ReleaseSRWLockExclusive(native_handle()); }
32  using native_handle_type = SRWLOCK*;
33 
34  __forceinline native_handle_type native_handle() { return &data_; }
35 };
36 
37 class OrtCondVar {
38  CONDITION_VARIABLE native_cv_object = CONDITION_VARIABLE_INIT;
39 
40  public:
41  constexpr OrtCondVar() noexcept = default;
42  ~OrtCondVar() = default;
43 
44  OrtCondVar(const OrtCondVar&) = delete;
45  OrtCondVar& operator=(const OrtCondVar&) = delete;
46 
47  void notify_one() noexcept { WakeConditionVariable(&native_cv_object); }
48  void notify_all() noexcept { WakeAllConditionVariable(&native_cv_object); }
49 
50  void wait(std::unique_lock<OrtMutex>& lk) {
51  if (SleepConditionVariableSRW(&native_cv_object, lk.mutex()->native_handle(), INFINITE, 0) != TRUE) {
52  std::terminate();
53  }
54  }
55  template <class _Predicate>
56  void wait(std::unique_lock<OrtMutex>& __lk, _Predicate __pred);
57 
58  /**
59  * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns
60  * cv_status::no_timeout.
61  * @param cond_mutex A unique_lock<OrtMutex> object.
62  * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up.
63  * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns
64  * cv_status::no_timeout
65  */
66  template <class Rep, class Period>
67  std::cv_status wait_for(std::unique_lock<OrtMutex>& cond_mutex, const std::chrono::duration<Rep, Period>& rel_time);
68  using native_handle_type = CONDITION_VARIABLE*;
69 
70  native_handle_type native_handle() { return &native_cv_object; }
71 
72  private:
73  void timed_wait_impl(std::unique_lock<OrtMutex>& __lk,
74  std::chrono::time_point<std::chrono::system_clock, std::chrono::nanoseconds>);
75 };
76 
77 template <class _Predicate>
78 void OrtCondVar::wait(std::unique_lock<OrtMutex>& __lk, _Predicate __pred) {
79  while (!__pred()) wait(__lk);
80 }
81 
82 template <class Rep, class Period>
83 std::cv_status OrtCondVar::wait_for(std::unique_lock<OrtMutex>& cond_mutex,
84  const std::chrono::duration<Rep, Period>& rel_time) {
85  // TODO: is it possible to use nsync_from_time_point_ ?
86  using namespace std::chrono;
87  if (rel_time <= duration<Rep, Period>::zero())
89  using SystemTimePointFloat = time_point<system_clock, duration<long double, std::nano> >;
90  using SystemTimePoint = time_point<system_clock, nanoseconds>;
91  SystemTimePointFloat max_time = SystemTimePoint::max();
92  steady_clock::time_point steady_now = steady_clock::now();
93  system_clock::time_point system_now = system_clock::now();
94  if (max_time - rel_time > system_now) {
95  nanoseconds remain = duration_cast<nanoseconds>(rel_time);
96  if (remain < rel_time)
97  ++remain;
98  timed_wait_impl(cond_mutex, system_now + remain);
99  } else
100  timed_wait_impl(cond_mutex, SystemTimePoint::max());
101  return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout;
102 }
103 } // namespace onnxruntime
104 #else
105 #include "nsync.h"
106 #include <mutex> //for unique_lock
107 #include <condition_variable> //for cv_status
108 namespace onnxruntime {
109 
110 class OrtMutex {
111  nsync::nsync_mu data_ = NSYNC_MU_INIT;
112 
113  public:
114  constexpr OrtMutex() = default;
115  ~OrtMutex() = default;
116  OrtMutex(const OrtMutex&) = delete;
117  OrtMutex& operator=(const OrtMutex&) = delete;
118 
119  void lock() { nsync::nsync_mu_lock(&data_); }
120  bool try_lock() noexcept { return nsync::nsync_mu_trylock(&data_) == 0; }
121  void unlock() noexcept { nsync::nsync_mu_unlock(&data_); }
122 
123  using native_handle_type = nsync::nsync_mu*;
124  native_handle_type native_handle() { return &data_; }
125 };
126 
127 class OrtCondVar {
128  nsync::nsync_cv native_cv_object = NSYNC_CV_INIT;
129 
130  public:
131  constexpr OrtCondVar() noexcept = default;
132 
133  ~OrtCondVar() = default;
134  OrtCondVar(const OrtCondVar&) = delete;
135  OrtCondVar& operator=(const OrtCondVar&) = delete;
136 
137  void notify_one() noexcept { nsync::nsync_cv_signal(&native_cv_object); }
138  void notify_all() noexcept { nsync::nsync_cv_broadcast(&native_cv_object); }
139 
140  void wait(std::unique_lock<OrtMutex>& lk);
141  template <class _Predicate>
142  void wait(std::unique_lock<OrtMutex>& __lk, _Predicate __pred);
143 
144  /**
145  * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns
146  * cv_status::no_timeout.
147  * @param cond_mutex A unique_lock<OrtMutex> object.
148  * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up.
149  * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns
150  * cv_status::no_timeout
151  */
152  template <class Rep, class Period>
153  std::cv_status wait_for(std::unique_lock<OrtMutex>& cond_mutex, const std::chrono::duration<Rep, Period>& rel_time);
154  using native_handle_type = nsync::nsync_cv*;
155  native_handle_type native_handle() { return &native_cv_object; }
156 
157  private:
158  void timed_wait_impl(std::unique_lock<OrtMutex>& __lk,
159  std::chrono::time_point<std::chrono::system_clock, std::chrono::nanoseconds>);
160 };
161 
162 template <class _Predicate>
163 void OrtCondVar::wait(std::unique_lock<OrtMutex>& __lk, _Predicate __pred) {
164  while (!__pred()) wait(__lk);
165 }
166 
167 template <class Rep, class Period>
168 std::cv_status OrtCondVar::wait_for(std::unique_lock<OrtMutex>& cond_mutex,
169  const std::chrono::duration<Rep, Period>& rel_time) {
170  // TODO: is it possible to use nsync_from_time_point_ ?
171  using namespace std::chrono;
172  if (rel_time <= duration<Rep, Period>::zero())
174  using SystemTimePointFloat = time_point<system_clock, duration<long double, std::nano> >;
175  using SystemTimePoint = time_point<system_clock, nanoseconds>;
176  SystemTimePointFloat max_time = SystemTimePoint::max();
177  steady_clock::time_point steady_now = steady_clock::now();
178  system_clock::time_point system_now = system_clock::now();
179  if (max_time - rel_time > system_now) {
180  nanoseconds remain = duration_cast<nanoseconds>(rel_time);
181  if (remain < rel_time)
182  ++remain;
183  timed_wait_impl(cond_mutex, system_now + remain);
184  } else
185  timed_wait_impl(cond_mutex, SystemTimePoint::max());
186  return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout;
187 }
188 }; // namespace onnxruntime
189 #endif
void notify_all() noexcept
Definition: ort_mutex.h:138
native_handle_type native_handle()
Definition: ort_mutex.h:124
native_handle_type native_handle()
Definition: ort_mutex.h:155
constexpr OrtCondVar() noexcept=default
OrtCondVar & operator=(const OrtCondVar &)=delete
constexpr OrtMutex()=default
void unlock() noexcept
Definition: ort_mutex.h:121
nsync::nsync_cv * native_handle_type
Definition: ort_mutex.h:154
nsync::nsync_mu * native_handle_type
Definition: ort_mutex.h:123
GLbitfield GLuint64 timeout
Definition: glcorearb.h:1599
void wait(std::unique_lock< OrtMutex > &lk)
OrtMutex & operator=(const OrtMutex &)=delete
std::cv_status wait_for(std::unique_lock< OrtMutex > &cond_mutex, const std::chrono::duration< Rep, Period > &rel_time)
Definition: ort_mutex.h:168
void notify_one() noexcept
Definition: ort_mutex.h:137
ImageBuf OIIO_API max(Image_or_Const A, Image_or_Const B, ROI roi={}, int nthreads=0)
bool try_lock() noexcept
Definition: ort_mutex.h:120
ImageBuf OIIO_API zero(ROI roi, int nthreads=0)