mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Moved the MemCopyFunctor back to TensorSyclDevice since it's the only caller and it makes TensorFlow compile again
This commit is contained in:
parent
fca27350eb
commit
e073de96dc
@ -170,6 +170,29 @@ struct SyclDevice {
|
||||
// some runtime conditions that can be applied here
|
||||
EIGEN_STRONG_INLINE bool isDeviceSuitable() const { return true; }
|
||||
|
||||
template <typename T> class MemCopyFunctor {
|
||||
public:
|
||||
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read, cl::sycl::access::target::global_buffer> read_accessor;
|
||||
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer> write_accessor;
|
||||
|
||||
MemCopyFunctor(read_accessor src_acc, write_accessor dst_acc, size_t rng, size_t i, size_t offset): m_src_acc(src_acc), m_dst_acc(dst_acc), m_rng(rng), m_i(i), m_offset(offset) {}
|
||||
|
||||
void operator()(cl::sycl::nd_item<1> itemID) {
|
||||
auto src_ptr = ConvertToActualTypeSycl(T, m_src_acc);
|
||||
auto dst_ptr = ConvertToActualTypeSycl(T, m_dst_acc);
|
||||
auto globalid = itemID.get_global_linear_id();
|
||||
if (globalid < m_rng) {
|
||||
dst_ptr[globalid + m_i] = src_ptr[globalid + m_offset];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
read_accessor m_src_acc;
|
||||
write_accessor m_dst_acc;
|
||||
size_t m_rng;
|
||||
size_t m_i;
|
||||
size_t m_offset;
|
||||
};
|
||||
|
||||
/// the memcpy function
|
||||
template<typename T> EIGEN_STRONG_INLINE void memcpy(void *dst, const T *src, size_t n) const {
|
||||
@ -184,7 +207,7 @@ struct SyclDevice {
|
||||
sycl_queue().submit([&](cl::sycl::handler &cgh) {
|
||||
auto src_acc =it1->second.template get_access<cl::sycl::access::mode::read, cl::sycl::access::target::global_buffer>(cgh);
|
||||
auto dst_acc =it2->second.template get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer>(cgh);
|
||||
cgh.parallel_for(cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), TensorSycl::internal::MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, offset));
|
||||
cgh.parallel_for(cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, offset));
|
||||
});
|
||||
synchronize();
|
||||
}
|
||||
@ -215,7 +238,7 @@ struct SyclDevice {
|
||||
sycl_queue().submit([&](cl::sycl::handler &cgh) {
|
||||
auto src_acc= it->second.template get_access<cl::sycl::access::mode::read, cl::sycl::access::target::global_buffer>(cgh);
|
||||
auto dst_acc =dest_buf.template get_access<cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer>(cgh);
|
||||
cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), TensorSycl::internal::MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, offset));
|
||||
cgh.parallel_for( cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)), MemCopyFunctor<T>(src_acc, dst_acc, rng, 0, offset));
|
||||
});
|
||||
synchronize();
|
||||
}
|
||||
|
@ -55,27 +55,6 @@ template < typename HostExpr, typename PlaceHolderExpr, typename FunctorExpr, ty
|
||||
Index range;
|
||||
};
|
||||
|
||||
/// Memcopyfuncdeveicetohost
|
||||
template <typename T> class MemCopyFunctor {
|
||||
public:
|
||||
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read, cl::sycl::access::target::global_buffer> read_accessor;
|
||||
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::discard_write, cl::sycl::access::target::global_buffer> write_accessor;
|
||||
MemCopyFunctor(read_accessor src_acc, write_accessor dst_acc, size_t rng, size_t i, size_t offset): m_src_acc(src_acc), m_dst_acc(dst_acc), m_rng(rng), m_i(i), m_offset(offset) {}
|
||||
void operator()(cl::sycl::nd_item<1> itemID) {
|
||||
auto src_ptr = ConvertToActualTypeSycl(T, m_src_acc);
|
||||
auto dst_ptr = ConvertToActualTypeSycl(T, m_dst_acc);
|
||||
auto globalid = itemID.get_global_linear_id();
|
||||
if (globalid < m_rng) {
|
||||
dst_ptr[globalid + m_i] = src_ptr[globalid + m_offset];
|
||||
}
|
||||
}
|
||||
private:
|
||||
read_accessor m_src_acc;
|
||||
write_accessor m_dst_acc;
|
||||
size_t m_rng;
|
||||
size_t m_i;
|
||||
size_t m_offset;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user