blob: 4a8a4f389e7f57dc7ac1f8d3fe6eed390b32aa46 [file] [log] [blame]
/*
* Copyright © 2024 Valve Corporation
* SPDX-License-Identifier: MIT
*/
#include "nir.h"
#include "nir_builder.h"
#include "nir_phi_builder.h"
struct call_liveness_entry {
struct list_head list;
nir_call_instr *instr;
const BITSET_WORD *live_set;
};
static bool
can_remat_instr(nir_instr *instr)
{
switch (instr->type) {
case nir_instr_type_alu:
case nir_instr_type_load_const:
case nir_instr_type_undef:
return true;
case nir_instr_type_intrinsic:
switch (nir_instr_as_intrinsic(instr)->intrinsic) {
case nir_intrinsic_load_ray_launch_id:
case nir_intrinsic_load_ray_launch_size:
case nir_intrinsic_vulkan_resource_index:
case nir_intrinsic_vulkan_resource_reindex:
case nir_intrinsic_load_vulkan_descriptor:
case nir_intrinsic_load_push_constant:
case nir_intrinsic_load_global_constant:
case nir_intrinsic_load_smem_amd:
case nir_intrinsic_load_scalar_arg_amd:
case nir_intrinsic_load_vector_arg_amd:
return true;
default:
return false;
}
default:
return false;
}
}
static void
remat_ssa_def(nir_builder *b, nir_def *def, struct hash_table *remap_table,
struct hash_table *phi_value_table,
struct nir_phi_builder *phi_builder, BITSET_WORD *def_blocks)
{
memset(def_blocks, 0, BITSET_WORDS(b->impl->num_blocks) * sizeof(BITSET_WORD));
BITSET_SET(def_blocks, def->parent_instr->block->index);
BITSET_SET(def_blocks, nir_cursor_current_block(b->cursor)->index);
struct nir_phi_builder_value *val =
nir_phi_builder_add_value(phi_builder, def->num_components,
def->bit_size, def_blocks);
_mesa_hash_table_insert(phi_value_table, def, val);
nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr,
remap_table);
nir_builder_instr_insert(b, clone);
nir_def *new_def = nir_instr_def(clone);
_mesa_hash_table_insert(remap_table, def, new_def);
if (nir_cursor_current_block(b->cursor)->index !=
def->parent_instr->block->index)
nir_phi_builder_value_set_block_def(val, def->parent_instr->block, def);
nir_phi_builder_value_set_block_def(val, nir_cursor_current_block(b->cursor),
new_def);
}
struct remat_chain_check_data {
struct hash_table *remap_table;
unsigned chain_length;
};
static bool
can_remat_chain(nir_src *src, void *data)
{
struct remat_chain_check_data *check_data = data;
if (_mesa_hash_table_search(check_data->remap_table, src->ssa))
return true;
if (!can_remat_instr(src->ssa->parent_instr))
return false;
if (check_data->chain_length++ >= 16)
return false;
return nir_foreach_src(src->ssa->parent_instr, can_remat_chain, check_data);
}
struct remat_chain_data {
nir_builder *b;
struct hash_table *remap_table;
struct hash_table *phi_value_table;
struct nir_phi_builder *phi_builder;
BITSET_WORD *def_blocks;
};
static bool
do_remat_chain(nir_src *src, void *data)
{
struct remat_chain_data *remat_data = data;
if (_mesa_hash_table_search(remat_data->remap_table, src->ssa))
return true;
nir_foreach_src(src->ssa->parent_instr, do_remat_chain, remat_data);
remat_ssa_def(remat_data->b, src->ssa, remat_data->remap_table,
remat_data->phi_value_table, remat_data->phi_builder,
remat_data->def_blocks);
return true;
}
static bool
rewrite_instr_src_from_phi_builder(nir_src *src, void *data)
{
struct hash_table *phi_value_table = data;
if (nir_src_is_const(*src)) {
nir_builder b = nir_builder_at(nir_before_instr(nir_src_parent_instr(src)));
nir_src_rewrite(src, nir_build_imm(&b, src->ssa->num_components,
src->ssa->bit_size,
nir_src_as_const_value(*src)));
return true;
}
struct hash_entry *entry = _mesa_hash_table_search(phi_value_table, src->ssa);
if (!entry)
return true;
nir_block *block = nir_src_parent_instr(src)->block;
nir_def *new_def = nir_phi_builder_value_get_block_def(entry->data, block);
bool can_rewrite = true;
if (new_def->parent_instr->block == block && new_def->index != UINT32_MAX)
can_rewrite =
!nir_instr_is_before(nir_src_parent_instr(src), new_def->parent_instr);
if (can_rewrite)
nir_src_rewrite(src, new_def);
return true;
}
static bool
nir_minimize_call_live_states_impl(nir_function_impl *impl)
{
nir_metadata_require(impl, nir_metadata_block_index |
nir_metadata_live_defs |
nir_metadata_dominance);
bool progress = false;
void *mem_ctx = ralloc_context(NULL);
struct list_head call_list;
list_inithead(&call_list);
unsigned num_defs = impl->ssa_alloc;
nir_def **rematerializable =
rzalloc_array_size(mem_ctx, sizeof(nir_def *), num_defs);
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
nir_def *def = nir_instr_def(instr);
if (def &&
can_remat_instr(instr)) {
rematerializable[def->index] = def;
}
if (instr->type != nir_instr_type_call)
continue;
nir_call_instr *call = nir_instr_as_call(instr);
if (!call->indirect_callee.ssa)
continue;
struct call_liveness_entry *entry =
ralloc_size(mem_ctx, sizeof(struct call_liveness_entry));
entry->instr = call;
entry->live_set = nir_get_live_defs(nir_after_instr(instr), mem_ctx);
list_addtail(&entry->list, &call_list);
}
}
const unsigned block_words = BITSET_WORDS(impl->num_blocks);
BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
list_for_each_entry(struct call_liveness_entry, entry, &call_list, list) {
unsigned i;
nir_builder b = nir_builder_at(nir_after_instr(&entry->instr->instr));
struct nir_phi_builder *builder = nir_phi_builder_create(impl);
struct hash_table *phi_value_table =
_mesa_pointer_hash_table_create(mem_ctx);
struct hash_table *remap_table =
_mesa_pointer_hash_table_create(mem_ctx);
BITSET_FOREACH_SET(i, entry->live_set, num_defs) {
if (!rematerializable[i] ||
_mesa_hash_table_search(remap_table, rematerializable[i]))
continue;
assert(!_mesa_hash_table_search(phi_value_table, rematerializable[i]));
struct remat_chain_check_data check_data = {
.remap_table = remap_table,
.chain_length = 1,
};
if (!nir_foreach_src(rematerializable[i]->parent_instr,
can_remat_chain, &check_data))
continue;
struct remat_chain_data remat_data = {
.b = &b,
.remap_table = remap_table,
.phi_value_table = phi_value_table,
.phi_builder = builder,
.def_blocks = def_blocks,
};
nir_foreach_src(rematerializable[i]->parent_instr, do_remat_chain,
&remat_data);
remat_ssa_def(&b, rematerializable[i], remap_table, phi_value_table,
builder, def_blocks);
progress = true;
}
_mesa_hash_table_destroy(remap_table, NULL);
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
if (instr->type == nir_instr_type_phi)
continue;
nir_foreach_src(instr, rewrite_instr_src_from_phi_builder,
phi_value_table);
}
}
nir_phi_builder_finish(builder);
_mesa_hash_table_destroy(phi_value_table, NULL);
}
ralloc_free(mem_ctx);
nir_progress(true, impl, nir_metadata_block_index | nir_metadata_dominance);
return progress;
}
/* Tries to rematerialize as many live vars as possible after calls.
* Note: nir_opt_cse will undo any rematerializations done by this pass,
* so it shouldn't be run afterward.
*/
bool
nir_minimize_call_live_states(nir_shader *shader)
{
bool progress = false;
nir_foreach_function_impl(impl, shader) {
progress |= nir_minimize_call_live_states_impl(impl);
}
return progress;
}