Vulkan: SPIR-V Gen: Fix .length() vs ssbo arrays
Bug: angleproject:4889
Change-Id: Ib490a46fabf058064fc1b18d2c084a4bc5f9277d
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/3052682
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Tim Van Patten <timvp@google.com>
diff --git a/src/compiler/translator/OutputSPIRV.cpp b/src/compiler/translator/OutputSPIRV.cpp
index 246dc9b..1ea1ce2 100644
--- a/src/compiler/translator/OutputSPIRV.cpp
+++ b/src/compiler/translator/OutputSPIRV.cpp
@@ -304,6 +304,8 @@
spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
+ void visitArrayLength(TIntermUnary *node);
+
// Cast between types. There are two kinds of casts:
//
// - A constructor can cast between basic types, for example vec4(someInt).
@@ -2020,6 +2022,73 @@
return result;
}
+void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
+{
+ // .length() on sized arrays is already constant folded, so this operation only applies to
+ // ssbo[N].last_member.length(). OpArrayLength takes the ssbo block *pointer* and the field
+ // index of last_member, so those need to be extracted from the access chain. Additionally,
+ // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
+ // necessary.
+
+ // Inspect the children. There are two possibilities:
+ //
+ // - last_member.length(): In this case, the id of the nameless ssbo is used.
+ // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
+ // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
+ //
+ // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
+ // which is not useful.
+
+ spirv::IdRef accessChainId;
+ spirv::LiteralInteger fieldIndex;
+
+ if (node->getOperand()->getAsSymbolNode())
+ {
+ // If the operand is a symbol referencing the last member of a nameless interface block,
+ // visit the symbol to get the id of the interface block.
+ node->getOperand()->getAsSymbolNode()->traverse(this);
+
+ // The access chain must only include the base id + one literal field index.
+ ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
+
+ accessChainId = mNodeData.back().baseId;
+ fieldIndex = mNodeData.back().idList.back().literal;
+ }
+ else
+ {
+ // Otherwise make sure not to traverse the field index selection node so that the access
+ // chain would not include it.
+ TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
+ ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
+
+ TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
+ TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
+ ASSERT(indexNode);
+
+ // Visit the expression.
+ interfaceBlockExpression->traverse(this);
+
+ accessChainId = accessChainCollapse(&mNodeData.back());
+ fieldIndex = spirv::LiteralInteger(indexNode->getIConst(0));
+ }
+
+ // Get the int and uint type ids.
+ const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
+ const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
+
+ // Generate the instruction.
+ const spirv::IdRef resultId = mBuilder.getNewId({});
+ spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
+ accessChainId, fieldIndex);
+
+ // Cast to int.
+ const spirv::IdRef castResultId = mBuilder.getNewId({});
+ spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
+
+ // Replace the access chain with an rvalue that's the result.
+ nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
+}
+
bool IsShortCircuitNeeded(TIntermOperator *node)
{
TOperator op = node->getOp();
@@ -4856,6 +4925,15 @@
// Constants are expected to be folded.
ASSERT(!node->hasConstantValue());
+ // Special case EOpArrayLength.
+ if (node->getOp() == EOpArrayLength)
+ {
+ visitArrayLength(node);
+
+ // Children already visited.
+ return false;
+ }
+
if (visit == PreVisit)
{
// Don't add an entry to the stack. The child will create one, which we won't pop.
@@ -4868,38 +4946,6 @@
// There is at least on entry for the child.
ASSERT(mNodeData.size() >= 1);
- // Special case EOpArrayLength. .length() on sized arrays is already constant folded, so this
- // operation only applies to ssbo.last_member.length(). OpArrayLength takes the ssbo block
- // *type* and the field index of last_member, so those need to be extracted from the access
- // chain. Additionally, OpArrayLength produces an unsigned int while GLSL produces an int, so a
- // final cast is necessary.
- if (node->getOp() == EOpArrayLength)
- {
- // The access chain must only include the base ssbo + one literal field index.
- ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
- const spirv::IdRef baseId = mNodeData.back().baseId;
- const spirv::LiteralInteger fieldIndex = mNodeData.back().idList.back().literal;
-
- // Get the int and uint type ids.
- const spirv::IdRef intTypeId = mBuilder.getBasicTypeId(EbtInt, 1);
- const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
-
- // Generate the instruction.
- const spirv::IdRef resultId = mBuilder.getNewId({});
- spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
- baseId, fieldIndex);
-
- // Cast to int.
- const spirv::IdRef castResultId = mBuilder.getNewId({});
- spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId,
- resultId);
-
- // Replace the access chain with an rvalue that's the result.
- nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
-
- return true;
- }
-
const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
const spirv::IdRef result = visitOperator(node, resultTypeId);