Optimize the case where broadcasting is a no-op.

This commit is contained in:
Rasmus Munk Larsen 2018-07-13 16:12:38 -07:00
parent 4a3952fd55
commit 4222550e17

View File

@ -105,7 +105,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
bool nByOne = false, oneByN = false;
bool isCopy= false, nByOne = false, oneByN = false;
enum {
IsAligned = true,
@ -123,9 +123,13 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
const InputDimensions& input_dims = m_impl.dimensions();
const Broadcast& broadcast = op.broadcast();
isCopy = true;
for (int i = 0; i < NumDims; ++i) {
eigen_assert(input_dims[i] > 0);
m_dimensions[i] = input_dims[i] * broadcast[i];
m_dimensions[i] = input_dims[i] * m_broadcast[i];
if (m_broadcast[i] != 1) {
isCopy = false;
}
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
@ -197,9 +201,17 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
return coeffColMajor(index);
if (isCopy) {
return m_impl.coeff(index);
} else {
return coeffColMajor(index);
}
} else {
return coeffRowMajor(index);
if (isCopy) {
return m_impl.coeff(index);
} else {
return coeffRowMajor(index);
}
}
}
@ -272,7 +284,9 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
if (oneByN && !nByOne) {
if (isCopy) {
return m_impl.template packet<LoadMode>(index);
} else if (oneByN && !nByOne) {
return packetNByOne<LoadMode>(index);
} else if (!oneByN && nByOne) {
return packetOneByN<LoadMode>(index);
@ -282,7 +296,9 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
return packetColMajor<LoadMode>(index);
}
} else {
if (oneByN && !nByOne) {
if (isCopy) {
return m_impl.template packet<LoadMode>(index);
} else if (oneByN && !nByOne) {
return packetOneByN<LoadMode>(index);
} else if (!oneByN && nByOne) {
return packetNByOne<LoadMode>(index);
@ -516,7 +532,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeff(bool vectorized) const {
double compute_cost = TensorOpCost::AddCost<Index>();
if (NumDims > 0) {
if (!isCopy && NumDims > 0) {
for (int i = NumDims - 1; i > 0; --i) {
compute_cost += TensorOpCost::DivCost<Index>();
if (internal::index_statically_eq<Broadcast>(i, 1)) {