Vulkan: SPIR-V Gen: Fix .length() vs ssbo arrays

Bug: angleproject:4889
Change-Id: Ib490a46fabf058064fc1b18d2c084a4bc5f9277d
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/3052682
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Tim Van Patten <timvp@google.com>
diff --git a/src/compiler/translator/OutputSPIRV.cpp b/src/compiler/translator/OutputSPIRV.cpp
index 246dc9b..1ea1ce2 100644
--- a/src/compiler/translator/OutputSPIRV.cpp
+++ b/src/compiler/translator/OutputSPIRV.cpp
@@ -304,6 +304,8 @@
 
     spirv::IdRef createFunctionCall(TIntermAggregate *node, spirv::IdRef resultTypeId);
 
+    void visitArrayLength(TIntermUnary *node);
+
     // Cast between types.  There are two kinds of casts:
     //
     // - A constructor can cast between basic types, for example vec4(someInt).
@@ -2020,6 +2022,73 @@
     return result;
 }
 
+void OutputSPIRVTraverser::visitArrayLength(TIntermUnary *node)
+{
+    // .length() on sized arrays is already constant folded, so this operation only applies to
+    // ssbo[N].last_member.length().  OpArrayLength takes the ssbo block *pointer* and the field
+    // index of last_member, so those need to be extracted from the access chain.  Additionally,
+    // OpArrayLength produces an unsigned int while GLSL produces an int, so a final cast is
+    // necessary.
+
+    // Inspect the children.  There are two possibilities:
+    //
+    // - last_member.length(): In this case, the id of the nameless ssbo is used.
+    // - ssbo.last_member.length(): In this case, the id of the variable |ssbo| itself is used.
+    // - ssbo[N][M].last_member.length(): In this case, the access chain |ssbo N M| is used.
+    //
+    // We can't visit the child in its entirety as it will create the access chain |ssbo N M field|
+    // which is not useful.
+
+    spirv::IdRef accessChainId;
+    spirv::LiteralInteger fieldIndex;
+
+    if (node->getOperand()->getAsSymbolNode())
+    {
+        // If the operand is a symbol referencing the last member of a nameless interface block,
+        // visit the symbol to get the id of the interface block.
+        node->getOperand()->getAsSymbolNode()->traverse(this);
+
+        // The access chain must only include the base id + one literal field index.
+        ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
+
+        accessChainId = mNodeData.back().baseId;
+        fieldIndex    = mNodeData.back().idList.back().literal;
+    }
+    else
+    {
+        // Otherwise make sure not to traverse the field index selection node so that the access
+        // chain would not include it.
+        TIntermBinary *fieldSelectionNode = node->getOperand()->getAsBinaryNode();
+        ASSERT(fieldSelectionNode && fieldSelectionNode->getOp() == EOpIndexDirectInterfaceBlock);
+
+        TIntermTyped *interfaceBlockExpression = fieldSelectionNode->getLeft();
+        TIntermConstantUnion *indexNode = fieldSelectionNode->getRight()->getAsConstantUnion();
+        ASSERT(indexNode);
+
+        // Visit the expression.
+        interfaceBlockExpression->traverse(this);
+
+        accessChainId = accessChainCollapse(&mNodeData.back());
+        fieldIndex    = spirv::LiteralInteger(indexNode->getIConst(0));
+    }
+
+    // Get the int and uint type ids.
+    const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
+    const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
+
+    // Generate the instruction.
+    const spirv::IdRef resultId = mBuilder.getNewId({});
+    spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
+                            accessChainId, fieldIndex);
+
+    // Cast to int.
+    const spirv::IdRef castResultId = mBuilder.getNewId({});
+    spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId, resultId);
+
+    // Replace the access chain with an rvalue that's the result.
+    nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
+}
+
 bool IsShortCircuitNeeded(TIntermOperator *node)
 {
     TOperator op = node->getOp();
@@ -4856,6 +4925,15 @@
     // Constants are expected to be folded.
     ASSERT(!node->hasConstantValue());
 
+    // Special case EOpArrayLength.
+    if (node->getOp() == EOpArrayLength)
+    {
+        visitArrayLength(node);
+
+        // Children already visited.
+        return false;
+    }
+
     if (visit == PreVisit)
     {
         // Don't add an entry to the stack.  The child will create one, which we won't pop.
@@ -4868,38 +4946,6 @@
     // There is at least on entry for the child.
     ASSERT(mNodeData.size() >= 1);
 
-    // Special case EOpArrayLength.  .length() on sized arrays is already constant folded, so this
-    // operation only applies to ssbo.last_member.length().  OpArrayLength takes the ssbo block
-    // *type* and the field index of last_member, so those need to be extracted from the access
-    // chain.  Additionally, OpArrayLength produces an unsigned int while GLSL produces an int, so a
-    // final cast is necessary.
-    if (node->getOp() == EOpArrayLength)
-    {
-        // The access chain must only include the base ssbo + one literal field index.
-        ASSERT(mNodeData.back().idList.size() == 1 && !mNodeData.back().idList.back().id.valid());
-        const spirv::IdRef baseId              = mNodeData.back().baseId;
-        const spirv::LiteralInteger fieldIndex = mNodeData.back().idList.back().literal;
-
-        // Get the int and uint type ids.
-        const spirv::IdRef intTypeId  = mBuilder.getBasicTypeId(EbtInt, 1);
-        const spirv::IdRef uintTypeId = mBuilder.getBasicTypeId(EbtUInt, 1);
-
-        // Generate the instruction.
-        const spirv::IdRef resultId = mBuilder.getNewId({});
-        spirv::WriteArrayLength(mBuilder.getSpirvCurrentFunctionBlock(), uintTypeId, resultId,
-                                baseId, fieldIndex);
-
-        // Cast to int.
-        const spirv::IdRef castResultId = mBuilder.getNewId({});
-        spirv::WriteBitcast(mBuilder.getSpirvCurrentFunctionBlock(), intTypeId, castResultId,
-                            resultId);
-
-        // Replace the access chain with an rvalue that's the result.
-        nodeDataInitRValue(&mNodeData.back(), castResultId, intTypeId);
-
-        return true;
-    }
-
     const spirv::IdRef resultTypeId = mBuilder.getTypeData(node->getType(), {}).id;
     const spirv::IdRef result       = visitOperator(node, resultTypeId);