mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-03-07 18:27:40 +08:00
Sped up the assignment of a tensor to a tensor slice, as well as the assigment of a constant slice to a tensor
This commit is contained in:
parent
43eb2ca6e1
commit
10a1f81822
@ -132,13 +132,20 @@ struct TensorEvaluator<const Derived, Device>
|
||||
CoordAccess = NumCoords > 0,
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&)
|
||||
: m_data(m.data()), m_dims(m.dimensions())
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
|
||||
: m_data(m.data()), m_dims(m.dimensions()), m_device(device)
|
||||
{ }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
|
||||
if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data) {
|
||||
m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||
@ -172,6 +179,7 @@ struct TensorEvaluator<const Derived, Device>
|
||||
protected:
|
||||
const Scalar* m_data;
|
||||
Dimensions m_dims;
|
||||
const Device& m_device;
|
||||
};
|
||||
|
||||
|
||||
|
@ -346,7 +346,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
|
||||
m_impl.evalSubExprsIfNeeded(NULL);
|
||||
if (internal::is_arithmetic<Scalar>::value && data && m_impl.data()) {
|
||||
if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data && m_impl.data()) {
|
||||
Index contiguous_values = 1;
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
|
Loading…
Reference in New Issue
Block a user