Please, help us to better know about our user community by answering the following short survey: https://forms.gle/wpyrxWi18ox9Z5ae9
 
Loading...
Searching...
No Matches
EventCount.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
11#define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
12
13namespace Eigen {
14
15// EventCount allows to wait for arbitrary predicates in non-blocking
16// algorithms. Think of condition variable, but wait predicate does not need to
17// be protected by a mutex. Usage:
18// Waiting thread does:
19//
20// if (predicate)
21// return act();
22// EventCount::Waiter& w = waiters[my_index];
23// ec.Prewait(&w);
24// if (predicate) {
25// ec.CancelWait(&w);
26// return act();
27// }
28// ec.CommitWait(&w);
29//
30// Notifying thread does:
31//
32// predicate = true;
33// ec.Notify(true);
34//
35// Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
36// cheap, but they are executed only if the preceding predicate check has
37// failed.
38//
39// Algorithm outline:
40// There are two main variables: predicate (managed by user) and state_.
41// Operation closely resembles Dekker mutual algorithm:
42// https://en.wikipedia.org/wiki/Dekker%27s_algorithm
43// Waiting thread sets state_ then checks predicate, Notifying thread sets
44// predicate then checks state_. Due to seq_cst fences in between these
45// operations it is guaranteed than either waiter will see predicate change
46// and won't block, or notifying thread will see state_ change and will unblock
47// the waiter, or both. But it can't happen that both threads don't see each
48// other changes, which would lead to deadlock.
49class EventCount {
50 public:
51 class Waiter;
52
53 EventCount(MaxSizeVector<Waiter>& waiters)
54 : state_(kStackMask), waiters_(waiters) {
55 eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
56 }
57
58 ~EventCount() {
59 // Ensure there are no waiters.
60 eigen_plain_assert(state_.load() == kStackMask);
61 }
62
63 // Prewait prepares for waiting.
64 // After calling Prewait, the thread must re-check the wait predicate
65 // and then call either CancelWait or CommitWait.
66 void Prewait() {
67 uint64_t state = state_.load(std::memory_order_relaxed);
68 for (;;) {
69 CheckState(state);
70 uint64_t newstate = state + kWaiterInc;
71 CheckState(newstate);
72 if (state_.compare_exchange_weak(state, newstate,
73 std::memory_order_seq_cst))
74 return;
75 }
76 }
77
78 // CommitWait commits waiting after Prewait.
79 void CommitWait(Waiter* w) {
80 eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
81 w->state = Waiter::kNotSignaled;
82 const uint64_t me = (w - &waiters_[0]) | w->epoch;
83 uint64_t state = state_.load(std::memory_order_seq_cst);
84 for (;;) {
85 CheckState(state, true);
86 uint64_t newstate;
87 if ((state & kSignalMask) != 0) {
88 // Consume the signal and return immidiately.
89 newstate = state - kWaiterInc - kSignalInc;
90 } else {
91 // Remove this thread from pre-wait counter and add to the waiter stack.
92 newstate = ((state & kWaiterMask) - kWaiterInc) | me;
93 w->next.store(state & (kStackMask | kEpochMask),
94 std::memory_order_relaxed);
95 }
96 CheckState(newstate);
97 if (state_.compare_exchange_weak(state, newstate,
98 std::memory_order_acq_rel)) {
99 if ((state & kSignalMask) == 0) {
100 w->epoch += kEpochInc;
101 Park(w);
102 }
103 return;
104 }
105 }
106 }
107
108 // CancelWait cancels effects of the previous Prewait call.
109 void CancelWait() {
110 uint64_t state = state_.load(std::memory_order_relaxed);
111 for (;;) {
112 CheckState(state, true);
113 uint64_t newstate = state - kWaiterInc;
114 // We don't know if the thread was also notified or not,
115 // so we should not consume a signal unconditionaly.
116 // Only if number of waiters is equal to number of signals,
117 // we know that the thread was notified and we must take away the signal.
118 if (((state & kWaiterMask) >> kWaiterShift) ==
119 ((state & kSignalMask) >> kSignalShift))
120 newstate -= kSignalInc;
121 CheckState(newstate);
122 if (state_.compare_exchange_weak(state, newstate,
123 std::memory_order_acq_rel))
124 return;
125 }
126 }
127
128 // Notify wakes one or all waiting threads.
129 // Must be called after changing the associated wait predicate.
130 void Notify(bool notifyAll) {
131 std::atomic_thread_fence(std::memory_order_seq_cst);
132 uint64_t state = state_.load(std::memory_order_acquire);
133 for (;;) {
134 CheckState(state);
135 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
136 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
137 // Easy case: no waiters.
138 if ((state & kStackMask) == kStackMask && waiters == signals) return;
139 uint64_t newstate;
140 if (notifyAll) {
141 // Empty wait stack and set signal to number of pre-wait threads.
142 newstate =
143 (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
144 } else if (signals < waiters) {
145 // There is a thread in pre-wait state, unblock it.
146 newstate = state + kSignalInc;
147 } else {
148 // Pop a waiter from list and unpark it.
149 Waiter* w = &waiters_[state & kStackMask];
150 uint64_t next = w->next.load(std::memory_order_relaxed);
151 newstate = (state & (kWaiterMask | kSignalMask)) | next;
152 }
153 CheckState(newstate);
154 if (state_.compare_exchange_weak(state, newstate,
155 std::memory_order_acq_rel)) {
156 if (!notifyAll && (signals < waiters))
157 return; // unblocked pre-wait thread
158 if ((state & kStackMask) == kStackMask) return;
159 Waiter* w = &waiters_[state & kStackMask];
160 if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
161 Unpark(w);
162 return;
163 }
164 }
165 }
166
167 class Waiter {
168 friend class EventCount;
169 // Align to 128 byte boundary to prevent false sharing with other Waiter
170 // objects in the same vector.
171 EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
172 std::mutex mu;
173 std::condition_variable cv;
174 uint64_t epoch = 0;
175 unsigned state = kNotSignaled;
176 enum {
177 kNotSignaled,
178 kWaiting,
179 kSignaled,
180 };
181 };
182
183 private:
184 // State_ layout:
185 // - low kWaiterBits is a stack of waiters committed wait
186 // (indexes in waiters_ array are used as stack elements,
187 // kStackMask means empty stack).
188 // - next kWaiterBits is count of waiters in prewait state.
189 // - next kWaiterBits is count of pending signals.
190 // - remaining bits are ABA counter for the stack.
191 // (stored in Waiter node and incremented on push).
192 static const uint64_t kWaiterBits = 14;
193 static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
194 static const uint64_t kWaiterShift = kWaiterBits;
195 static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
196 << kWaiterShift;
197 static const uint64_t kWaiterInc = 1ull << kWaiterShift;
198 static const uint64_t kSignalShift = 2 * kWaiterBits;
199 static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
200 << kSignalShift;
201 static const uint64_t kSignalInc = 1ull << kSignalShift;
202 static const uint64_t kEpochShift = 3 * kWaiterBits;
203 static const uint64_t kEpochBits = 64 - kEpochShift;
204 static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
205 static const uint64_t kEpochInc = 1ull << kEpochShift;
206 std::atomic<uint64_t> state_;
207 MaxSizeVector<Waiter>& waiters_;
208
209 static void CheckState(uint64_t state, bool waiter = false) {
210 static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
211 const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
212 const uint64_t signals = (state & kSignalMask) >> kSignalShift;
213 eigen_plain_assert(waiters >= signals);
214 eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
215 eigen_plain_assert(!waiter || waiters > 0);
216 (void)waiters;
217 (void)signals;
218 }
219
220 void Park(Waiter* w) {
221 std::unique_lock<std::mutex> lock(w->mu);
222 while (w->state != Waiter::kSignaled) {
223 w->state = Waiter::kWaiting;
224 w->cv.wait(lock);
225 }
226 }
227
228 void Unpark(Waiter* w) {
229 for (Waiter* next; w; w = next) {
230 uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
231 next = wnext == kStackMask ? nullptr : &waiters_[wnext];
232 unsigned state;
233 {
234 std::unique_lock<std::mutex> lock(w->mu);
235 state = w->state;
236 w->state = Waiter::kSignaled;
237 }
238 // Avoid notifying if it wasn't waiting.
239 if (state == Waiter::kWaiting) w->cv.notify_one();
240 }
241 }
242
243 EventCount(const EventCount&) = delete;
244 void operator=(const EventCount&) = delete;
245};
246
247} // namespace Eigen
248
249#endif // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
Namespace containing all symbols from the Eigen library.