mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-01-30 17:40:05 +08:00
TensorBroadcasting support for random/uniform blocks
This commit is contained in:
parent
d380c23b2c
commit
02431cbe71
@ -45,7 +45,7 @@ EIGEN_ALWAYS_INLINE DSizes<IndexType, NumDims> strides(
|
||||
return strides;
|
||||
}
|
||||
|
||||
template<int Layout, typename IndexType, size_t NumDims>
|
||||
template <int Layout, typename IndexType, size_t NumDims>
|
||||
EIGEN_ALWAYS_INLINE DSizes<IndexType, NumDims> strides(
|
||||
const Eigen::array<IndexType, NumDims>& dimensions) {
|
||||
return strides<Layout>(DSizes<IndexType, NumDims>(dimensions));
|
||||
@ -121,7 +121,7 @@ class TensorBlockDescriptor {
|
||||
// Compare strides ignoring dimensions of size `1`.
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
if (desc_dims[i] == 1) continue;
|
||||
if (desc_strides[i] != dst_strides[i]) return false;
|
||||
if (desc_strides[i] != dst_strides[i]) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -507,8 +507,8 @@ class TensorCwiseUnaryBlock {
|
||||
public:
|
||||
typedef typename conditional<
|
||||
NoArgBlockAccess, void,
|
||||
TensorCwiseUnaryOp<UnaryOp, const typename ArgTensorBlock::XprType> >::type
|
||||
XprType;
|
||||
TensorCwiseUnaryOp<UnaryOp, const typename ArgTensorBlock::XprType> >::
|
||||
type XprType;
|
||||
|
||||
typedef typename XprScalar<XprType>::type Scalar;
|
||||
|
||||
@ -854,12 +854,13 @@ class TensorBlockIOV2 {
|
||||
//
|
||||
// src_dimension_index = dst_to_src_dim_map[dst_dimension_index]
|
||||
//
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Copy(
|
||||
// Returns the number of copied elements.
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType Copy(
|
||||
const Dst& dst, const Src& src, const DimensionsMap& dst_to_src_dim_map) {
|
||||
// Copy single scalar value from `src` to `dst`.
|
||||
if (NumDims == 0) {
|
||||
*(dst.data + dst.offset) = *(src.data + src.offset);
|
||||
return;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Both `dst` and `src` must have contiguous innermost dimension. We also
|
||||
@ -898,13 +899,13 @@ class TensorBlockIOV2 {
|
||||
// If all dimensions are of size 1, just copy a scalar from `src` to `dst`.
|
||||
if (num_size_one_inner_dims == NumDims) {
|
||||
*(dst.data + dst.offset) = *(src.data + src.offset);
|
||||
return;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Outermost dimension in the dst with `stride == 1` (contiguous in memory).
|
||||
const int dst_stride1_dim =
|
||||
IsColMajor ? num_size_one_inner_dims
|
||||
: NumDims - num_size_one_inner_dims - 1;
|
||||
const int dst_stride1_dim = IsColMajor
|
||||
? num_size_one_inner_dims
|
||||
: NumDims - num_size_one_inner_dims - 1;
|
||||
|
||||
// Dimension in the src that corresponds to the dst innermost dimension.
|
||||
const int src_dim_for_dst_stride1_dim =
|
||||
@ -956,24 +957,27 @@ class TensorBlockIOV2 {
|
||||
// Iterate copying data from src to dst.
|
||||
const IndexType block_total_size = NumDims == 0 ? 1 : dst.dims.TotalSize();
|
||||
|
||||
#define COPY_INNER_DIM(KIND) \
|
||||
for (IndexType i = 0; i < block_total_size; i += dst_inner_dim_size) { \
|
||||
LinCopy::template Run<KIND>( \
|
||||
typename LinCopy::Dst(output_offset, output_stride, dst.data), \
|
||||
typename LinCopy::Src(input_offset, input_stride, src.data), \
|
||||
dst_inner_dim_size); \
|
||||
\
|
||||
for (int j = 0; j < idx; ++j) { \
|
||||
if (++it[j].count < it[j].size) { \
|
||||
input_offset += it[j].input_stride; \
|
||||
output_offset += it[j].output_stride; \
|
||||
break; \
|
||||
} \
|
||||
it[j].count = 0; \
|
||||
input_offset -= it[j].input_span; \
|
||||
output_offset -= it[j].output_span; \
|
||||
} \
|
||||
}
|
||||
#define COPY_INNER_DIM(KIND) \
|
||||
IndexType num_copied = 0; \
|
||||
for (num_copied = 0; num_copied < block_total_size; \
|
||||
num_copied += dst_inner_dim_size) { \
|
||||
LinCopy::template Run<KIND>( \
|
||||
typename LinCopy::Dst(output_offset, output_stride, dst.data), \
|
||||
typename LinCopy::Src(input_offset, input_stride, src.data), \
|
||||
dst_inner_dim_size); \
|
||||
\
|
||||
for (int j = 0; j < idx; ++j) { \
|
||||
if (++it[j].count < it[j].size) { \
|
||||
input_offset += it[j].input_stride; \
|
||||
output_offset += it[j].output_stride; \
|
||||
break; \
|
||||
} \
|
||||
it[j].count = 0; \
|
||||
input_offset -= it[j].input_span; \
|
||||
output_offset -= it[j].output_span; \
|
||||
} \
|
||||
} \
|
||||
return num_copied;
|
||||
|
||||
if (input_stride == 1 && output_stride == 1) {
|
||||
COPY_INNER_DIM(LinCopy::Linear);
|
||||
@ -992,12 +996,13 @@ class TensorBlockIOV2 {
|
||||
#undef COPY_INNER_DIM
|
||||
}
|
||||
|
||||
// Copy from `src` to `dst` with an identity src->dst dimension map.
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void Copy(const Dst& dst,
|
||||
const Src& src) {
|
||||
// Copy from `src` to `dst` with an identity src->dst dimension map. Returns
|
||||
// the number of copied elements.
|
||||
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexType Copy(const Dst& dst,
|
||||
const Src& src) {
|
||||
DimensionsMap dst_to_src_map;
|
||||
for (int i = 0; i < NumDims; ++i) dst_to_src_map[i] = i;
|
||||
Copy(dst, src, dst_to_src_map);
|
||||
return Copy(dst, src, dst_to_src_map);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -884,60 +884,187 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
|
||||
blockV2(TensorBlockDesc& desc, TensorBlockScratch& scratch,
|
||||
bool /*root_of_expr_ast*/ = false) const {
|
||||
static const bool
|
||||
is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
|
||||
BlockBroadcastingParams params = blockBroadcastingParams(desc);
|
||||
|
||||
// Return a block with a single scalar.
|
||||
if (NumDims <= 0) return scalarBlock(scratch);
|
||||
|
||||
// Because we only support kSkewedInnerDims blocking, block size should be
|
||||
// equal to m_dimensions for inner dims, a smaller than m_dimensions[i] size
|
||||
// for the first outer dim, and 1 for other outer dims. This is guaranteed
|
||||
// by MergeResourceRequirements() in TensorBlock.h.
|
||||
const Dimensions& output_dims = desc.dimensions();
|
||||
const Dimensions output_strides = internal::strides<Layout>(output_dims);
|
||||
|
||||
// Find where outer dims start.
|
||||
int outer_dim_start = 0;
|
||||
Index outer_dim_size = 1;
|
||||
Index inner_dim_size = 1;
|
||||
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
const int dim = is_col_major ? i : NumDims - i - 1;
|
||||
|
||||
if (i > outer_dim_start) {
|
||||
eigen_assert(output_dims[dim] == 1);
|
||||
} else if (output_dims[dim] != m_dimensions[dim]) {
|
||||
eigen_assert(output_dims[dim] < m_dimensions[dim]);
|
||||
outer_dim_size = output_dims[dim];
|
||||
} else {
|
||||
inner_dim_size *= output_dims[dim];
|
||||
++outer_dim_start;
|
||||
}
|
||||
}
|
||||
|
||||
if (inner_dim_size == 0 || outer_dim_size == 0) {
|
||||
if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
|
||||
return emptyBlock();
|
||||
}
|
||||
|
||||
const Dimensions& input_dims = Dimensions(m_impl.dimensions());
|
||||
// Check if we can reuse `desc` destination, or allocate new scratch buffer.
|
||||
ScalarNoConst* materialized_output =
|
||||
desc.template destination<ScalarNoConst, Layout>();
|
||||
bool materialized_in_output;
|
||||
|
||||
// Pre-fill input_block_sizes, broadcast_block_sizes,
|
||||
// broadcast_block_strides, and broadcast_tensor_strides. Later on we will
|
||||
// only modify the outer_dim_start-th dimension on these arrays.
|
||||
if (materialized_output != NULL) {
|
||||
desc.DropDestinationBuffer();
|
||||
materialized_in_output = true;
|
||||
|
||||
} else {
|
||||
materialized_in_output = false;
|
||||
const size_t materialized_output_size = desc.size() * sizeof(Scalar);
|
||||
void* output_scratch_mem = scratch.allocate(materialized_output_size);
|
||||
materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
|
||||
}
|
||||
|
||||
ScalarNoConst* materialized_input = NULL;
|
||||
size_t materialized_input_size = 0;
|
||||
|
||||
// Initialize block broadcating iterator state for outer dimensions (outer
|
||||
// with regard to bcast dimension). Dimension in this array are always in
|
||||
// inner_most -> outer_most order (col major layout).
|
||||
array<BlockBroadcastingIteratorState, NumDims> it;
|
||||
int idx = 0;
|
||||
|
||||
for (int i = params.inner_dim_count + 1; i < NumDims; ++i) {
|
||||
const Index dim = IsColMajor ? i : NumDims - 1 - i;
|
||||
it[idx].size = params.output_dims[dim];
|
||||
it[idx].count = 0;
|
||||
it[idx].output_stride = m_outputStrides[dim];
|
||||
it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
|
||||
idx++;
|
||||
}
|
||||
|
||||
// Write output into the beginning of `materialized_output`.
|
||||
Index output_offset = 0;
|
||||
|
||||
// We will fill output block by broadcasting along the bcast dim, and
|
||||
// iterating over outer dimension.
|
||||
const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
|
||||
|
||||
for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
|
||||
ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
|
||||
Index bcast_offset = desc.offset() + output_offset;
|
||||
|
||||
// Broadcast along the bcast dimension.
|
||||
num_output_coeffs += BroadcastBlockAlongBcastDim(
|
||||
params, bcast_offset, scratch, bcast_output, &materialized_input,
|
||||
&materialized_input_size);
|
||||
|
||||
// Switch to the next outer dimension.
|
||||
for (int j = 0; j < idx; ++j) {
|
||||
if (++it[j].count < it[j].size) {
|
||||
output_offset += it[j].output_stride;
|
||||
break;
|
||||
}
|
||||
it[j].count = 0;
|
||||
output_offset -= it[j].output_span;
|
||||
}
|
||||
}
|
||||
|
||||
return TensorBlockV2(
|
||||
materialized_in_output
|
||||
? internal::TensorBlockKind::kMaterializedInOutput
|
||||
: internal::TensorBlockKind::kMaterializedInScratch,
|
||||
materialized_output, desc.dimensions());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
|
||||
|
||||
const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
|
||||
|
||||
Broadcast functor() const { return m_broadcast; }
|
||||
#ifdef EIGEN_USE_SYCL
|
||||
// binding placeholder accessors to a command group handler for SYCL
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
|
||||
cl::sycl::handler& cgh) const {
|
||||
m_impl.bind(cgh);
|
||||
}
|
||||
#endif
|
||||
private:
|
||||
static const bool IsColMajor =
|
||||
static_cast<int>(Layout) == static_cast<int>(ColMajor);
|
||||
|
||||
// We will build a general case block broadcasting on top of broadcasting
|
||||
// primitive that will do broadcasting only for the inner dimension(s) along
|
||||
// the first dimension smaller than the input size (it's called `bcast_dim`).
|
||||
//
|
||||
// Example:
|
||||
// dim: 0 1 2 (ColMajor)
|
||||
// input size: [9, 3, 6]
|
||||
// block size: [9, 2, 6]
|
||||
//
|
||||
// We will compute broadcasted block by iterating over the outer dimensions
|
||||
// before `bcast_dim` (only dimension `2` in this example) and computing
|
||||
// broadcasts along the `bcast_dim` (dimension `1` in this example).
|
||||
|
||||
// BlockBroadcastingParams holds precomputed parameters for broadcasting a
|
||||
// single block along the broadcasting dimension. Sizes and strides along the
|
||||
// `bcast_dim` might be invalid, they will be adjusted later in
|
||||
// `BroadcastBlockAlongBcastDim`.
|
||||
struct BlockBroadcastingParams {
|
||||
Dimensions input_dims; // input expression dimensions
|
||||
Dimensions output_dims; // output block sizes
|
||||
Dimensions output_strides; // output block strides
|
||||
|
||||
int inner_dim_count; // count inner dimensions matching in size
|
||||
int bcast_dim; // broadcasting dimension index
|
||||
Index bcast_dim_size; // broadcasting dimension size
|
||||
Index inner_dim_size; // inner dimensions size
|
||||
|
||||
// Block sizes and strides for the input block where all dimensions before
|
||||
// `bcast_dim` are equal to `1`.
|
||||
Dimensions input_block_sizes;
|
||||
Dimensions input_block_strides;
|
||||
|
||||
// Block sizes and strides for blocks with extra dimensions and strides `0`.
|
||||
BroadcastDimensions bcast_block_sizes;
|
||||
BroadcastDimensions bcast_block_strides;
|
||||
BroadcastDimensions bcast_input_strides;
|
||||
};
|
||||
|
||||
struct BlockBroadcastingIteratorState {
|
||||
Index size;
|
||||
Index count;
|
||||
Index output_stride;
|
||||
Index output_span;
|
||||
};
|
||||
|
||||
BlockBroadcastingParams blockBroadcastingParams(TensorBlockDesc& desc) const {
|
||||
BlockBroadcastingParams params;
|
||||
|
||||
params.input_dims = Dimensions(m_impl.dimensions());
|
||||
|
||||
// Output block sizes and strides.
|
||||
params.output_dims = desc.dimensions();
|
||||
params.output_strides = internal::strides<Layout>(params.output_dims);
|
||||
|
||||
// Find the broadcasting dimension (first dimension with output size smaller
|
||||
// that the input size).
|
||||
params.bcast_dim = 0;
|
||||
params.bcast_dim_size = 1;
|
||||
params.inner_dim_size = 1;
|
||||
|
||||
// Count the number of inner dimensions that have the same size in the block
|
||||
// and in the broadcast expression.
|
||||
params.inner_dim_count = 0;
|
||||
|
||||
for (int i = 0; i < NumDims; ++i) {
|
||||
const int dim = IsColMajor ? i : NumDims - i - 1;
|
||||
|
||||
if (params.output_dims[dim] == m_dimensions[dim]) {
|
||||
params.inner_dim_size *= params.output_dims[dim];
|
||||
++params.inner_dim_count;
|
||||
continue;
|
||||
}
|
||||
|
||||
// First non-matching dimension is the broadcasting dimension.
|
||||
eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
|
||||
params.bcast_dim = dim;
|
||||
params.bcast_dim_size = params.output_dims[dim];
|
||||
break;
|
||||
}
|
||||
|
||||
// Calculate the input block size for looking into the input.
|
||||
Dimensions input_block_sizes;
|
||||
for (int i = 0; i < outer_dim_start; ++i) {
|
||||
const int dim = is_col_major ? i : NumDims -i - 1;
|
||||
input_block_sizes[dim] = input_dims[dim];
|
||||
for (int i = 0; i < params.inner_dim_count; ++i) {
|
||||
const int dim = IsColMajor ? i : NumDims - i - 1;
|
||||
params.input_block_sizes[dim] = params.input_dims[dim];
|
||||
}
|
||||
for (int i = outer_dim_start; i < NumDims; ++i) {
|
||||
const int dim = is_col_major ? i : NumDims -i - 1;
|
||||
input_block_sizes[dim] = 1;
|
||||
for (int i = params.inner_dim_count; i < NumDims; ++i) {
|
||||
const int dim = IsColMajor ? i : NumDims - i - 1;
|
||||
params.input_block_sizes[dim] = 1;
|
||||
}
|
||||
Dimensions input_block_strides =
|
||||
internal::strides<Layout>(input_block_sizes);
|
||||
params.input_block_strides =
|
||||
internal::strides<Layout>(params.input_block_sizes);
|
||||
|
||||
// Broadcast with the 0-stride trick: Create 1 extra dim for each
|
||||
// broadcast, set the input stride to 0.
|
||||
@ -957,229 +1084,31 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
// input_block_strides[1], 0,
|
||||
// ...].
|
||||
//
|
||||
BroadcastDimensions bcast_block_sizes;
|
||||
BroadcastDimensions bcast_block_strides;
|
||||
BroadcastDimensions bcast_input_strides;
|
||||
for (int i = 0; i < params.inner_dim_count; ++i) {
|
||||
const int dim = IsColMajor ? i : NumDims - i - 1;
|
||||
|
||||
for (int i = 0; i < outer_dim_start; ++i) {
|
||||
const int dim = is_col_major ? i : NumDims - i - 1;
|
||||
const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
|
||||
const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
|
||||
|
||||
const int copy_dim = is_col_major ? 2 * i : 2 * NumDims - 2 * i - 1;
|
||||
const int broadcast_dim = is_col_major ? copy_dim + 1 : copy_dim - 1;
|
||||
|
||||
bcast_block_sizes[copy_dim] = input_dims[dim];
|
||||
bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
|
||||
bcast_block_strides[copy_dim] = output_strides[dim];
|
||||
bcast_block_strides[broadcast_dim] =
|
||||
output_strides[dim] * input_dims[dim];
|
||||
bcast_input_strides[copy_dim] = input_block_strides[dim];
|
||||
bcast_input_strides[broadcast_dim] = 0;
|
||||
}
|
||||
for (int i = 2 * outer_dim_start; i < 2 * NumDims; ++i) {
|
||||
const int dim = is_col_major ? i : 2 * NumDims - i - 1;
|
||||
bcast_block_sizes[dim] = 1;
|
||||
bcast_block_strides[dim] = 0;
|
||||
bcast_input_strides[dim] = 0;
|
||||
params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
|
||||
params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
|
||||
params.bcast_block_strides[copy_dim] = params.output_strides[dim];
|
||||
params.bcast_block_strides[broadcast_dim] =
|
||||
params.output_strides[dim] * params.input_dims[dim];
|
||||
params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
|
||||
params.bcast_input_strides[broadcast_dim] = 0;
|
||||
}
|
||||
|
||||
const int outer_dim =
|
||||
is_col_major ? outer_dim_start : NumDims - outer_dim_start - 1;
|
||||
|
||||
// Check if we can reuse `desc` destination, or allocate new scratch buffer.
|
||||
ScalarNoConst* materialized_output =
|
||||
desc.template destination<ScalarNoConst, Layout>();
|
||||
bool materialized_in_output;
|
||||
|
||||
if (materialized_output != NULL) {
|
||||
desc.DropDestinationBuffer();
|
||||
materialized_in_output = true;
|
||||
|
||||
} else {
|
||||
materialized_in_output = false;
|
||||
const size_t materialized_output_size = desc.size() * sizeof(Scalar);
|
||||
void* output_scratch_mem = scratch.allocate(materialized_output_size);
|
||||
materialized_output = static_cast<ScalarNoConst*>(output_scratch_mem);
|
||||
for (int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
|
||||
const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
|
||||
params.bcast_block_sizes[dim] = 1;
|
||||
params.bcast_block_strides[dim] = 0;
|
||||
params.bcast_input_strides[dim] = 0;
|
||||
}
|
||||
|
||||
size_t materialized_input_size = 0;
|
||||
ScalarNoConst* materialized_input = NULL;
|
||||
|
||||
if (outer_dim_size == 1) {
|
||||
// We just need one block read using the ready-set values above.
|
||||
BroadcastBlockV2(
|
||||
input_block_sizes, input_block_strides, bcast_block_sizes,
|
||||
bcast_block_strides, bcast_input_strides, 0, desc, scratch,
|
||||
materialized_output, &materialized_input, &materialized_input_size);
|
||||
|
||||
} else if (input_dims[outer_dim] == 1) {
|
||||
// Broadcast outer_dim_start-th dimension (< NumDims) by outer_dim_size.
|
||||
const int broadcast_outer_dim =
|
||||
is_col_major ? 2 * outer_dim_start + 1
|
||||
: 2 * NumDims - 2 * outer_dim_start - 2;
|
||||
|
||||
bcast_block_sizes[broadcast_outer_dim] = outer_dim_size;
|
||||
bcast_input_strides[broadcast_outer_dim] = 0;
|
||||
bcast_block_strides[broadcast_outer_dim] = output_strides[outer_dim];
|
||||
|
||||
BroadcastBlockV2(
|
||||
input_block_sizes, input_block_strides, bcast_block_sizes,
|
||||
bcast_block_strides, bcast_input_strides, 0, desc, scratch,
|
||||
materialized_output, &materialized_input, &materialized_input_size);
|
||||
|
||||
} else {
|
||||
// The general case. Let's denote the output block as x[...,
|
||||
// a:a+outer_dim_size, :, ..., :], where a:a+outer_dim_size is a slice on
|
||||
// the outer_dim_start-th dimension (< NumDims). We need to split the
|
||||
// a:a+outer_dim_size into possibly 3 sub-blocks:
|
||||
//
|
||||
// (1) a:b, where b is the smallest multiple of
|
||||
// input_dims[outer_dim_start] in [a, a+outer_dim_size].
|
||||
//
|
||||
// (2) b:c, where c is the largest multiple of input_dims[outer_dim_start]
|
||||
// in [a, a+outer_dim_size].
|
||||
//
|
||||
// (3) c:a+outer_dim_size .
|
||||
//
|
||||
// Or, when b and c do not exist, we just need to process the whole block
|
||||
// together.
|
||||
|
||||
// Find a.
|
||||
const Index outer_dim_left_index =
|
||||
desc.offset() / m_outputStrides[outer_dim];
|
||||
|
||||
// Find b and c.
|
||||
const Index input_outer_dim_size = input_dims[outer_dim];
|
||||
|
||||
// First multiple after a. This is b when <= outer_dim_left_index +
|
||||
// outer_dim_size.
|
||||
const Index first_multiple =
|
||||
divup<Index>(outer_dim_left_index, input_outer_dim_size) *
|
||||
input_outer_dim_size;
|
||||
|
||||
if (first_multiple <= outer_dim_left_index + outer_dim_size) {
|
||||
// b exists, so does c. Find it.
|
||||
const Index last_multiple = (outer_dim_left_index + outer_dim_size) /
|
||||
input_outer_dim_size * input_outer_dim_size;
|
||||
const int copy_outer_dim = is_col_major
|
||||
? 2 * outer_dim_start
|
||||
: 2 * NumDims - 2 * outer_dim_start - 1;
|
||||
const int broadcast_outer_dim =
|
||||
is_col_major ? 2 * outer_dim_start + 1
|
||||
: 2 * NumDims - 2 * outer_dim_start - 2;
|
||||
|
||||
if (first_multiple > outer_dim_left_index) {
|
||||
const Index head_size = first_multiple - outer_dim_left_index;
|
||||
input_block_sizes[outer_dim] = head_size;
|
||||
bcast_block_sizes[copy_outer_dim] = head_size;
|
||||
bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
|
||||
bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
|
||||
bcast_block_sizes[broadcast_outer_dim] = 1;
|
||||
bcast_input_strides[broadcast_outer_dim] = 0;
|
||||
bcast_block_strides[broadcast_outer_dim] =
|
||||
output_strides[outer_dim] * input_dims[outer_dim];
|
||||
|
||||
BroadcastBlockV2(input_block_sizes, input_block_strides,
|
||||
bcast_block_sizes, bcast_block_strides,
|
||||
bcast_input_strides, 0, desc, scratch,
|
||||
materialized_output, &materialized_input,
|
||||
&materialized_input_size);
|
||||
}
|
||||
if (first_multiple < last_multiple) {
|
||||
input_block_sizes[outer_dim] = input_outer_dim_size;
|
||||
bcast_block_sizes[copy_outer_dim] = input_outer_dim_size;
|
||||
bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
|
||||
bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
|
||||
bcast_block_sizes[broadcast_outer_dim] =
|
||||
(last_multiple - first_multiple) / input_outer_dim_size;
|
||||
bcast_input_strides[broadcast_outer_dim] = 0;
|
||||
bcast_block_strides[broadcast_outer_dim] =
|
||||
output_strides[outer_dim] * input_dims[outer_dim];
|
||||
const Index offset = (first_multiple - outer_dim_left_index) *
|
||||
m_outputStrides[outer_dim];
|
||||
|
||||
BroadcastBlockV2(input_block_sizes, input_block_strides,
|
||||
bcast_block_sizes, bcast_block_strides,
|
||||
bcast_input_strides, offset, desc, scratch,
|
||||
materialized_output, &materialized_input,
|
||||
&materialized_input_size);
|
||||
}
|
||||
if (last_multiple < outer_dim_left_index + outer_dim_size) {
|
||||
const Index tail_size =
|
||||
outer_dim_left_index + outer_dim_size - last_multiple;
|
||||
input_block_sizes[outer_dim] = tail_size;
|
||||
bcast_block_sizes[copy_outer_dim] = tail_size;
|
||||
bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
|
||||
bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
|
||||
bcast_block_sizes[broadcast_outer_dim] = 1;
|
||||
bcast_input_strides[broadcast_outer_dim] = 0;
|
||||
bcast_block_strides[broadcast_outer_dim] =
|
||||
output_strides[outer_dim] * input_dims[outer_dim];
|
||||
const Index offset = (last_multiple - outer_dim_left_index) *
|
||||
m_outputStrides[outer_dim];
|
||||
|
||||
BroadcastBlockV2(input_block_sizes, input_block_strides,
|
||||
bcast_block_sizes, bcast_block_strides,
|
||||
bcast_input_strides, offset, desc, scratch,
|
||||
materialized_output, &materialized_input,
|
||||
&materialized_input_size);
|
||||
}
|
||||
} else {
|
||||
// b and c do not exist.
|
||||
const int copy_outer_dim = is_col_major
|
||||
? 2 * outer_dim_start
|
||||
: 2 * NumDims - 2 * outer_dim_start - 1;
|
||||
input_block_sizes[outer_dim] = outer_dim_size;
|
||||
bcast_block_sizes[copy_outer_dim] = outer_dim_size;
|
||||
bcast_input_strides[copy_outer_dim] = input_block_strides[outer_dim];
|
||||
bcast_block_strides[copy_outer_dim] = output_strides[outer_dim];
|
||||
|
||||
BroadcastBlockV2(
|
||||
input_block_sizes, input_block_strides, bcast_block_sizes,
|
||||
bcast_block_strides, bcast_input_strides, 0, desc, scratch,
|
||||
materialized_output, &materialized_input, &materialized_input_size);
|
||||
}
|
||||
}
|
||||
|
||||
return TensorBlockV2(materialized_in_output
|
||||
? internal::TensorBlockKind::kMaterializedInOutput
|
||||
: internal::TensorBlockKind::kMaterializedInScratch,
|
||||
materialized_output,
|
||||
desc.dimensions());
|
||||
return params;
|
||||
}
|
||||
|
||||
// This is a special case for `NumDims == 0`, in practice this should not
|
||||
// happen often, so it's fine to do memory allocation just for a scalar.
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2
|
||||
scalarBlock(TensorBlockScratch& scratch) const {
|
||||
void* mem = scratch.allocate(sizeof(Scalar));
|
||||
ScalarNoConst* buf = static_cast<ScalarNoConst*>(mem);
|
||||
*buf = m_impl.coeff(0);
|
||||
|
||||
DSizes<Index, NumDims> dimensions;
|
||||
for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
|
||||
|
||||
return TensorBlockV2(internal::TensorBlockKind::kMaterializedInScratch, buf,
|
||||
dimensions);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 emptyBlock() const {
|
||||
DSizes<Index, NumDims> dimensions;
|
||||
for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
|
||||
return TensorBlockV2(internal::TensorBlockKind::kView, NULL, dimensions);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
|
||||
|
||||
const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
|
||||
|
||||
Broadcast functor() const { return m_broadcast; }
|
||||
#ifdef EIGEN_USE_SYCL
|
||||
// binding placeholder accessors to a command group handler for SYCL
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
|
||||
m_impl.bind(cgh);
|
||||
}
|
||||
#endif
|
||||
private:
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlock(
|
||||
const Dimensions& input_block_sizes,
|
||||
const BroadcastDimensions& broadcast_block_sizes,
|
||||
@ -1202,23 +1131,194 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
BroadcastTensorBlockReader::Run(&broadcast_block, input_block.data());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void BroadcastBlockV2(
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlockV2 emptyBlock() const {
|
||||
DSizes<Index, NumDims> dimensions;
|
||||
for (int i = 0; i < NumDims; ++i) dimensions[i] = 0;
|
||||
return TensorBlockV2(internal::TensorBlockKind::kView, NULL, dimensions);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
|
||||
BlockBroadcastingParams params, Index bcast_offset,
|
||||
TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
|
||||
ScalarNoConst** materialized_input,
|
||||
size_t* materialized_input_size) const {
|
||||
if (params.bcast_dim_size == 1) {
|
||||
// We just need one block read using the ready-set values above.
|
||||
return BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, 0, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
|
||||
} else if (params.input_dims[params.bcast_dim] == 1) {
|
||||
// Broadcast bcast dimension (< NumDims) by bcast_dim_size.
|
||||
const int broadcast_bcast_dim =
|
||||
IsColMajor ? 2 * params.inner_dim_count + 1
|
||||
: 2 * NumDims - 2 * params.inner_dim_count - 2;
|
||||
|
||||
params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
|
||||
params.bcast_input_strides[broadcast_bcast_dim] = 0;
|
||||
params.bcast_block_strides[broadcast_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim];
|
||||
|
||||
return BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, 0, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
|
||||
} else {
|
||||
// Keep track of the total number of the coefficients written to the
|
||||
// output block.
|
||||
Index num_output_coeffs = 0;
|
||||
|
||||
// The general case. Let's denote the output block as
|
||||
//
|
||||
// x[..., a:a+bcast_dim_size, :, ..., :]
|
||||
//
|
||||
// where a:a+bcast_dim_size is a slice on the bcast_dim dimension
|
||||
// (< NumDims). We need to split the a:a+bcast_dim_size into possibly 3
|
||||
// sub-blocks:
|
||||
//
|
||||
// (1) a:b, where b is the smallest multiple of
|
||||
// input_dims[bcast_dim_start] in [a, a+bcast_dim_size].
|
||||
//
|
||||
// (2) b:c, where c is the largest multiple of input_dims[bcast_dim_start]
|
||||
// in [a, a+bcast_dim_size].
|
||||
//
|
||||
// (3) c:a+bcast_dim_size .
|
||||
//
|
||||
// Or, when b and c do not exist, we just need to process the whole block
|
||||
// together.
|
||||
|
||||
// Find a.
|
||||
const Index bcast_dim_left_index =
|
||||
bcast_offset / m_outputStrides[params.bcast_dim];
|
||||
|
||||
// Find b and c.
|
||||
const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
|
||||
|
||||
// First multiple after a. This is b when <= bcast_dim_left_index +
|
||||
// bcast_dim_size.
|
||||
const Index first_multiple =
|
||||
divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
|
||||
input_bcast_dim_size;
|
||||
|
||||
if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
|
||||
// b exists, so does c. Find it.
|
||||
const Index last_multiple =
|
||||
(bcast_dim_left_index + params.bcast_dim_size) /
|
||||
input_bcast_dim_size * input_bcast_dim_size;
|
||||
const int copy_bcast_dim =
|
||||
IsColMajor ? 2 * params.inner_dim_count
|
||||
: 2 * NumDims - 2 * params.inner_dim_count - 1;
|
||||
const int broadcast_bcast_dim =
|
||||
IsColMajor ? 2 * params.inner_dim_count + 1
|
||||
: 2 * NumDims - 2 * params.inner_dim_count - 2;
|
||||
|
||||
if (first_multiple > bcast_dim_left_index) {
|
||||
const Index head_size = first_multiple - bcast_dim_left_index;
|
||||
params.input_block_sizes[params.bcast_dim] = head_size;
|
||||
params.bcast_block_sizes[copy_bcast_dim] = head_size;
|
||||
params.bcast_input_strides[copy_bcast_dim] =
|
||||
params.input_block_strides[params.bcast_dim];
|
||||
params.bcast_block_strides[copy_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim];
|
||||
params.bcast_block_sizes[broadcast_bcast_dim] = 1;
|
||||
params.bcast_input_strides[broadcast_bcast_dim] = 0;
|
||||
params.bcast_block_strides[broadcast_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim] *
|
||||
params.input_dims[params.bcast_dim];
|
||||
|
||||
num_output_coeffs += BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, 0, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
}
|
||||
if (first_multiple < last_multiple) {
|
||||
params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
|
||||
params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
|
||||
params.bcast_input_strides[copy_bcast_dim] =
|
||||
params.input_block_strides[params.bcast_dim];
|
||||
params.bcast_block_strides[copy_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim];
|
||||
params.bcast_block_sizes[broadcast_bcast_dim] =
|
||||
(last_multiple - first_multiple) / input_bcast_dim_size;
|
||||
params.bcast_input_strides[broadcast_bcast_dim] = 0;
|
||||
params.bcast_block_strides[broadcast_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim] *
|
||||
params.input_dims[params.bcast_dim];
|
||||
const Index offset = (first_multiple - bcast_dim_left_index) *
|
||||
m_outputStrides[params.bcast_dim];
|
||||
|
||||
num_output_coeffs += BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, offset, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
}
|
||||
if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
|
||||
const Index tail_size =
|
||||
bcast_dim_left_index + params.bcast_dim_size - last_multiple;
|
||||
params.input_block_sizes[params.bcast_dim] = tail_size;
|
||||
params.bcast_block_sizes[copy_bcast_dim] = tail_size;
|
||||
params.bcast_input_strides[copy_bcast_dim] =
|
||||
params.input_block_strides[params.bcast_dim];
|
||||
params.bcast_block_strides[copy_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim];
|
||||
params.bcast_block_sizes[broadcast_bcast_dim] = 1;
|
||||
params.bcast_input_strides[broadcast_bcast_dim] = 0;
|
||||
params.bcast_block_strides[broadcast_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim] *
|
||||
params.input_dims[params.bcast_dim];
|
||||
const Index offset = (last_multiple - bcast_dim_left_index) *
|
||||
m_outputStrides[params.bcast_dim];
|
||||
|
||||
num_output_coeffs += BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, offset, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
}
|
||||
} else {
|
||||
// b and c do not exist.
|
||||
const int copy_bcast_dim =
|
||||
IsColMajor ? 2 * params.inner_dim_count
|
||||
: 2 * NumDims - 2 * params.inner_dim_count - 1;
|
||||
params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
|
||||
params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
|
||||
params.bcast_input_strides[copy_bcast_dim] =
|
||||
params.input_block_strides[params.bcast_dim];
|
||||
params.bcast_block_strides[copy_bcast_dim] =
|
||||
params.output_strides[params.bcast_dim];
|
||||
|
||||
num_output_coeffs += BroadcastBlockV2(
|
||||
params.input_block_sizes, params.input_block_strides,
|
||||
params.bcast_block_sizes, params.bcast_block_strides,
|
||||
params.bcast_input_strides, bcast_offset, 0, scratch,
|
||||
materialized_output, materialized_input, materialized_input_size);
|
||||
}
|
||||
|
||||
return num_output_coeffs;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockV2(
|
||||
const Dimensions& input_block_sizes,
|
||||
const Dimensions& input_block_strides,
|
||||
const BroadcastDimensions& bcast_block_sizes,
|
||||
const BroadcastDimensions& bcast_block_strides,
|
||||
const BroadcastDimensions& bcast_input_strides, Index offset,
|
||||
const TensorBlockDesc& output_desc, TensorBlockScratch& scratch,
|
||||
const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
|
||||
Index offset, TensorBlockScratch& scratch,
|
||||
ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
|
||||
size_t* materialized_input_size) const {
|
||||
// ---------------------------------------------------------------------- //
|
||||
// Tensor block descriptor for reading block from the input.
|
||||
const Index input_offset = output_desc.offset() + offset;
|
||||
static const bool is_col_major = static_cast<int>(Layout) == static_cast<int>(ColMajor);
|
||||
TensorBlockDesc input_desc(is_col_major
|
||||
? indexColMajor(input_offset)
|
||||
: indexRowMajor(input_offset),
|
||||
input_block_sizes);
|
||||
const Index input_offset = bcast_offset + offset;
|
||||
TensorBlockDesc input_desc(
|
||||
IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
|
||||
input_block_sizes);
|
||||
|
||||
ArgTensorBlock input_block = m_impl.blockV2(input_desc, scratch);
|
||||
|
||||
@ -1266,7 +1366,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
|
||||
typename TensorBlockIOV2::Dst dst(bcast_block_sizes, bcast_block_strides,
|
||||
materialized_output + offset);
|
||||
|
||||
TensorBlockIOV2::Copy(dst, src);
|
||||
return TensorBlockIOV2::Copy(dst, src);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -264,6 +264,10 @@ static void test_eval_tensor_broadcast() {
|
||||
input.broadcast(bcast),
|
||||
[&bcasted_dims]() { return SkewedInnerBlock<Layout>(bcasted_dims); });
|
||||
|
||||
VerifyBlockEvaluator<T, NumDims, Layout>(
|
||||
input.broadcast(bcast),
|
||||
[&bcasted_dims]() { return RandomBlock<Layout>(bcasted_dims, 5, 10); });
|
||||
|
||||
VerifyBlockEvaluator<T, NumDims, Layout>(
|
||||
input.broadcast(bcast),
|
||||
[&bcasted_dims]() { return FixedSizeBlock(bcasted_dims); });
|
||||
@ -534,6 +538,33 @@ static void test_eval_tensor_forced_eval() {
|
||||
[dims]() { return RandomBlock<Layout, 2>(dims, 1, 50); });
|
||||
}
|
||||
|
||||
template <typename T, int Layout>
|
||||
static void test_eval_tensor_chipping_of_bcast() {
|
||||
if (Layout != static_cast<int>(RowMajor)) return;
|
||||
|
||||
Index dim0 = internal::random<Index>(1, 10);
|
||||
Index dim1 = internal::random<Index>(1, 10);
|
||||
Index dim2 = internal::random<Index>(1, 10);
|
||||
|
||||
Tensor<T, 3, Layout> input(1, dim1, dim2);
|
||||
input.setRandom();
|
||||
|
||||
Eigen::array<Index, 3> bcast({dim0, 1, 1});
|
||||
DSizes<Index, 2> chipped_dims(dim0, dim2);
|
||||
|
||||
VerifyBlockEvaluator<T, 2, Layout>(
|
||||
input.broadcast(bcast).chip(0, 1),
|
||||
[chipped_dims]() { return FixedSizeBlock(chipped_dims); });
|
||||
|
||||
VerifyBlockEvaluator<T, 2, Layout>(
|
||||
input.broadcast(bcast).chip(0, 1),
|
||||
[chipped_dims]() { return SkewedInnerBlock<Layout, 2>(chipped_dims); });
|
||||
|
||||
VerifyBlockEvaluator<T, 2, Layout>(
|
||||
input.broadcast(bcast).chip(0, 1),
|
||||
[chipped_dims]() { return RandomBlock<Layout, 2>(chipped_dims, 1, 5); });
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Verify that assigning block to a Tensor expression produces the same result
|
||||
// as an assignment to TensorSliceOp (writing a block is is identical to
|
||||
@ -760,6 +791,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_block_eval) {
|
||||
|
||||
CALL_SUBTESTS_LAYOUTS(6, test_eval_tensor_reshape_with_bcast);
|
||||
CALL_SUBTESTS_LAYOUTS(6, test_eval_tensor_forced_eval);
|
||||
CALL_SUBTESTS_LAYOUTS(6, test_eval_tensor_chipping_of_bcast);
|
||||
|
||||
CALL_SUBTESTS_DIMS_LAYOUTS(7, test_assign_to_tensor);
|
||||
CALL_SUBTESTS_DIMS_LAYOUTS(7, test_assign_to_tensor_reshape);
|
||||
|
Loading…
Reference in New Issue
Block a user