Unroll for-loop if sampler array uses loop index as its index.
If inside a for-loop, sampler array index is the loop index, Mac cg compiler will crash. This CL unroll the loop in such situation. The behavior is:
1) If the for-loop index is a float, we reject the shader.
2) If it is an integer, we unroll the for-loop.
Things that should be done in the future are:
1) Add line number macros.
2) Add a limit to unroll iteration count.
anglebug=94
Review URL: http://codereview.appspot.com/4331048
git-svn-id: https://angleproject.googlecode.com/svn/trunk@606 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/build_angle.gyp b/src/build_angle.gyp
index 022d239..03fdee5 100644
--- a/src/build_angle.gyp
+++ b/src/build_angle.gyp
@@ -100,6 +100,8 @@
],
'sources': [
'compiler/CodeGenGLSL.cpp',
+ 'compiler/ForLoopUnroll.cpp',
+ 'compiler/ForLoopUnroll.h',
'compiler/OutputGLSL.cpp',
'compiler/OutputGLSL.h',
'compiler/TranslatorGLSL.cpp',
diff --git a/src/common/version.h b/src/common/version.h
index 796b435..47564fa 100644
--- a/src/common/version.h
+++ b/src/common/version.h
@@ -1,7 +1,7 @@
#define MAJOR_VERSION 0
#define MINOR_VERSION 0
#define BUILD_VERSION 0
-#define BUILD_REVISION 605
+#define BUILD_REVISION 606
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
diff --git a/src/compiler/ForLoopUnroll.cpp b/src/compiler/ForLoopUnroll.cpp
new file mode 100644
index 0000000..d631af4
--- /dev/null
+++ b/src/compiler/ForLoopUnroll.cpp
@@ -0,0 +1,172 @@
+//
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/ForLoopUnroll.h"
+
+void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
+{
+ ASSERT(node->getType() == ELoopFor);
+ ASSERT(node->getUnrollFlag());
+
+ TIntermNode* init = node->getInit();
+ ASSERT(init != NULL);
+ TIntermAggregate* decl = init->getAsAggregate();
+ ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
+ TIntermSequence& declSeq = decl->getSequence();
+ ASSERT(declSeq.size() == 1);
+ TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
+ ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
+ TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
+ ASSERT(symbol != NULL);
+ ASSERT(symbol->getBasicType() == EbtInt);
+
+ info.id = symbol->getId();
+
+ ASSERT(declInit->getRight() != NULL);
+ TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
+ ASSERT(initNode != NULL);
+
+ info.initValue = evaluateIntConstant(initNode);
+ info.currentValue = info.initValue;
+
+ TIntermNode* cond = node->getCondition();
+ ASSERT(cond != NULL);
+ TIntermBinary* binOp = cond->getAsBinaryNode();
+ ASSERT(binOp != NULL);
+ ASSERT(binOp->getRight() != NULL);
+ ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
+
+ info.incrementValue = getLoopIncrement(node);
+ info.stopValue = evaluateIntConstant(
+ binOp->getRight()->getAsConstantUnion());
+ info.op = binOp->getOp();
+}
+
+void ForLoopUnroll::Step()
+{
+ ASSERT(mLoopIndexStack.size() > 0);
+ TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
+ info.currentValue += info.incrementValue;
+}
+
+bool ForLoopUnroll::SatisfiesLoopCondition()
+{
+ ASSERT(mLoopIndexStack.size() > 0);
+ TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
+ // Relational operator is one of: > >= < <= == or !=.
+ switch (info.op) {
+ case EOpEqual:
+ return (info.currentValue == info.stopValue);
+ case EOpNotEqual:
+ return (info.currentValue != info.stopValue);
+ case EOpLessThan:
+ return (info.currentValue < info.stopValue);
+ case EOpGreaterThan:
+ return (info.currentValue > info.stopValue);
+ case EOpLessThanEqual:
+ return (info.currentValue <= info.stopValue);
+ case EOpGreaterThanEqual:
+ return (info.currentValue >= info.stopValue);
+ default:
+ UNREACHABLE();
+ }
+ return false;
+}
+
+bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
+{
+ for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
+ i != mLoopIndexStack.end();
+ ++i) {
+ if (i->id == symbol->getId())
+ return true;
+ }
+ return false;
+}
+
+int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
+{
+ for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
+ i != mLoopIndexStack.end();
+ ++i) {
+ if (i->id == symbol->getId())
+ return i->currentValue;
+ }
+ UNREACHABLE();
+ return false;
+}
+
+void ForLoopUnroll::Push(TLoopIndexInfo& info)
+{
+ mLoopIndexStack.push_back(info);
+}
+
+void ForLoopUnroll::Pop()
+{
+ mLoopIndexStack.pop_back();
+}
+
+int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
+{
+ TIntermNode* expr = node->getExpression();
+ ASSERT(expr != NULL);
+ // for expression has one of the following forms:
+ // loop_index++
+ // loop_index--
+ // loop_index += constant_expression
+ // loop_index -= constant_expression
+ // ++loop_index
+ // --loop_index
+ // The last two forms are not specified in the spec, but I am assuming
+ // its an oversight.
+ TIntermUnary* unOp = expr->getAsUnaryNode();
+ TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
+
+ TOperator op = EOpNull;
+ TIntermConstantUnion* incrementNode = NULL;
+ if (unOp != NULL) {
+ op = unOp->getOp();
+ } else if (binOp != NULL) {
+ op = binOp->getOp();
+ ASSERT(binOp->getRight() != NULL);
+ incrementNode = binOp->getRight()->getAsConstantUnion();
+ ASSERT(incrementNode != NULL);
+ }
+
+ int increment = 0;
+ // The operator is one of: ++ -- += -=.
+ switch (op) {
+ case EOpPostIncrement:
+ case EOpPreIncrement:
+ ASSERT((unOp != NULL) && (binOp == NULL));
+ increment = 1;
+ break;
+ case EOpPostDecrement:
+ case EOpPreDecrement:
+ ASSERT((unOp != NULL) && (binOp == NULL));
+ increment = -1;
+ break;
+ case EOpAddAssign:
+ ASSERT((unOp == NULL) && (binOp != NULL));
+ increment = evaluateIntConstant(incrementNode);
+ break;
+ case EOpSubAssign:
+ ASSERT((unOp == NULL) && (binOp != NULL));
+ increment = - evaluateIntConstant(incrementNode);
+ break;
+ default:
+ ASSERT(false);
+ }
+
+ return increment;
+}
+
+int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
+{
+ ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
+ return node->getUnionArrayPointer()->getIConst();
+}
+
diff --git a/src/compiler/ForLoopUnroll.h b/src/compiler/ForLoopUnroll.h
new file mode 100644
index 0000000..b2b2b58
--- /dev/null
+++ b/src/compiler/ForLoopUnroll.h
@@ -0,0 +1,46 @@
+//
+// Copyright (c) 2011 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+
+#include "compiler/intermediate.h"
+
+struct TLoopIndexInfo {
+ int id;
+ int initValue;
+ int stopValue;
+ int incrementValue;
+ TOperator op;
+ int currentValue;
+};
+
+class ForLoopUnroll {
+public:
+ ForLoopUnroll() { }
+
+ void FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info);
+
+ // Update the info.currentValue for the next loop iteration.
+ void Step();
+
+ // Return false if loop condition is no longer satisfied.
+ bool SatisfiesLoopCondition();
+
+ // Check if the symbol is the index of a loop that's unrolled.
+ bool NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol);
+
+ // Return the current value of a given loop index symbol.
+ int GetLoopIndexValue(TIntermSymbol* symbol);
+
+ void Push(TLoopIndexInfo& info);
+ void Pop();
+
+private:
+ int getLoopIncrement(TIntermLoop* node);
+
+ int evaluateIntConstant(TIntermConstantUnion* node);
+
+ TVector<TLoopIndexInfo> mLoopIndexStack;
+};
+
diff --git a/src/compiler/OutputGLSL.cpp b/src/compiler/OutputGLSL.cpp
index 23476f2..3224dfd 100644
--- a/src/compiler/OutputGLSL.cpp
+++ b/src/compiler/OutputGLSL.cpp
@@ -1,5 +1,5 @@
//
-// Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
+// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
@@ -195,7 +195,10 @@
void TOutputGLSL::visitSymbol(TIntermSymbol* node)
{
TInfoSinkBase& out = objSink();
- out << node->getSymbol();
+ if (mLoopUnroll.NeedsToReplaceSymbolWithValue(node))
+ out << mLoopUnroll.GetLoopIndexValue(node);
+ else
+ out << node->getSymbol();
if (mDeclaringVariables && node->getType().isArray())
out << arrayBrackets(node->getType());
@@ -615,18 +618,20 @@
TLoopType loopType = node->getType();
if (loopType == ELoopFor) // for loop
{
- out << "for (";
- if (node->getInit())
- node->getInit()->traverse(this);
- out << "; ";
+ if (!node->getUnrollFlag()) {
+ out << "for (";
+ if (node->getInit())
+ node->getInit()->traverse(this);
+ out << "; ";
- if (node->getCondition())
- node->getCondition()->traverse(this);
- out << "; ";
+ if (node->getCondition())
+ node->getCondition()->traverse(this);
+ out << "; ";
- if (node->getExpression())
- node->getExpression()->traverse(this);
- out << ")\n";
+ if (node->getExpression())
+ node->getExpression()->traverse(this);
+ out << ")\n";
+ }
}
else if (loopType == ELoopWhile) // while loop
{
@@ -642,7 +647,22 @@
}
// Loop body.
- visitCodeBlock(node->getBody());
+ if (node->getUnrollFlag())
+ {
+ TLoopIndexInfo indexInfo;
+ mLoopUnroll.FillLoopIndexInfo(node, indexInfo);
+ mLoopUnroll.Push(indexInfo);
+ while (mLoopUnroll.SatisfiesLoopCondition())
+ {
+ visitCodeBlock(node->getBody());
+ mLoopUnroll.Step();
+ }
+ mLoopUnroll.Pop();
+ }
+ else
+ {
+ visitCodeBlock(node->getBody());
+ }
// Loop footer.
if (loopType == ELoopDoWhile) // do-while loop
diff --git a/src/compiler/OutputGLSL.h b/src/compiler/OutputGLSL.h
index aa203d4..ace110a 100644
--- a/src/compiler/OutputGLSL.h
+++ b/src/compiler/OutputGLSL.h
@@ -9,6 +9,7 @@
#include <set>
+#include "compiler/ForLoopUnroll.h"
#include "compiler/intermediate.h"
#include "compiler/ParseHelper.h"
@@ -44,6 +45,8 @@
// declared only once.
typedef std::set<TString> DeclaredStructs;
DeclaredStructs mDeclaredStructs;
+
+ ForLoopUnroll mLoopUnroll;
};
#endif // CROSSCOMPILERGLSL_OUTPUTGLSL_H_
diff --git a/src/compiler/ValidateLimitations.cpp b/src/compiler/ValidateLimitations.cpp
index 886f693..b46e4b9 100644
--- a/src/compiler/ValidateLimitations.cpp
+++ b/src/compiler/ValidateLimitations.cpp
@@ -17,6 +17,17 @@
return false;
}
+void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
+ for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
+ if (i->index.id == symbol->getId()) {
+ ASSERT(i->loop != NULL);
+ i->loop->setUnrollFlag(true);
+ return;
+ }
+ }
+ UNREACHABLE();
+}
+
// Traverses a node to check if it represents a constant index expression.
// Definition:
// constant-index-expressions are a superset of constant-expressions.
@@ -54,6 +65,48 @@
bool mValid;
const TLoopStack& mLoopStack;
};
+
+// Traverses a node to check if it uses a loop index.
+// If an int loop index is used in its body as a sampler array index,
+// mark the loop for unroll.
+class ValidateLoopIndexExpr : public TIntermTraverser {
+public:
+ ValidateLoopIndexExpr(TLoopStack& stack)
+ : mUsesFloatLoopIndex(false),
+ mUsesIntLoopIndex(false),
+ mLoopStack(stack) {}
+
+ bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
+ bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
+
+ virtual void visitSymbol(TIntermSymbol* symbol) {
+ if (IsLoopIndex(symbol, mLoopStack)) {
+ switch (symbol->getBasicType()) {
+ case EbtFloat:
+ mUsesFloatLoopIndex = true;
+ break;
+ case EbtInt:
+ mUsesIntLoopIndex = true;
+ MarkLoopForUnroll(symbol, mLoopStack);
+ break;
+ default:
+ UNREACHABLE();
+ }
+ }
+ }
+ virtual void visitConstantUnion(TIntermConstantUnion*) {}
+ virtual bool visitBinary(Visit, TIntermBinary*) { return true; }
+ virtual bool visitUnary(Visit, TIntermUnary*) { return true; }
+ virtual bool visitSelection(Visit, TIntermSelection*) { return true; }
+ virtual bool visitAggregate(Visit, TIntermAggregate*) { return true; }
+ virtual bool visitLoop(Visit, TIntermLoop*) { return true; }
+ virtual bool visitBranch(Visit, TIntermBranch*) { return true; }
+
+private:
+ bool mUsesFloatLoopIndex;
+ bool mUsesIntLoopIndex;
+ TLoopStack& mLoopStack;
+};
} // namespace
ValidateLimitations::ValidateLimitations(ShShaderType shaderType,
@@ -80,7 +133,28 @@
// Check indexing.
switch (node->getOp()) {
case EOpIndexDirect:
+ validateIndexing(node);
+ break;
case EOpIndexIndirect:
+#if defined(__APPLE__)
+ // Loop unrolling is a work-around for a Mac Cg compiler bug where it
+ // crashes when a sampler array's index is also the loop index.
+ // Once Apple fixes this bug, we should remove the code in this CL.
+ // See http://codereview.appspot.com/4331048/.
+ if ((node->getLeft() != NULL) && (node->getRight() != NULL) &&
+ (node->getLeft()->getAsSymbolNode())) {
+ TIntermSymbol* symbol = node->getLeft()->getAsSymbolNode();
+ if (IsSampler(symbol->getBasicType()) && symbol->isArray()) {
+ ValidateLoopIndexExpr validate(mLoopStack);
+ node->getRight()->traverse(&validate);
+ if (validate.usesFloatLoopIndex()) {
+ error(node->getLine(),
+ "sampler array index is float loop index",
+ "for");
+ }
+ }
+ }
+#endif
validateIndexing(node);
break;
default: break;
@@ -120,6 +194,7 @@
TLoopInfo info;
memset(&info, 0, sizeof(TLoopInfo));
+ info.loop = node;
if (!validateForLoopHeader(node, &info))
return false;
diff --git a/src/compiler/ValidateLimitations.h b/src/compiler/ValidateLimitations.h
index a4f5a28..dd2e5bf 100644
--- a/src/compiler/ValidateLimitations.h
+++ b/src/compiler/ValidateLimitations.h
@@ -13,6 +13,7 @@
struct TIndex {
int id; // symbol id.
} index;
+ TIntermLoop* loop;
};
typedef TVector<TLoopInfo> TLoopStack;
diff --git a/src/compiler/intermediate.h b/src/compiler/intermediate.h
index c3c073c..cf91061 100644
--- a/src/compiler/intermediate.h
+++ b/src/compiler/intermediate.h
@@ -279,7 +279,8 @@
init(aInit),
cond(aCond),
expr(aExpr),
- body(aBody) { }
+ body(aBody),
+ unrollFlag(false) { }
virtual TIntermLoop* getAsLoopNode() { return this; }
virtual void traverse(TIntermTraverser*);
@@ -290,12 +291,17 @@
TIntermTyped* getExpression() { return expr; }
TIntermNode* getBody() { return body; }
+ void setUnrollFlag(bool flag) { unrollFlag = flag; }
+ bool getUnrollFlag() { return unrollFlag; }
+
protected:
TLoopType type;
TIntermNode* init; // for-loop initialization
TIntermTyped* cond; // loop exit condition
TIntermTyped* expr; // for-loop expression
TIntermNode* body; // loop body
+
+ bool unrollFlag; // Whether the loop should be unrolled or not.
};
//