Merge pull request #103613 from stuartcarnie/fix_101696

Metal: Use uniform set index passed by `RenderingDevice`
This commit is contained in:
Thaddeus Crews 2025-03-05 12:08:11 -06:00
commit 59d75a704e
No known key found for this signature in database
GPG Key ID: 62181B86FE9E5D84
2 changed files with 43 additions and 41 deletions

View File

@ -787,18 +787,18 @@ struct BoundUniformSet {
class API_AVAILABLE(macos(11.0), ios(14.0), tvos(14.0)) MDUniformSet {
private:
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
void bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
void bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
public:
uint32_t index;
LocalVector<RDD::BoundUniform> uniforms;
HashMap<MDShader *, BoundUniformSet> bound_uniforms;
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state);
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state);
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index);
void bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index);
BoundUniformSet &bound_uniform_set(MDShader *p_shader, id<MTLDevice> p_device, ResourceUsageMap &p_resource_usage);
};

View File

@ -213,36 +213,38 @@ void MDCommandBuffer::render_bind_uniform_set(RDD::UniformSetID p_uniform_set, R
DEV_ASSERT(type == MDCommandBufferStateType::Render);
MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
if (render.uniform_sets.size() <= set->index) {
if (render.uniform_sets.size() <= p_set_index) {
uint32_t s = render.uniform_sets.size();
render.uniform_sets.resize(set->index + 1);
render.uniform_sets.resize(p_set_index + 1);
// Set intermediate values to null.
std::fill(&render.uniform_sets[s], &render.uniform_sets[set->index] + 1, nullptr);
std::fill(&render.uniform_sets[s], &render.uniform_sets[p_set_index] + 1, nullptr);
}
if (render.uniform_sets[set->index] != set) {
if (render.uniform_sets[p_set_index] != set) {
render.dirty.set_flag(RenderState::DIRTY_UNIFORMS);
render.uniform_set_mask |= 1ULL << set->index;
render.uniform_sets[set->index] = set;
render.uniform_set_mask |= 1ULL << p_set_index;
render.uniform_sets[p_set_index] = set;
}
}
void MDCommandBuffer::render_bind_uniform_sets(VectorView<RDD::UniformSetID> p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count) {
DEV_ASSERT(type == MDCommandBufferStateType::Render);
for (size_t i = 0u; i < p_set_count; ++i) {
for (size_t i = 0; i < p_set_count; ++i) {
MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id);
if (render.uniform_sets.size() <= set->index) {
uint32_t index = p_first_set_index + i;
if (render.uniform_sets.size() <= index) {
uint32_t s = render.uniform_sets.size();
render.uniform_sets.resize(set->index + 1);
render.uniform_sets.resize(index + 1);
// Set intermediate values to null.
std::fill(&render.uniform_sets[s], &render.uniform_sets[set->index] + 1, nullptr);
std::fill(&render.uniform_sets[s], &render.uniform_sets[index] + 1, nullptr);
}
if (render.uniform_sets[set->index] != set) {
if (render.uniform_sets[index] != set) {
render.dirty.set_flag(RenderState::DIRTY_UNIFORMS);
render.uniform_set_mask |= 1ULL << set->index;
render.uniform_sets[set->index] = set;
render.uniform_set_mask |= 1ULL << index;
render.uniform_sets[index] = set;
}
}
}
@ -474,14 +476,14 @@ void MDCommandBuffer::_render_bind_uniform_sets() {
while (set_uniforms != 0) {
// Find the index of the next set bit.
int index = __builtin_ctzll(set_uniforms);
uint32_t index = (uint32_t)__builtin_ctzll(set_uniforms);
// Clear the set bit.
set_uniforms &= (set_uniforms - 1);
MDUniformSet *set = render.uniform_sets[index];
if (set == nullptr || set->index >= (uint32_t)shader->sets.size()) {
if (set == nullptr || index >= (uint32_t)shader->sets.size()) {
continue;
}
set->bind_uniforms(shader, render);
set->bind_uniforms(shader, render, index);
}
}
@ -955,7 +957,7 @@ void MDCommandBuffer::compute_bind_uniform_set(RDD::UniformSetID p_uniform_set,
MDShader *shader = (MDShader *)(p_shader.id);
MDUniformSet *set = (MDUniformSet *)(p_uniform_set.id);
set->bind_uniforms(shader, compute);
set->bind_uniforms(shader, compute, p_set_index);
}
void MDCommandBuffer::compute_bind_uniform_sets(VectorView<RDD::UniformSetID> p_uniform_sets, RDD::ShaderID p_shader, uint32_t p_first_set_index, uint32_t p_set_count) {
@ -966,7 +968,7 @@ void MDCommandBuffer::compute_bind_uniform_sets(VectorView<RDD::UniformSetID> p_
// TODO(sgc): Bind multiple buffers using [encoder setBuffers:offsets:withRange:]
for (size_t i = 0u; i < p_set_count; ++i) {
MDUniformSet *set = (MDUniformSet *)(p_uniform_sets[i].id);
set->bind_uniforms(shader, compute);
set->bind_uniforms(shader, compute, p_first_set_index + i);
}
}
@ -1052,11 +1054,11 @@ void MDRenderShader::encode_push_constant_data(VectorView<uint32_t> p_data, MDCo
}
}
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
DEV_ASSERT(p_shader->uses_argument_buffers);
DEV_ASSERT(p_state.encoder != nil);
UniformSet const &set_info = p_shader->sets[index];
UniformSet const &set_info = p_shader->sets[p_set_index];
id<MTLRenderCommandEncoder> __unsafe_unretained enc = p_state.encoder;
id<MTLDevice> __unsafe_unretained device = enc.device;
@ -1067,25 +1069,25 @@ void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandB
{
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_VERTEX);
if (offset) {
[enc setVertexBuffer:bus.buffer offset:*offset atIndex:index];
[enc setVertexBuffer:bus.buffer offset:*offset atIndex:p_set_index];
}
}
// Set the buffer for the fragment stage.
{
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_FRAGMENT);
if (offset) {
[enc setFragmentBuffer:bus.buffer offset:*offset atIndex:index];
[enc setFragmentBuffer:bus.buffer offset:*offset atIndex:p_set_index];
}
}
}
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
DEV_ASSERT(!p_shader->uses_argument_buffers);
DEV_ASSERT(p_state.encoder != nil);
id<MTLRenderCommandEncoder> __unsafe_unretained enc = p_state.encoder;
UniformSet const &set = p_shader->sets[index];
UniformSet const &set = p_shader->sets[p_set_index];
for (uint32_t i = 0; i < MIN(uniforms.size(), set.uniforms.size()); i++) {
RDD::BoundUniform const &uniform = uniforms[i];
@ -1256,19 +1258,19 @@ void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::Ren
}
}
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state) {
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::RenderState &p_state, uint32_t p_set_index) {
if (p_shader->uses_argument_buffers) {
bind_uniforms_argument_buffers(p_shader, p_state);
bind_uniforms_argument_buffers(p_shader, p_state, p_set_index);
} else {
bind_uniforms_direct(p_shader, p_state);
bind_uniforms_direct(p_shader, p_state, p_set_index);
}
}
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
DEV_ASSERT(p_shader->uses_argument_buffers);
DEV_ASSERT(p_state.encoder != nil);
UniformSet const &set_info = p_shader->sets[index];
UniformSet const &set_info = p_shader->sets[p_set_index];
id<MTLComputeCommandEncoder> enc = p_state.encoder;
id<MTLDevice> device = enc.device;
@ -1277,17 +1279,17 @@ void MDUniformSet::bind_uniforms_argument_buffers(MDShader *p_shader, MDCommandB
uint32_t const *offset = set_info.offsets.getptr(RDD::SHADER_STAGE_COMPUTE);
if (offset) {
[enc setBuffer:bus.buffer offset:*offset atIndex:index];
[enc setBuffer:bus.buffer offset:*offset atIndex:p_set_index];
}
}
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
DEV_ASSERT(!p_shader->uses_argument_buffers);
DEV_ASSERT(p_state.encoder != nil);
id<MTLComputeCommandEncoder> __unsafe_unretained enc = p_state.encoder;
UniformSet const &set = p_shader->sets[index];
UniformSet const &set = p_shader->sets[p_set_index];
for (uint32_t i = 0; i < uniforms.size(); i++) {
RDD::BoundUniform const &uniform = uniforms[i];
@ -1407,11 +1409,11 @@ void MDUniformSet::bind_uniforms_direct(MDShader *p_shader, MDCommandBuffer::Com
}
}
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state) {
void MDUniformSet::bind_uniforms(MDShader *p_shader, MDCommandBuffer::ComputeState &p_state, uint32_t p_set_index) {
if (p_shader->uses_argument_buffers) {
bind_uniforms_argument_buffers(p_shader, p_state);
bind_uniforms_argument_buffers(p_shader, p_state, p_set_index);
} else {
bind_uniforms_direct(p_shader, p_state);
bind_uniforms_direct(p_shader, p_state, p_set_index);
}
}