Add constant folding support for geometric built-ins
* re-land after mac compilation fix (unsupported c++11 initializer) *
This change adds constant folding support for following geometric
built-ins:
- length, distance, dot, cross, normalize, faceforward,
reflect and refract.
BUG=angleproject:913
TEST=angle_unittests, dEQP Tests
dEQP-GLES3.functional.shaders.constant_expressions.builtin_functions.geometric.*
(56 tests started passing with this change)
Change-Id: I236fc0c1af47a63f359564500c711e6bedf1c808
Reviewed-on: https://chromium-review.googlesource.com/274789
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 5319bc7..5aed4f8 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -168,6 +168,25 @@
}
}
+float VectorLength(TConstantUnion *paramArray, size_t paramArraySize)
+{
+ float result = 0.0f;
+ for (size_t i = 0; i < paramArraySize; i++)
+ {
+ float f = paramArray[i].getFConst();
+ result += f * f;
+ }
+ return sqrtf(result);
+}
+
+float VectorDotProduct(TConstantUnion *paramArray1, TConstantUnion *paramArray2, size_t paramArraySize)
+{
+ float result = 0.0f;
+ for (size_t i = 0; i < paramArraySize; i++)
+ result += paramArray1[i].getFConst() * paramArray2[i].getFConst();
+ return result;
+}
+
} // namespace anonymous
@@ -1179,14 +1198,15 @@
return tempNode;
}
- else if (op == EOpAny || op == EOpAll)
+ else if (op == EOpAny || op == EOpAll || op == EOpLength)
{
// Do operations where the return type is different from the operand type.
- TType returnType(EbtBool, EbpUndefined, EvqConst);
+ TType returnType;
TConstantUnion *tempConstArray = nullptr;
- if (op == EOpAny)
+ switch (op)
{
+ case EOpAny:
if (getType().getBasicType() == EbtBool)
{
tempConstArray = new TConstantUnion();
@@ -1199,17 +1219,16 @@
break;
}
}
+ returnType = TType(EbtBool, EbpUndefined, EvqConst);
+ break;
}
else
{
- infoSink.info.message(
- EPrefixInternalError, getLine(),
- "Unary operation not folded into constant");
+ infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
return nullptr;
}
- }
- else if (op == EOpAll)
- {
+
+ case EOpAll:
if (getType().getBasicType() == EbtBool)
{
tempConstArray = new TConstantUnion();
@@ -1222,15 +1241,33 @@
break;
}
}
+ returnType = TType(EbtBool, EbpUndefined, EvqConst);
+ break;
}
else
{
- infoSink.info.message(
- EPrefixInternalError, getLine(),
- "Unary operation not folded into constant");
+ infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
return nullptr;
}
+
+ case EOpLength:
+ if (getType().getBasicType() == EbtFloat)
+ {
+ tempConstArray = new TConstantUnion();
+ tempConstArray->setFConst(VectorLength(unionArray, objectSize));
+ returnType = TType(EbtFloat, getType().getPrecision(), EvqConst);
+ break;
+ }
+ else
+ {
+ infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
+ return nullptr;
+ }
+
+ default:
+ break;
}
+
TIntermConstantUnion *tempNode = new TIntermConstantUnion(tempConstArray, returnType);
tempNode->setLine(getLine());
return tempNode;
@@ -1575,6 +1612,20 @@
"Unary operation not folded into constant");
return nullptr;
+ case EOpNormalize:
+ if (getType().getBasicType() == EbtFloat)
+ {
+ float x = unionArray[i].getFConst();
+ float length = VectorLength(unionArray, objectSize);
+ if (length)
+ tempConstArray[i].setFConst(x / length);
+ else
+ UndefinedConstantFoldingError(getLine(), op, getType().getBasicType(), infoSink, &tempConstArray[i]);
+ break;
+ }
+ infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
+ return nullptr;
+
default:
return nullptr;
}
@@ -1918,6 +1969,70 @@
}
break;
+ case EOpDistance:
+ if (basicType == EbtFloat)
+ {
+ TConstantUnion *distanceArray = new TConstantUnion[maxObjectSize];
+ tempConstArray = new TConstantUnion();
+ for (size_t i = 0; i < maxObjectSize; i++)
+ {
+ float x = unionArrays[0][i].getFConst();
+ float y = unionArrays[1][i].getFConst();
+ distanceArray[i].setFConst(x - y);
+ }
+ tempConstArray->setFConst(VectorLength(distanceArray, maxObjectSize));
+ }
+ else
+ UNREACHABLE();
+ break;
+
+ case EOpDot:
+ if (basicType == EbtFloat)
+ {
+ tempConstArray = new TConstantUnion();
+ tempConstArray->setFConst(VectorDotProduct(unionArrays[0], unionArrays[1], maxObjectSize));
+ }
+ else
+ UNREACHABLE();
+ break;
+
+ case EOpCross:
+ if (basicType == EbtFloat && maxObjectSize == 3)
+ {
+ tempConstArray = new TConstantUnion[maxObjectSize];
+ float x0 = unionArrays[0][0].getFConst();
+ float x1 = unionArrays[0][1].getFConst();
+ float x2 = unionArrays[0][2].getFConst();
+ float y0 = unionArrays[1][0].getFConst();
+ float y1 = unionArrays[1][1].getFConst();
+ float y2 = unionArrays[1][2].getFConst();
+ tempConstArray[0].setFConst(x1 * y2 - y1 * x2);
+ tempConstArray[1].setFConst(x2 * y0 - y2 * x0);
+ tempConstArray[2].setFConst(x0 * y1 - y0 * x1);
+ }
+ else
+ UNREACHABLE();
+ break;
+
+ case EOpReflect:
+ if (basicType == EbtFloat)
+ {
+ // genType reflect (genType I, genType N) :
+ // For the incident vector I and surface orientation N, returns the reflection direction:
+ // I - 2 * dot(N, I) * N.
+ tempConstArray = new TConstantUnion[maxObjectSize];
+ float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
+ for (size_t i = 0; i < maxObjectSize; i++)
+ {
+ float result = unionArrays[0][i].getFConst() -
+ 2.0f * dotProduct * unionArrays[1][i].getFConst();
+ tempConstArray[i].setFConst(result);
+ }
+ }
+ else
+ UNREACHABLE();
+ break;
+
default:
UNREACHABLE();
// TODO: Add constant folding support for other built-in operations that take 2 parameters and not handled above.
@@ -2043,6 +2158,53 @@
}
break;
+ case EOpFaceForward:
+ if (basicType == EbtFloat)
+ {
+ // genType faceforward(genType N, genType I, genType Nref) :
+ // If dot(Nref, I) < 0 return N, otherwise return -N.
+ tempConstArray = new TConstantUnion[maxObjectSize];
+ float dotProduct = VectorDotProduct(unionArrays[2], unionArrays[1], maxObjectSize);
+ for (size_t i = 0; i < maxObjectSize; i++)
+ {
+ if (dotProduct < 0)
+ tempConstArray[i].setFConst(unionArrays[0][i].getFConst());
+ else
+ tempConstArray[i].setFConst(-unionArrays[0][i].getFConst());
+ }
+ }
+ else
+ UNREACHABLE();
+ break;
+
+ case EOpRefract:
+ if (basicType == EbtFloat)
+ {
+ // genType refract(genType I, genType N, float eta) :
+ // For the incident vector I and surface normal N, and the ratio of indices of refraction eta,
+ // return the refraction vector. The result is computed by
+ // k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+ // if (k < 0.0)
+ // return genType(0.0)
+ // else
+ // return eta * I - (eta * dot(N, I) + sqrt(k)) * N
+ tempConstArray = new TConstantUnion[maxObjectSize];
+ float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
+ for (size_t i = 0; i < maxObjectSize; i++)
+ {
+ float eta = unionArrays[2][i].getFConst();
+ float k = 1.0f - eta * eta * (1.0f - dotProduct * dotProduct);
+ if (k < 0.0f)
+ tempConstArray[i].setFConst(0.0f);
+ else
+ tempConstArray[i].setFConst(eta * unionArrays[0][i].getFConst() -
+ (eta * dotProduct + sqrtf(k)) * unionArrays[1][i].getFConst());
+ }
+ }
+ else
+ UNREACHABLE();
+ break;
+
default:
UNREACHABLE();
// TODO: Add constant folding support for other built-in operations that take 3 parameters and not handled above.
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index a5f83a4..1d8b36b 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -477,6 +477,12 @@
case EOpGreaterThanEqual:
case EOpVectorEqual:
case EOpVectorNotEqual:
+ case EOpDistance:
+ case EOpDot:
+ case EOpCross:
+ case EOpFaceForward:
+ case EOpReflect:
+ case EOpRefract:
return TIntermConstantUnion::FoldAggregateBuiltIn(op, aggregate, mInfoSink);
default:
// Constant folding not supported for the built-in.
diff --git a/src/tests/compiler_tests/ConstantFolding_test.cpp b/src/tests/compiler_tests/ConstantFolding_test.cpp
index 40ddfc8..ce66bdd 100644
--- a/src/tests/compiler_tests/ConstantFolding_test.cpp
+++ b/src/tests/compiler_tests/ConstantFolding_test.cpp
@@ -7,32 +7,53 @@
// Tests for constant folding
//
+#include <vector>
+
#include "angle_gl.h"
#include "gtest/gtest.h"
#include "GLSLANG/ShaderLang.h"
#include "compiler/translator/PoolAlloc.h"
#include "compiler/translator/TranslatorESSL.h"
+template <typename T>
class ConstantFinder : public TIntermTraverser
{
public:
- ConstantFinder(TConstantUnion constToFind)
- : mConstToFind(constToFind),
+ ConstantFinder(const std::vector<T> &constantVector)
+ : mConstantVector(constantVector),
mFound(false)
{}
- virtual void visitConstantUnion(TIntermConstantUnion *node)
+ ConstantFinder(const T &value)
+ : mFound(false)
{
- if (node->getUnionArrayPointer()[0] == mConstToFind)
+ mConstantVector.push_back(value);
+ }
+
+ void visitConstantUnion(TIntermConstantUnion *node)
+ {
+ if (node->getType().getObjectSize() == mConstantVector.size())
{
- mFound = true;
+ bool found = true;
+ for (size_t i = 0; i < mConstantVector.size(); i++)
+ {
+ if (node->getUnionArrayPointer()[i] != mConstantVector[i])
+ {
+ found = false;
+ break;
+ }
+ }
+ if (found)
+ {
+ mFound = found;
+ }
}
}
bool found() const { return mFound; }
private:
- TConstantUnion mConstToFind;
+ std::vector<T> mConstantVector;
bool mFound;
};
@@ -72,18 +93,20 @@
}
}
- bool constantFoundInAST(TConstantUnion c)
+ template <typename T>
+ bool constantFoundInAST(T constant)
{
- ConstantFinder finder(c);
+ ConstantFinder<T> finder(constant);
mASTRoot->traverse(&finder);
return finder.found();
}
- bool constantFoundInAST(int i)
+ template <typename T>
+ bool constantVectorFoundInAST(const std::vector<T> &constantVector)
{
- TConstantUnion c;
- c.setIConst(i);
- return constantFoundInAST(c);
+ ConstantFinder<T> finder(constantVector);
+ mASTRoot->traverse(&finder);
+ return finder.found();
}
private:
@@ -173,3 +196,28 @@
ASSERT_FALSE(constantFoundInAST(5));
ASSERT_TRUE(constantFoundInAST(4));
}
+
+TEST_F(ConstantFoldingTest, FoldVectorCrossProduct)
+{
+ const std::string &shaderString =
+ "#version 300 es\n"
+ "precision mediump float;\n"
+ "out vec3 my_Vec3;"
+ "void main() {\n"
+ " const vec3 v3 = cross(vec3(1.0f, 1.0f, 1.0f), vec3(1.0f, -1.0f, 1.0f));\n"
+ " my_Vec3 = v3;\n"
+ "}\n";
+ compile(shaderString);
+ std::vector<float> input1(3, 1.0f);
+ ASSERT_FALSE(constantVectorFoundInAST(input1));
+ std::vector<float> input2;
+ input2.push_back(1.0f);
+ input2.push_back(-1.0f);
+ input2.push_back(1.0f);
+ ASSERT_FALSE(constantVectorFoundInAST(input2));
+ std::vector<float> result;
+ result.push_back(2.0f);
+ result.push_back(0.0f);
+ result.push_back(-2.0f);
+ ASSERT_TRUE(constantVectorFoundInAST(result));
+}