Updated the convolution and contraction evaluators to follow the new EvalSubExprsIfNeeded apu.

This commit is contained in:
Benoit Steiner 2014-08-13 08:33:18 -07:00
parent 72e7529708
commit f8fad09301
2 changed files with 10 additions and 8 deletions

View File

@ -184,9 +184,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
buffer[i] += coeff(i);
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_leftImpl.evalSubExprsIfNeeded();
m_rightImpl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded(Scalar*) {
m_leftImpl.evalSubExprsIfNeeded(NULL);
m_rightImpl.evalSubExprsIfNeeded(NULL);
return true;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_leftImpl.cleanup();

View File

@ -110,7 +110,7 @@ struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelAr
false,
};
TensorEvaluator(const XprType& op, const Device& device)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_inputImpl(op.inputExpression(), device), m_kernelImpl(op.kernelExpression(), device), m_dimensions(op.inputExpression().dimensions())
{
const typename TensorEvaluator<InputArgType, Device>::Dimensions& input_dims = m_inputImpl.dimensions();
@ -151,11 +151,12 @@ struct TensorEvaluator<const TensorConvolutionOp<Indices, InputArgType, KernelAr
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() {
m_inputImpl.evalSubExprsIfNeeded();
m_kernelImpl.evalSubExprsIfNeeded();
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded(Scalar*) {
m_inputImpl.evalSubExprsIfNeeded(NULL);
m_kernelImpl.evalSubExprsIfNeeded(NULL);
return true;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
m_inputImpl.cleanup();