Unroll for-loop if sampler array uses loop index as its index.

If inside a for-loop, sampler array index is the loop index, Mac cg compiler will crash.  This CL unroll the loop in such situation.  The behavior is:
1) If the for-loop index is a float, we reject the shader.
2) If it is an integer, we unroll the for-loop.

Things that should be done in the future are:
1) Add line number macros.
2) Add a limit to unroll iteration count.

anglebug=94
Review URL: http://codereview.appspot.com/4331048

git-svn-id: https://angleproject.googlecode.com/svn/trunk@606 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/build_angle.gyp b/src/build_angle.gyp
index 022d239..03fdee5 100644
--- a/src/build_angle.gyp
+++ b/src/build_angle.gyp
@@ -100,6 +100,8 @@
       ],
       'sources': [
         'compiler/CodeGenGLSL.cpp',
+        'compiler/ForLoopUnroll.cpp',
+        'compiler/ForLoopUnroll.h',
         'compiler/OutputGLSL.cpp',
         'compiler/OutputGLSL.h',
         'compiler/TranslatorGLSL.cpp',
diff --git a/src/common/version.h b/src/common/version.h
index 796b435..47564fa 100644
--- a/src/common/version.h
+++ b/src/common/version.h
@@ -1,7 +1,7 @@
 #define MAJOR_VERSION 0
 #define MINOR_VERSION 0
 #define BUILD_VERSION 0
-#define BUILD_REVISION 605
+#define BUILD_REVISION 606
 
 #define STRINGIFY(x) #x
 #define MACRO_STRINGIFY(x) STRINGIFY(x)
diff --git a/src/compiler/ForLoopUnroll.cpp b/src/compiler/ForLoopUnroll.cpp
new file mode 100644
index 0000000..d631af4
--- /dev/null
+++ b/src/compiler/ForLoopUnroll.cpp
@@ -0,0 +1,172 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/ForLoopUnroll.h"
+
+void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
+{
+    ASSERT(node->getType() == ELoopFor);
+    ASSERT(node->getUnrollFlag());
+
+    TIntermNode* init = node->getInit();
+    ASSERT(init != NULL);
+    TIntermAggregate* decl = init->getAsAggregate();
+    ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
+    TIntermSequence& declSeq = decl->getSequence();
+    ASSERT(declSeq.size() == 1);
+    TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
+    ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
+    TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
+    ASSERT(symbol != NULL);
+    ASSERT(symbol->getBasicType() == EbtInt);
+
+    info.id = symbol->getId();
+
+    ASSERT(declInit->getRight() != NULL);
+    TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
+    ASSERT(initNode != NULL);
+
+    info.initValue = evaluateIntConstant(initNode);
+    info.currentValue = info.initValue;
+
+    TIntermNode* cond = node->getCondition();
+    ASSERT(cond != NULL);
+    TIntermBinary* binOp = cond->getAsBinaryNode();
+    ASSERT(binOp != NULL);
+    ASSERT(binOp->getRight() != NULL);
+    ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
+
+    info.incrementValue = getLoopIncrement(node);
+    info.stopValue = evaluateIntConstant(
+        binOp->getRight()->getAsConstantUnion());
+    info.op = binOp->getOp();
+}
+
+void ForLoopUnroll::Step()
+{
+    ASSERT(mLoopIndexStack.size() > 0);
+    TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
+    info.currentValue += info.incrementValue;
+}
+
+bool ForLoopUnroll::SatisfiesLoopCondition()
+{
+    ASSERT(mLoopIndexStack.size() > 0);
+    TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
+    // Relational operator is one of: > >= < <= == or !=.
+    switch (info.op) {
+      case EOpEqual:
+        return (info.currentValue == info.stopValue);
+      case EOpNotEqual:
+        return (info.currentValue != info.stopValue);
+      case EOpLessThan:
+        return (info.currentValue < info.stopValue);
+      case EOpGreaterThan:
+        return (info.currentValue > info.stopValue);
+      case EOpLessThanEqual:
+        return (info.currentValue <= info.stopValue);
+      case EOpGreaterThanEqual:
+        return (info.currentValue >= info.stopValue);
+      default:
+        UNREACHABLE();
+    }
+    return false;
+}
+
+bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
+{
+    for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
+         i != mLoopIndexStack.end();
+         ++i) {
+        if (i->id == symbol->getId())
+            return true;
+    }
+    return false;
+}
+
+int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
+{
+    for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
+         i != mLoopIndexStack.end();
+         ++i) {
+        if (i->id == symbol->getId())
+            return i->currentValue;
+    }
+    UNREACHABLE();
+    return false;
+}
+
+void ForLoopUnroll::Push(TLoopIndexInfo& info)
+{
+    mLoopIndexStack.push_back(info);
+}
+
+void ForLoopUnroll::Pop()
+{
+    mLoopIndexStack.pop_back();
+}
+
+int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
+{
+    TIntermNode* expr = node->getExpression();
+    ASSERT(expr != NULL);
+    // for expression has one of the following forms:
+    //     loop_index++
+    //     loop_index--
+    //     loop_index += constant_expression
+    //     loop_index -= constant_expression
+    //     ++loop_index
+    //     --loop_index
+    // The last two forms are not specified in the spec, but I am assuming
+    // its an oversight.
+    TIntermUnary* unOp = expr->getAsUnaryNode();
+    TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
+
+    TOperator op = EOpNull;
+    TIntermConstantUnion* incrementNode = NULL;
+    if (unOp != NULL) {
+        op = unOp->getOp();
+    } else if (binOp != NULL) {
+        op = binOp->getOp();
+        ASSERT(binOp->getRight() != NULL);
+        incrementNode = binOp->getRight()->getAsConstantUnion();
+        ASSERT(incrementNode != NULL);
+    }
+
+    int increment = 0;
+    // The operator is one of: ++ -- += -=.
+    switch (op) {
+        case EOpPostIncrement:
+        case EOpPreIncrement:
+            ASSERT((unOp != NULL) && (binOp == NULL));
+            increment = 1;
+            break;
+        case EOpPostDecrement:
+        case EOpPreDecrement:
+            ASSERT((unOp != NULL) && (binOp == NULL));
+            increment = -1;
+            break;
+        case EOpAddAssign:
+            ASSERT((unOp == NULL) && (binOp != NULL));
+            increment = evaluateIntConstant(incrementNode);
+            break;
+        case EOpSubAssign:
+            ASSERT((unOp == NULL) && (binOp != NULL));
+            increment = - evaluateIntConstant(incrementNode);
+            break;
+        default:
+            ASSERT(false);
+    }
+
+    return increment;
+}
+
+int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
+{
+    ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
+    return node->getUnionArrayPointer()->getIConst();
+}
+
diff --git a/src/compiler/ForLoopUnroll.h b/src/compiler/ForLoopUnroll.h
new file mode 100644
index 0000000..b2b2b58
--- /dev/null
+++ b/src/compiler/ForLoopUnroll.h
@@ -0,0 +1,46 @@
+//
+// Copyright (c) 2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/intermediate.h"
+
+struct TLoopIndexInfo {
+    int id;
+    int initValue;
+    int stopValue;
+    int incrementValue;
+    TOperator op;
+    int currentValue;
+};
+
+class ForLoopUnroll {
+public:
+    ForLoopUnroll() { }
+
+    void FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info);
+
+    // Update the info.currentValue for the next loop iteration.
+    void Step();
+
+    // Return false if loop condition is no longer satisfied.
+    bool SatisfiesLoopCondition();
+
+    // Check if the symbol is the index of a loop that's unrolled.
+    bool NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol);
+
+    // Return the current value of a given loop index symbol.
+    int GetLoopIndexValue(TIntermSymbol* symbol);
+
+    void Push(TLoopIndexInfo& info);
+    void Pop();
+
+private:
+    int getLoopIncrement(TIntermLoop* node);
+
+    int evaluateIntConstant(TIntermConstantUnion* node);
+
+    TVector<TLoopIndexInfo> mLoopIndexStack;
+};
+
diff --git a/src/compiler/OutputGLSL.cpp b/src/compiler/OutputGLSL.cpp
index 23476f2..3224dfd 100644
--- a/src/compiler/OutputGLSL.cpp
+++ b/src/compiler/OutputGLSL.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 //
@@ -195,7 +195,10 @@
 void TOutputGLSL::visitSymbol(TIntermSymbol* node)
 {
     TInfoSinkBase& out = objSink();
-    out << node->getSymbol();
+    if (mLoopUnroll.NeedsToReplaceSymbolWithValue(node))
+        out << mLoopUnroll.GetLoopIndexValue(node);
+    else
+        out << node->getSymbol();
 
     if (mDeclaringVariables && node->getType().isArray())
         out << arrayBrackets(node->getType());
@@ -615,18 +618,20 @@
     TLoopType loopType = node->getType();
     if (loopType == ELoopFor)  // for loop
     {
-        out << "for (";
-        if (node->getInit())
-            node->getInit()->traverse(this);
-        out << "; ";
+        if (!node->getUnrollFlag()) {
+            out << "for (";
+            if (node->getInit())
+                node->getInit()->traverse(this);
+            out << "; ";
 
-        if (node->getCondition())
-            node->getCondition()->traverse(this);
-        out << "; ";
+            if (node->getCondition())
+                node->getCondition()->traverse(this);
+            out << "; ";
 
-        if (node->getExpression())
-            node->getExpression()->traverse(this);
-        out << ")\n";
+            if (node->getExpression())
+                node->getExpression()->traverse(this);
+            out << ")\n";
+        }
     }
     else if (loopType == ELoopWhile)  // while loop
     {
@@ -642,7 +647,22 @@
     }
 
     // Loop body.
-    visitCodeBlock(node->getBody());
+    if (node->getUnrollFlag())
+    {
+        TLoopIndexInfo indexInfo;
+        mLoopUnroll.FillLoopIndexInfo(node, indexInfo);
+        mLoopUnroll.Push(indexInfo);
+        while (mLoopUnroll.SatisfiesLoopCondition())
+        {
+            visitCodeBlock(node->getBody());
+            mLoopUnroll.Step();
+        }
+        mLoopUnroll.Pop();
+    }
+    else
+    {
+        visitCodeBlock(node->getBody());
+    }
 
     // Loop footer.
     if (loopType == ELoopDoWhile)  // do-while loop
diff --git a/src/compiler/OutputGLSL.h b/src/compiler/OutputGLSL.h
index aa203d4..ace110a 100644
--- a/src/compiler/OutputGLSL.h
+++ b/src/compiler/OutputGLSL.h
@@ -9,6 +9,7 @@
 
 #include <set>
 
+#include "compiler/ForLoopUnroll.h"
 #include "compiler/intermediate.h"
 #include "compiler/ParseHelper.h"
 
@@ -44,6 +45,8 @@
     // declared only once.
     typedef std::set<TString> DeclaredStructs;
     DeclaredStructs mDeclaredStructs;
+
+    ForLoopUnroll mLoopUnroll;
 };
 
 #endif  // CROSSCOMPILERGLSL_OUTPUTGLSL_H_
diff --git a/src/compiler/ValidateLimitations.cpp b/src/compiler/ValidateLimitations.cpp
index 886f693..b46e4b9 100644
--- a/src/compiler/ValidateLimitations.cpp
+++ b/src/compiler/ValidateLimitations.cpp
@@ -17,6 +17,17 @@
     return false;
 }
 
+void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
+    for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
+        if (i->index.id == symbol->getId()) {
+            ASSERT(i->loop != NULL);
+            i->loop->setUnrollFlag(true);
+            return;
+        }
+    }
+    UNREACHABLE();
+}
+
 // Traverses a node to check if it represents a constant index expression.
 // Definition:
 // constant-index-expressions are a superset of constant-expressions.
@@ -54,6 +65,48 @@
     bool mValid;
     const TLoopStack& mLoopStack;
 };
+
+// Traverses a node to check if it uses a loop index.
+// If an int loop index is used in its body as a sampler array index,
+// mark the loop for unroll.
+class ValidateLoopIndexExpr : public TIntermTraverser {
+public:
+    ValidateLoopIndexExpr(TLoopStack& stack)
+        : mUsesFloatLoopIndex(false),
+          mUsesIntLoopIndex(false),
+          mLoopStack(stack) {}
+
+    bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
+    bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
+
+    virtual void visitSymbol(TIntermSymbol* symbol) {
+        if (IsLoopIndex(symbol, mLoopStack)) {
+            switch (symbol->getBasicType()) {
+              case EbtFloat:
+                mUsesFloatLoopIndex = true;
+                break;
+              case EbtInt:
+                mUsesIntLoopIndex = true;
+                MarkLoopForUnroll(symbol, mLoopStack);
+                break;
+              default:
+                UNREACHABLE();
+            }
+        }
+    }
+    virtual void visitConstantUnion(TIntermConstantUnion*) {}
+    virtual bool visitBinary(Visit, TIntermBinary*) { return true; }
+    virtual bool visitUnary(Visit, TIntermUnary*) { return true; }
+    virtual bool visitSelection(Visit, TIntermSelection*) { return true; }
+    virtual bool visitAggregate(Visit, TIntermAggregate*) { return true; }
+    virtual bool visitLoop(Visit, TIntermLoop*) { return true; }
+    virtual bool visitBranch(Visit, TIntermBranch*) { return true; }
+
+private:
+    bool mUsesFloatLoopIndex;
+    bool mUsesIntLoopIndex;
+    TLoopStack& mLoopStack;
+};
 }  // namespace
 
 ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
@@ -80,7 +133,28 @@
     // Check indexing.
     switch (node->getOp()) {
       case EOpIndexDirect:
+        validateIndexing(node);
+        break;
       case EOpIndexIndirect:
+#if defined(__APPLE__)
+        // Loop unrolling is a work-around for a Mac Cg compiler bug where it
+        // crashes when a sampler array's index is also the loop index.
+        // Once Apple fixes this bug, we should remove the code in this CL.
+        // See http://codereview.appspot.com/4331048/.
+        if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
+            (node->getLeft()->getAsSymbolNode())) {
+            TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
+            if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
+                ValidateLoopIndexExpr validate(mLoopStack);
+                node->getRight()->traverse(&validate);
+                if (validate.usesFloatLoopIndex()) {
+                    error(node->getLine(),
+                          "sampler array index is float loop index",
+                          "for");
+                }
+            }
+        }
+#endif
         validateIndexing(node);
         break;
       default: break;
@@ -120,6 +194,7 @@
 
     TLoopInfo info;
     memset(&info, 0, sizeof(TLoopInfo));
+    info.loop = node;
     if (!validateForLoopHeader(node, &info))
         return false;
 
diff --git a/src/compiler/ValidateLimitations.h b/src/compiler/ValidateLimitations.h
index a4f5a28..dd2e5bf 100644
--- a/src/compiler/ValidateLimitations.h
+++ b/src/compiler/ValidateLimitations.h
@@ -13,6 +13,7 @@
     struct TIndex {
         int id;  // symbol id.
     } index;
+    TIntermLoop* loop;
 };
 typedef TVector<TLoopInfo> TLoopStack;
 
diff --git a/src/compiler/intermediate.h b/src/compiler/intermediate.h
index c3c073c..cf91061 100644
--- a/src/compiler/intermediate.h
+++ b/src/compiler/intermediate.h
@@ -279,7 +279,8 @@
             init(aInit),
             cond(aCond),
             expr(aExpr),
-            body(aBody) { }
+            body(aBody),
+            unrollFlag(false) { }
 
     virtual TIntermLoop* getAsLoopNode() { return this; }
     virtual void traverse(TIntermTraverser*);
@@ -290,12 +291,17 @@
     TIntermTyped* getExpression() { return expr; }
     TIntermNode* getBody() { return body; }
 
+    void setUnrollFlag(bool flag) { unrollFlag = flag; }
+    bool getUnrollFlag() { return unrollFlag; }
+
 protected:
     TLoopType type;
     TIntermNode* init;  // for-loop initialization
     TIntermTyped* cond; // loop exit condition
     TIntermTyped* expr; // for-loop expression
     TIntermNode* body;  // loop body
+
+    bool unrollFlag; // Whether the loop should be unrolled or not.
 };
 
 //