mirror of
https://gitlab.com/libeigen/eigen.git
synced 2024-12-21 07:19:46 +08:00
Get rid of void* casting when calling EvalRange::run.
This commit is contained in:
parent
07ac4f7e02
commit
3718bf654b
@ -85,8 +85,8 @@ class TensorExecutor<Expression, DefaultDevice, true>
|
||||
#ifdef EIGEN_USE_THREADS
|
||||
template <typename Evaluator, typename Index, bool Vectorizable>
|
||||
struct EvalRange {
|
||||
static void run(void* evaluator_in, const Index first, const Index last) {
|
||||
Evaluator evaluator(*static_cast<Evaluator*>(evaluator_in));
|
||||
static void run(Evaluator* evaluator_in, const Index first, const Index last) {
|
||||
Evaluator evaluator = *evaluator_in;
|
||||
eigen_assert(last >= first);
|
||||
for (Index i = first; i < last; ++i) {
|
||||
evaluator.evalScalar(i);
|
||||
@ -96,10 +96,9 @@ struct EvalRange {
|
||||
|
||||
template <typename Evaluator, typename Index>
|
||||
struct EvalRange<Evaluator, Index, true> {
|
||||
static void run(void* evaluator_in, const Index first, const Index last) {
|
||||
Evaluator evaluator(*static_cast<Evaluator*>(evaluator_in));
|
||||
static void run(Evaluator* evaluator_in, const Index first, const Index last) {
|
||||
Evaluator evaluator = *evaluator_in;
|
||||
eigen_assert(last >= first);
|
||||
|
||||
Index i = first;
|
||||
const int PacketSize = unpacket_traits<typename Evaluator::PacketReturnType>::size;
|
||||
if (last - first >= PacketSize) {
|
||||
@ -123,16 +122,6 @@ struct EvalRange<Evaluator, Index, true> {
|
||||
}
|
||||
};
|
||||
|
||||
// Used to make an std::function to add to the ThreadPool with less templating
|
||||
// than EvalRange::Run.
|
||||
// This requires that this and EvalRange takes a void* to the evaluator that can
|
||||
// be downcast to the right type by the EvalRange.
|
||||
template <typename Index>
|
||||
inline void InvokeEvalRange(void (*run_fn)(void*, const Index, const Index),
|
||||
void* evaluator, const Index first, const Index last) {
|
||||
run_fn(evaluator, first, last);
|
||||
}
|
||||
|
||||
template <typename Expression, bool Vectorizable>
|
||||
class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable> {
|
||||
public:
|
||||
@ -163,13 +152,12 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable> {
|
||||
Barrier barrier(numblocks);
|
||||
for (int i = 0; i < numblocks; ++i) {
|
||||
device.enqueue_with_barrier(
|
||||
&barrier, &InvokeEvalRange<Index>,
|
||||
&EvalRange<Evaluator, Index, Vectorizable>::run,
|
||||
static_cast<void*>(&evaluator), i * blocksize,
|
||||
(i + 1) * blocksize);
|
||||
&barrier, &EvalRange<Evaluator, Index, Vectorizable>::run,
|
||||
&evaluator, i * blocksize, (i + 1) * blocksize);
|
||||
}
|
||||
if (numblocks * blocksize < size) {
|
||||
EvalRange<Evaluator, Index, Vectorizable>::run(&evaluator, numblocks * blocksize, size);
|
||||
EvalRange<Evaluator, Index, Vectorizable>::run(
|
||||
&evaluator, numblocks * blocksize, size);
|
||||
}
|
||||
barrier.Wait();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user