Get rid of void* casting when calling EvalRange::run.

This commit is contained in:
Rasmus Munk Larsen 2016-04-15 12:51:33 -07:00
parent 07ac4f7e02
commit 3718bf654b

View File

@ -85,8 +85,8 @@ class TensorExecutor<Expression, DefaultDevice, true>
#ifdef EIGEN_USE_THREADS
template <typename Evaluator, typename Index, bool Vectorizable>
struct EvalRange {
static void run(void* evaluator_in, const Index first, const Index last) {
Evaluator evaluator(*static_cast<Evaluator*>(evaluator_in));
static void run(Evaluator* evaluator_in, const Index first, const Index last) {
Evaluator evaluator = *evaluator_in;
eigen_assert(last >= first);
for (Index i = first; i < last; ++i) {
evaluator.evalScalar(i);
@ -96,10 +96,9 @@ struct EvalRange {
template <typename Evaluator, typename Index>
struct EvalRange<Evaluator, Index, true> {
static void run(void* evaluator_in, const Index first, const Index last) {
Evaluator evaluator(*static_cast<Evaluator*>(evaluator_in));
static void run(Evaluator* evaluator_in, const Index first, const Index last) {
Evaluator evaluator = *evaluator_in;
eigen_assert(last >= first);
Index i = first;
const int PacketSize = unpacket_traits<typename Evaluator::PacketReturnType>::size;
if (last - first >= PacketSize) {
@ -123,16 +122,6 @@ struct EvalRange<Evaluator, Index, true> {
}
};
// Used to make an std::function to add to the ThreadPool with less templating
// than EvalRange::Run.
// This requires that this and EvalRange takes a void* to the evaluator that can
// be downcast to the right type by the EvalRange.
template <typename Index>
inline void InvokeEvalRange(void (*run_fn)(void*, const Index, const Index),
void* evaluator, const Index first, const Index last) {
run_fn(evaluator, first, last);
}
template <typename Expression, bool Vectorizable>
class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable> {
public:
@ -163,13 +152,12 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable> {
Barrier barrier(numblocks);
for (int i = 0; i < numblocks; ++i) {
device.enqueue_with_barrier(
&barrier, &InvokeEvalRange<Index>,
&EvalRange<Evaluator, Index, Vectorizable>::run,
static_cast<void*>(&evaluator), i * blocksize,
(i + 1) * blocksize);
&barrier, &EvalRange<Evaluator, Index, Vectorizable>::run,
&evaluator, i * blocksize, (i + 1) * blocksize);
}
if (numblocks * blocksize < size) {
EvalRange<Evaluator, Index, Vectorizable>::run(&evaluator, numblocks * blocksize, size);
EvalRange<Evaluator, Index, Vectorizable>::run(
&evaluator, numblocks * blocksize, size);
}
barrier.Wait();
}