Convert StridedLinearBufferCopy::Kind to enum class

This commit is contained in:
Eugene Zhulenev 2020-01-13 11:43:24 -08:00
parent 5a8b97b401
commit b9362fb8f7

View File

@ -999,7 +999,7 @@ class StridedLinearBufferCopy {
public: public:
// Specifying linear copy kind statically gives ~30% speedup for small sizes. // Specifying linear copy kind statically gives ~30% speedup for small sizes.
enum Kind { enum class Kind {
Linear = 0, // src_stride == 1 && dst_stride == 1 Linear = 0, // src_stride == 1 && dst_stride == 1
Scatter = 1, // src_stride == 1 && dst_stride != 1 Scatter = 1, // src_stride == 1 && dst_stride != 1
FillLinear = 2, // src_stride == 0 && dst_stride == 1 FillLinear = 2, // src_stride == 0 && dst_stride == 1
@ -1053,7 +1053,7 @@ class StridedLinearBufferCopy {
const IndexType vectorized_size = count - PacketSize; const IndexType vectorized_size = count - PacketSize;
IndexType i = 0; IndexType i = 0;
if (kind == Linear) { if (kind == StridedLinearBufferCopy::Kind::Linear) {
// ******************************************************************** // // ******************************************************************** //
// Linear copy from `src` to `dst`. // Linear copy from `src` to `dst`.
const IndexType unrolled_size = count - 4 * PacketSize; const IndexType unrolled_size = count - 4 * PacketSize;
@ -1072,7 +1072,7 @@ class StridedLinearBufferCopy {
dst[i] = src[i]; dst[i] = src[i];
} }
// ******************************************************************** // // ******************************************************************** //
} else if (kind == Scatter) { } else if (kind == StridedLinearBufferCopy::Kind::Scatter) {
// Scatter from `src` to `dst`. // Scatter from `src` to `dst`.
eigen_assert(src_stride == 1 && dst_stride != 1); eigen_assert(src_stride == 1 && dst_stride != 1);
for (; i <= vectorized_size; i += PacketSize) { for (; i <= vectorized_size; i += PacketSize) {
@ -1083,7 +1083,7 @@ class StridedLinearBufferCopy {
dst[i * dst_stride] = src[i]; dst[i * dst_stride] = src[i];
} }
// ******************************************************************** // // ******************************************************************** //
} else if (kind == FillLinear) { } else if (kind == StridedLinearBufferCopy::Kind::FillLinear) {
// Fill `dst` with value at `*src`. // Fill `dst` with value at `*src`.
eigen_assert(src_stride == 0 && dst_stride == 1); eigen_assert(src_stride == 0 && dst_stride == 1);
const IndexType unrolled_size = count - 4 * PacketSize; const IndexType unrolled_size = count - 4 * PacketSize;
@ -1100,7 +1100,7 @@ class StridedLinearBufferCopy {
dst[i] = *src; dst[i] = *src;
} }
// ******************************************************************** // // ******************************************************************** //
} else if (kind == FillScatter) { } else if (kind == StridedLinearBufferCopy::Kind::FillScatter) {
// Scatter `*src` into `dst`. // Scatter `*src` into `dst`.
eigen_assert(src_stride == 0 && dst_stride != 1); eigen_assert(src_stride == 0 && dst_stride != 1);
Packet p = pload1<Packet>(src); Packet p = pload1<Packet>(src);
@ -1111,7 +1111,7 @@ class StridedLinearBufferCopy {
dst[i * dst_stride] = *src; dst[i * dst_stride] = *src;
} }
// ******************************************************************** // // ******************************************************************** //
} else if (kind == Gather) { } else if (kind == StridedLinearBufferCopy::Kind::Gather) {
// Gather from `src` into `dst`. // Gather from `src` into `dst`.
eigen_assert(dst_stride == 1); eigen_assert(dst_stride == 1);
for (; i <= vectorized_size; i += PacketSize) { for (; i <= vectorized_size; i += PacketSize) {
@ -1122,7 +1122,7 @@ class StridedLinearBufferCopy {
dst[i] = src[i * src_stride]; dst[i] = src[i * src_stride];
} }
// ******************************************************************** // // ******************************************************************** //
} else if (kind == Random) { } else if (kind == StridedLinearBufferCopy::Kind::Random) {
// Random. // Random.
for (; i < count; ++i) { for (; i < count; ++i) {
dst[i * dst_stride] = src[i * src_stride]; dst[i * dst_stride] = src[i * src_stride];
@ -1300,17 +1300,17 @@ class TensorBlockIO {
return num_copied; return num_copied;
if (input_stride == 1 && output_stride == 1) { if (input_stride == 1 && output_stride == 1) {
COPY_INNER_DIM(LinCopy::Linear); COPY_INNER_DIM(LinCopy::Kind::Linear);
} else if (input_stride == 1 && output_stride != 1) { } else if (input_stride == 1 && output_stride != 1) {
COPY_INNER_DIM(LinCopy::Scatter); COPY_INNER_DIM(LinCopy::Kind::Scatter);
} else if (input_stride == 0 && output_stride == 1) { } else if (input_stride == 0 && output_stride == 1) {
COPY_INNER_DIM(LinCopy::FillLinear); COPY_INNER_DIM(LinCopy::Kind::FillLinear);
} else if (input_stride == 0 && output_stride != 1) { } else if (input_stride == 0 && output_stride != 1) {
COPY_INNER_DIM(LinCopy::FillScatter); COPY_INNER_DIM(LinCopy::Kind::FillScatter);
} else if (output_stride == 1) { } else if (output_stride == 1) {
COPY_INNER_DIM(LinCopy::Gather); COPY_INNER_DIM(LinCopy::Kind::Gather);
} else { } else {
COPY_INNER_DIM(LinCopy::Random); COPY_INNER_DIM(LinCopy::Kind::Random);
} }
#undef COPY_INNER_DIM #undef COPY_INNER_DIM