mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-27 07:29:52 +08:00
Merged in rmlarsen/eigen_threadpool (pull request PR-596)
Improve EventCount used by the non-blocking threadpool. Approved-by: Gael Guennebaud <g.gael@free.fr>
This commit is contained in:
commit
9558f4c25f
@ -20,7 +20,8 @@ namespace Eigen {
|
|||||||
// if (predicate)
|
// if (predicate)
|
||||||
// return act();
|
// return act();
|
||||||
// EventCount::Waiter& w = waiters[my_index];
|
// EventCount::Waiter& w = waiters[my_index];
|
||||||
// ec.Prewait(&w);
|
// if (!ec.Prewait(&w))
|
||||||
|
// return act();
|
||||||
// if (predicate) {
|
// if (predicate) {
|
||||||
// ec.CancelWait(&w);
|
// ec.CancelWait(&w);
|
||||||
// return act();
|
// return act();
|
||||||
@ -50,78 +51,78 @@ class EventCount {
|
|||||||
public:
|
public:
|
||||||
class Waiter;
|
class Waiter;
|
||||||
|
|
||||||
EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
|
EventCount(MaxSizeVector<Waiter>& waiters)
|
||||||
|
: state_(kStackMask), waiters_(waiters) {
|
||||||
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
|
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
|
||||||
// Initialize epoch to something close to overflow to test overflow.
|
|
||||||
state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~EventCount() {
|
~EventCount() {
|
||||||
// Ensure there are no waiters.
|
// Ensure there are no waiters.
|
||||||
eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
|
eigen_plain_assert(state_.load() == kStackMask);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prewait prepares for waiting.
|
// Prewait prepares for waiting.
|
||||||
// After calling this function the thread must re-check the wait predicate
|
// If Prewait returns true, the thread must re-check the wait predicate
|
||||||
// and call either CancelWait or CommitWait passing the same Waiter object.
|
// and then call either CancelWait or CommitWait.
|
||||||
void Prewait(Waiter* w) {
|
// Otherwise, the thread should assume the predicate may be true
|
||||||
w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
|
// and don't call CancelWait/CommitWait (there was a concurrent Notify call).
|
||||||
std::atomic_thread_fence(std::memory_order_seq_cst);
|
bool Prewait() {
|
||||||
|
uint64_t state = state_.load(std::memory_order_relaxed);
|
||||||
|
for (;;) {
|
||||||
|
CheckState(state);
|
||||||
|
uint64_t newstate = state + kWaiterInc;
|
||||||
|
if ((state & kSignalMask) != 0) {
|
||||||
|
// Consume the signal and cancel waiting.
|
||||||
|
newstate -= kSignalInc + kWaiterInc;
|
||||||
|
}
|
||||||
|
CheckState(newstate);
|
||||||
|
if (state_.compare_exchange_weak(state, newstate,
|
||||||
|
std::memory_order_seq_cst))
|
||||||
|
return (state & kSignalMask) == 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommitWait commits waiting.
|
// CommitWait commits waiting after Prewait.
|
||||||
void CommitWait(Waiter* w) {
|
void CommitWait(Waiter* w) {
|
||||||
|
eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
|
||||||
w->state = Waiter::kNotSignaled;
|
w->state = Waiter::kNotSignaled;
|
||||||
// Modification epoch of this waiter.
|
const uint64_t me = (w - &waiters_[0]) | w->epoch;
|
||||||
uint64_t epoch =
|
|
||||||
(w->epoch & kEpochMask) +
|
|
||||||
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
|
|
||||||
uint64_t state = state_.load(std::memory_order_seq_cst);
|
uint64_t state = state_.load(std::memory_order_seq_cst);
|
||||||
for (;;) {
|
for (;;) {
|
||||||
if (int64_t((state & kEpochMask) - epoch) < 0) {
|
CheckState(state, true);
|
||||||
// The preceding waiter has not decided on its fate. Wait until it
|
uint64_t newstate;
|
||||||
// calls either CancelWait or CommitWait, or is notified.
|
if ((state & kSignalMask) != 0) {
|
||||||
EIGEN_THREAD_YIELD();
|
// Consume the signal and return immidiately.
|
||||||
state = state_.load(std::memory_order_seq_cst);
|
newstate = state - kWaiterInc - kSignalInc;
|
||||||
continue;
|
} else {
|
||||||
|
// Remove this thread from pre-wait counter and add to the waiter stack.
|
||||||
|
newstate = ((state & kWaiterMask) - kWaiterInc) | me;
|
||||||
|
w->next.store(state & (kStackMask | kEpochMask),
|
||||||
|
std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
// We've already been notified.
|
CheckState(newstate);
|
||||||
if (int64_t((state & kEpochMask) - epoch) > 0) return;
|
|
||||||
// Remove this thread from prewait counter and add it to the waiter list.
|
|
||||||
eigen_plain_assert((state & kWaiterMask) != 0);
|
|
||||||
uint64_t newstate = state - kWaiterInc + kEpochInc;
|
|
||||||
newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
|
|
||||||
if ((state & kStackMask) == kStackMask)
|
|
||||||
w->next.store(nullptr, std::memory_order_relaxed);
|
|
||||||
else
|
|
||||||
w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
|
|
||||||
if (state_.compare_exchange_weak(state, newstate,
|
if (state_.compare_exchange_weak(state, newstate,
|
||||||
std::memory_order_release))
|
std::memory_order_acq_rel)) {
|
||||||
break;
|
if ((state & kSignalMask) == 0) {
|
||||||
|
w->epoch += kEpochInc;
|
||||||
|
Park(w);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Park(w);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CancelWait cancels effects of the previous Prewait call.
|
// CancelWait cancels effects of the previous Prewait call.
|
||||||
void CancelWait(Waiter* w) {
|
void CancelWait() {
|
||||||
uint64_t epoch =
|
|
||||||
(w->epoch & kEpochMask) +
|
|
||||||
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
|
|
||||||
uint64_t state = state_.load(std::memory_order_relaxed);
|
uint64_t state = state_.load(std::memory_order_relaxed);
|
||||||
for (;;) {
|
for (;;) {
|
||||||
if (int64_t((state & kEpochMask) - epoch) < 0) {
|
CheckState(state, true);
|
||||||
// The preceding waiter has not decided on its fate. Wait until it
|
uint64_t newstate = state - kWaiterInc;
|
||||||
// calls either CancelWait or CommitWait, or is notified.
|
// Also take away a signal if any.
|
||||||
EIGEN_THREAD_YIELD();
|
if ((state & kSignalMask) != 0) newstate -= kSignalInc;
|
||||||
state = state_.load(std::memory_order_relaxed);
|
CheckState(newstate);
|
||||||
continue;
|
if (state_.compare_exchange_weak(state, newstate,
|
||||||
}
|
std::memory_order_acq_rel))
|
||||||
// We've already been notified.
|
|
||||||
if (int64_t((state & kEpochMask) - epoch) > 0) return;
|
|
||||||
// Remove this thread from prewait counter.
|
|
||||||
eigen_plain_assert((state & kWaiterMask) != 0);
|
|
||||||
if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
|
|
||||||
std::memory_order_relaxed))
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -132,35 +133,33 @@ class EventCount {
|
|||||||
std::atomic_thread_fence(std::memory_order_seq_cst);
|
std::atomic_thread_fence(std::memory_order_seq_cst);
|
||||||
uint64_t state = state_.load(std::memory_order_acquire);
|
uint64_t state = state_.load(std::memory_order_acquire);
|
||||||
for (;;) {
|
for (;;) {
|
||||||
|
CheckState(state);
|
||||||
|
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
|
||||||
|
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
|
||||||
// Easy case: no waiters.
|
// Easy case: no waiters.
|
||||||
if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
|
if ((state & kStackMask) == kStackMask && waiters == signals) return;
|
||||||
return;
|
|
||||||
uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
|
|
||||||
uint64_t newstate;
|
uint64_t newstate;
|
||||||
if (notifyAll) {
|
if (notifyAll) {
|
||||||
// Reset prewait counter and empty wait list.
|
// Empty wait stack and set signal to number of pre-wait threads.
|
||||||
newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
|
newstate =
|
||||||
} else if (waiters) {
|
(state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
|
||||||
|
} else if (signals < waiters) {
|
||||||
// There is a thread in pre-wait state, unblock it.
|
// There is a thread in pre-wait state, unblock it.
|
||||||
newstate = state + kEpochInc - kWaiterInc;
|
newstate = state + kSignalInc;
|
||||||
} else {
|
} else {
|
||||||
// Pop a waiter from list and unpark it.
|
// Pop a waiter from list and unpark it.
|
||||||
Waiter* w = &waiters_[state & kStackMask];
|
Waiter* w = &waiters_[state & kStackMask];
|
||||||
Waiter* wnext = w->next.load(std::memory_order_relaxed);
|
uint64_t next = w->next.load(std::memory_order_relaxed);
|
||||||
uint64_t next = kStackMask;
|
newstate = (state & (kWaiterMask | kSignalMask)) | next;
|
||||||
if (wnext != nullptr) next = wnext - &waiters_[0];
|
|
||||||
// Note: we don't add kEpochInc here. ABA problem on the lock-free stack
|
|
||||||
// can't happen because a waiter is re-pushed onto the stack only after
|
|
||||||
// it was in the pre-wait state which inevitably leads to epoch
|
|
||||||
// increment.
|
|
||||||
newstate = (state & kEpochMask) + next;
|
|
||||||
}
|
}
|
||||||
|
CheckState(newstate);
|
||||||
if (state_.compare_exchange_weak(state, newstate,
|
if (state_.compare_exchange_weak(state, newstate,
|
||||||
std::memory_order_acquire)) {
|
std::memory_order_acq_rel)) {
|
||||||
if (!notifyAll && waiters) return; // unblocked pre-wait thread
|
if (!notifyAll && (signals < waiters))
|
||||||
|
return; // unblocked pre-wait thread
|
||||||
if ((state & kStackMask) == kStackMask) return;
|
if ((state & kStackMask) == kStackMask) return;
|
||||||
Waiter* w = &waiters_[state & kStackMask];
|
Waiter* w = &waiters_[state & kStackMask];
|
||||||
if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed);
|
if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
|
||||||
Unpark(w);
|
Unpark(w);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -171,11 +170,11 @@ class EventCount {
|
|||||||
friend class EventCount;
|
friend class EventCount;
|
||||||
// Align to 128 byte boundary to prevent false sharing with other Waiter
|
// Align to 128 byte boundary to prevent false sharing with other Waiter
|
||||||
// objects in the same vector.
|
// objects in the same vector.
|
||||||
EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
|
EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
|
||||||
std::mutex mu;
|
std::mutex mu;
|
||||||
std::condition_variable cv;
|
std::condition_variable cv;
|
||||||
uint64_t epoch;
|
uint64_t epoch = 0;
|
||||||
unsigned state;
|
unsigned state = kNotSignaled;
|
||||||
enum {
|
enum {
|
||||||
kNotSignaled,
|
kNotSignaled,
|
||||||
kWaiting,
|
kWaiting,
|
||||||
@ -185,23 +184,41 @@ class EventCount {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// State_ layout:
|
// State_ layout:
|
||||||
// - low kStackBits is a stack of waiters committed wait.
|
// - low kWaiterBits is a stack of waiters committed wait
|
||||||
|
// (indexes in waiters_ array are used as stack elements,
|
||||||
|
// kStackMask means empty stack).
|
||||||
// - next kWaiterBits is count of waiters in prewait state.
|
// - next kWaiterBits is count of waiters in prewait state.
|
||||||
// - next kEpochBits is modification counter.
|
// - next kWaiterBits is count of pending signals.
|
||||||
static const uint64_t kStackBits = 16;
|
// - remaining bits are ABA counter for the stack.
|
||||||
static const uint64_t kStackMask = (1ull << kStackBits) - 1;
|
// (stored in Waiter node and incremented on push).
|
||||||
static const uint64_t kWaiterBits = 16;
|
static const uint64_t kWaiterBits = 14;
|
||||||
static const uint64_t kWaiterShift = 16;
|
static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
|
||||||
|
static const uint64_t kWaiterShift = kWaiterBits;
|
||||||
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
|
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
|
||||||
<< kWaiterShift;
|
<< kWaiterShift;
|
||||||
static const uint64_t kWaiterInc = 1ull << kWaiterBits;
|
static const uint64_t kWaiterInc = 1ull << kWaiterShift;
|
||||||
static const uint64_t kEpochBits = 32;
|
static const uint64_t kSignalShift = 2 * kWaiterBits;
|
||||||
static const uint64_t kEpochShift = 32;
|
static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
|
||||||
|
<< kSignalShift;
|
||||||
|
static const uint64_t kSignalInc = 1ull << kSignalShift;
|
||||||
|
static const uint64_t kEpochShift = 3 * kWaiterBits;
|
||||||
|
static const uint64_t kEpochBits = 64 - kEpochShift;
|
||||||
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
|
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
|
||||||
static const uint64_t kEpochInc = 1ull << kEpochShift;
|
static const uint64_t kEpochInc = 1ull << kEpochShift;
|
||||||
std::atomic<uint64_t> state_;
|
std::atomic<uint64_t> state_;
|
||||||
MaxSizeVector<Waiter>& waiters_;
|
MaxSizeVector<Waiter>& waiters_;
|
||||||
|
|
||||||
|
static void CheckState(uint64_t state, bool waiter = false) {
|
||||||
|
static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
|
||||||
|
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
|
||||||
|
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
|
||||||
|
eigen_plain_assert(waiters >= signals);
|
||||||
|
eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
|
||||||
|
eigen_plain_assert(!waiter || waiters > 0);
|
||||||
|
(void)waiters;
|
||||||
|
(void)signals;
|
||||||
|
}
|
||||||
|
|
||||||
void Park(Waiter* w) {
|
void Park(Waiter* w) {
|
||||||
std::unique_lock<std::mutex> lock(w->mu);
|
std::unique_lock<std::mutex> lock(w->mu);
|
||||||
while (w->state != Waiter::kSignaled) {
|
while (w->state != Waiter::kSignaled) {
|
||||||
@ -210,10 +227,10 @@ class EventCount {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unpark(Waiter* waiters) {
|
void Unpark(Waiter* w) {
|
||||||
Waiter* next = nullptr;
|
for (Waiter* next; w; w = next) {
|
||||||
for (Waiter* w = waiters; w; w = next) {
|
uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
|
||||||
next = w->next.load(std::memory_order_relaxed);
|
next = wnext == kStackMask ? nullptr : &waiters_[wnext];
|
||||||
unsigned state;
|
unsigned state;
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(w->mu);
|
std::unique_lock<std::mutex> lock(w->mu);
|
||||||
|
@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
eigen_plain_assert(!t->f);
|
eigen_plain_assert(!t->f);
|
||||||
// We already did best-effort emptiness check in Steal, so prepare for
|
// We already did best-effort emptiness check in Steal, so prepare for
|
||||||
// blocking.
|
// blocking.
|
||||||
ec_.Prewait(waiter);
|
if (!ec_.Prewait()) return true;
|
||||||
// Now do a reliable emptiness check.
|
// Now do a reliable emptiness check.
|
||||||
int victim = NonEmptyQueueIndex();
|
int victim = NonEmptyQueueIndex();
|
||||||
if (victim != -1) {
|
if (victim != -1) {
|
||||||
ec_.CancelWait(waiter);
|
ec_.CancelWait();
|
||||||
if (cancelled_) {
|
if (cancelled_) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
blocked_++;
|
blocked_++;
|
||||||
// TODO is blocked_ required to be unsigned?
|
// TODO is blocked_ required to be unsigned?
|
||||||
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
|
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
|
||||||
ec_.CancelWait(waiter);
|
ec_.CancelWait();
|
||||||
// Almost done, but need to re-check queues.
|
// Almost done, but need to re-check queues.
|
||||||
// Consider that all queues are empty and all worker threads are preempted
|
// Consider that all queues are empty and all worker threads are preempted
|
||||||
// right after incrementing blocked_ above. Now a free-standing thread
|
// right after incrementing blocked_ above. Now a free-standing thread
|
||||||
|
@ -30,11 +30,11 @@ static void test_basic_eventcount()
|
|||||||
EventCount ec(waiters);
|
EventCount ec(waiters);
|
||||||
EventCount::Waiter& w = waiters[0];
|
EventCount::Waiter& w = waiters[0];
|
||||||
ec.Notify(false);
|
ec.Notify(false);
|
||||||
ec.Prewait(&w);
|
VERIFY(ec.Prewait());
|
||||||
ec.Notify(true);
|
ec.Notify(true);
|
||||||
ec.CommitWait(&w);
|
ec.CommitWait(&w);
|
||||||
ec.Prewait(&w);
|
VERIFY(ec.Prewait());
|
||||||
ec.CancelWait(&w);
|
ec.CancelWait();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fake bounded counter-based queue.
|
// Fake bounded counter-based queue.
|
||||||
@ -112,7 +112,7 @@ static void test_stress_eventcount()
|
|||||||
unsigned idx = rand_reentrant(&rnd) % kQueues;
|
unsigned idx = rand_reentrant(&rnd) % kQueues;
|
||||||
if (queues[idx].Pop()) continue;
|
if (queues[idx].Pop()) continue;
|
||||||
j--;
|
j--;
|
||||||
ec.Prewait(&w);
|
if (!ec.Prewait()) continue;
|
||||||
bool empty = true;
|
bool empty = true;
|
||||||
for (int q = 0; q < kQueues; q++) {
|
for (int q = 0; q < kQueues; q++) {
|
||||||
if (!queues[q].Empty()) {
|
if (!queues[q].Empty()) {
|
||||||
@ -121,7 +121,7 @@ static void test_stress_eventcount()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!empty) {
|
if (!empty) {
|
||||||
ec.CancelWait(&w);
|
ec.CancelWait();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ec.CommitWait(&w);
|
ec.CommitWait(&w);
|
||||||
|
Loading…
Reference in New Issue
Block a user