Add constant folding support for geometric built-ins

* re-land after mac compilation fix (unsupported c++11 initializer) *

This change adds constant folding support for following geometric
built-ins:
    - length, distance, dot, cross, normalize, faceforward,
      reflect and refract.

BUG=angleproject:913
TEST=angle_unittests, dEQP Tests
dEQP-GLES3.functional.shaders.constant_expressions.builtin_functions.geometric.*
(56 tests started passing with this change)

Change-Id: I236fc0c1af47a63f359564500c711e6bedf1c808
Reviewed-on: https://chromium-review.googlesource.com/274789
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 5319bc7..5aed4f8 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -168,6 +168,25 @@
     }
 }
 
+float VectorLength(TConstantUnion *paramArray, size_t paramArraySize)
+{
+    float result = 0.0f;
+    for (size_t i = 0; i < paramArraySize; i++)
+    {
+        float f = paramArray[i].getFConst();
+        result += f * f;
+    }
+    return sqrtf(result);
+}
+
+float VectorDotProduct(TConstantUnion *paramArray1, TConstantUnion *paramArray2, size_t paramArraySize)
+{
+    float result = 0.0f;
+    for (size_t i = 0; i < paramArraySize; i++)
+        result += paramArray1[i].getFConst() * paramArray2[i].getFConst();
+    return result;
+}
+
 }  // namespace anonymous
 
 
@@ -1179,14 +1198,15 @@
 
         return tempNode;
     }
-    else if (op == EOpAny || op == EOpAll)
+    else if (op == EOpAny || op == EOpAll || op == EOpLength)
     {
         // Do operations where the return type is different from the operand type.
 
-        TType returnType(EbtBool, EbpUndefined, EvqConst);
+        TType returnType;
         TConstantUnion *tempConstArray = nullptr;
-        if (op == EOpAny)
+        switch (op)
         {
+          case EOpAny:
             if (getType().getBasicType() == EbtBool)
             {
                 tempConstArray = new TConstantUnion();
@@ -1199,17 +1219,16 @@
                         break;
                     }
                 }
+                returnType = TType(EbtBool, EbpUndefined, EvqConst);
+                break;
             }
             else
             {
-                infoSink.info.message(
-                    EPrefixInternalError, getLine(),
-                    "Unary operation not folded into constant");
+                infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
                 return nullptr;
             }
-        }
-        else if (op == EOpAll)
-        {
+
+          case EOpAll:
             if (getType().getBasicType() == EbtBool)
             {
                 tempConstArray = new TConstantUnion();
@@ -1222,15 +1241,33 @@
                         break;
                     }
                 }
+                returnType = TType(EbtBool, EbpUndefined, EvqConst);
+                break;
             }
             else
             {
-                infoSink.info.message(
-                    EPrefixInternalError, getLine(),
-                    "Unary operation not folded into constant");
+                infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
                 return nullptr;
             }
+
+          case EOpLength:
+            if (getType().getBasicType() == EbtFloat)
+            {
+                tempConstArray = new TConstantUnion();
+                tempConstArray->setFConst(VectorLength(unionArray, objectSize));
+                returnType = TType(EbtFloat, getType().getPrecision(), EvqConst);
+                break;
+            }
+            else
+            {
+                infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
+                return nullptr;
+            }
+
+          default:
+            break;
         }
+
         TIntermConstantUnion *tempNode = new TIntermConstantUnion(tempConstArray, returnType);
         tempNode->setLine(getLine());
         return tempNode;
@@ -1575,6 +1612,20 @@
                     "Unary operation not folded into constant");
                 return nullptr;
 
+              case EOpNormalize:
+                if (getType().getBasicType() == EbtFloat)
+                {
+                    float x = unionArray[i].getFConst();
+                    float length = VectorLength(unionArray, objectSize);
+                    if (length)
+                        tempConstArray[i].setFConst(x / length);
+                    else
+                        UndefinedConstantFoldingError(getLine(), op, getType().getBasicType(), infoSink, &tempConstArray[i]);
+                    break;
+                }
+                infoSink.info.message(EPrefixInternalError, getLine(), "Unary operation not folded into constant");
+                return nullptr;
+
               default:
                 return nullptr;
             }
@@ -1918,6 +1969,70 @@
             }
             break;
 
+          case EOpDistance:
+            if (basicType == EbtFloat)
+            {
+                TConstantUnion *distanceArray = new TConstantUnion[maxObjectSize];
+                tempConstArray = new TConstantUnion();
+                for (size_t i = 0; i < maxObjectSize; i++)
+                {
+                    float x = unionArrays[0][i].getFConst();
+                    float y = unionArrays[1][i].getFConst();
+                    distanceArray[i].setFConst(x - y);
+                }
+                tempConstArray->setFConst(VectorLength(distanceArray, maxObjectSize));
+            }
+            else
+                UNREACHABLE();
+            break;
+
+          case EOpDot:
+            if (basicType == EbtFloat)
+            {
+                tempConstArray = new TConstantUnion();
+                tempConstArray->setFConst(VectorDotProduct(unionArrays[0], unionArrays[1], maxObjectSize));
+            }
+            else
+                UNREACHABLE();
+            break;
+
+          case EOpCross:
+            if (basicType == EbtFloat && maxObjectSize == 3)
+            {
+                tempConstArray = new TConstantUnion[maxObjectSize];
+                float x0 = unionArrays[0][0].getFConst();
+                float x1 = unionArrays[0][1].getFConst();
+                float x2 = unionArrays[0][2].getFConst();
+                float y0 = unionArrays[1][0].getFConst();
+                float y1 = unionArrays[1][1].getFConst();
+                float y2 = unionArrays[1][2].getFConst();
+                tempConstArray[0].setFConst(x1 * y2 - y1 * x2);
+                tempConstArray[1].setFConst(x2 * y0 - y2 * x0);
+                tempConstArray[2].setFConst(x0 * y1 - y0 * x1);
+            }
+            else
+                UNREACHABLE();
+            break;
+
+          case EOpReflect:
+            if (basicType == EbtFloat)
+            {
+                // genType reflect (genType I, genType N) :
+                //     For the incident vector I and surface orientation N, returns the reflection direction:
+                //     I - 2 * dot(N, I) * N.
+                tempConstArray = new TConstantUnion[maxObjectSize];
+                float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
+                for (size_t i = 0; i < maxObjectSize; i++)
+                {
+                    float result = unionArrays[0][i].getFConst() -
+                                   2.0f * dotProduct * unionArrays[1][i].getFConst();
+                    tempConstArray[i].setFConst(result);
+                }
+            }
+            else
+                UNREACHABLE();
+            break;
+
           default:
             UNREACHABLE();
             // TODO: Add constant folding support for other built-in operations that take 2 parameters and not handled above.
@@ -2043,6 +2158,53 @@
             }
             break;
 
+          case EOpFaceForward:
+            if (basicType == EbtFloat)
+            {
+                // genType faceforward(genType N, genType I, genType Nref) :
+                //     If dot(Nref, I) < 0 return N, otherwise return -N.
+                tempConstArray = new TConstantUnion[maxObjectSize];
+                float dotProduct = VectorDotProduct(unionArrays[2], unionArrays[1], maxObjectSize);
+                for (size_t i = 0; i < maxObjectSize; i++)
+                {
+                    if (dotProduct < 0)
+                        tempConstArray[i].setFConst(unionArrays[0][i].getFConst());
+                    else
+                        tempConstArray[i].setFConst(-unionArrays[0][i].getFConst());
+                }
+            }
+            else
+                UNREACHABLE();
+            break;
+
+          case EOpRefract:
+            if (basicType == EbtFloat)
+            {
+                // genType refract(genType I, genType N, float eta) :
+                //     For the incident vector I and surface normal N, and the ratio of indices of refraction eta,
+                //     return the refraction vector. The result is computed by
+                //         k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+                //         if (k < 0.0)
+                //             return genType(0.0)
+                //         else
+                //             return eta * I - (eta * dot(N, I) + sqrt(k)) * N
+                tempConstArray = new TConstantUnion[maxObjectSize];
+                float dotProduct = VectorDotProduct(unionArrays[1], unionArrays[0], maxObjectSize);
+                for (size_t i = 0; i < maxObjectSize; i++)
+                {
+                    float eta = unionArrays[2][i].getFConst();
+                    float k = 1.0f - eta * eta * (1.0f - dotProduct * dotProduct);
+                    if (k < 0.0f)
+                        tempConstArray[i].setFConst(0.0f);
+                    else
+                        tempConstArray[i].setFConst(eta * unionArrays[0][i].getFConst() -
+                                                    (eta * dotProduct + sqrtf(k)) * unionArrays[1][i].getFConst());
+                }
+            }
+            else
+                UNREACHABLE();
+            break;
+
           default:
             UNREACHABLE();
             // TODO: Add constant folding support for other built-in operations that take 3 parameters and not handled above.
diff --git a/src/compiler/translator/Intermediate.cpp b/src/compiler/translator/Intermediate.cpp
index a5f83a4..1d8b36b 100644
--- a/src/compiler/translator/Intermediate.cpp
+++ b/src/compiler/translator/Intermediate.cpp
@@ -477,6 +477,12 @@
       case EOpGreaterThanEqual:
       case EOpVectorEqual:
       case EOpVectorNotEqual:
+      case EOpDistance:
+      case EOpDot:
+      case EOpCross:
+      case EOpFaceForward:
+      case EOpReflect:
+      case EOpRefract:
         return TIntermConstantUnion::FoldAggregateBuiltIn(op, aggregate, mInfoSink);
       default:
         // Constant folding not supported for the built-in.
diff --git a/src/tests/compiler_tests/ConstantFolding_test.cpp b/src/tests/compiler_tests/ConstantFolding_test.cpp
index 40ddfc8..ce66bdd 100644
--- a/src/tests/compiler_tests/ConstantFolding_test.cpp
+++ b/src/tests/compiler_tests/ConstantFolding_test.cpp
@@ -7,32 +7,53 @@
 //   Tests for constant folding
 //
 
+#include <vector>
+
 #include "angle_gl.h"
 #include "gtest/gtest.h"
 #include "GLSLANG/ShaderLang.h"
 #include "compiler/translator/PoolAlloc.h"
 #include "compiler/translator/TranslatorESSL.h"
 
+template <typename T>
 class ConstantFinder : public TIntermTraverser
 {
   public:
-    ConstantFinder(TConstantUnion constToFind)
-        : mConstToFind(constToFind),
+    ConstantFinder(const std::vector<T> &constantVector)
+        : mConstantVector(constantVector),
           mFound(false)
     {}
 
-    virtual void visitConstantUnion(TIntermConstantUnion *node)
+    ConstantFinder(const T &value)
+        : mFound(false)
     {
-        if (node->getUnionArrayPointer()[0] == mConstToFind)
+        mConstantVector.push_back(value);
+    }
+
+    void visitConstantUnion(TIntermConstantUnion *node)
+    {
+        if (node->getType().getObjectSize() == mConstantVector.size())
         {
-            mFound = true;
+            bool found = true;
+            for (size_t i = 0; i < mConstantVector.size(); i++)
+            {
+                if (node->getUnionArrayPointer()[i] != mConstantVector[i])
+                {
+                    found = false;
+                    break;
+                }
+            }
+            if (found)
+            {
+                mFound = found;
+            }
         }
     }
 
     bool found() const { return mFound; }
 
   private:
-    TConstantUnion mConstToFind;
+    std::vector<T> mConstantVector;
     bool mFound;
 };
 
@@ -72,18 +93,20 @@
         }
     }
 
-    bool constantFoundInAST(TConstantUnion c)
+    template <typename T>
+    bool constantFoundInAST(T constant)
     {
-        ConstantFinder finder(c);
+        ConstantFinder<T> finder(constant);
         mASTRoot->traverse(&finder);
         return finder.found();
     }
 
-    bool constantFoundInAST(int i)
+    template <typename T>
+    bool constantVectorFoundInAST(const std::vector<T> &constantVector)
     {
-        TConstantUnion c;
-        c.setIConst(i);
-        return constantFoundInAST(c);
+        ConstantFinder<T> finder(constantVector);
+        mASTRoot->traverse(&finder);
+        return finder.found();
     }
 
   private:
@@ -173,3 +196,28 @@
     ASSERT_FALSE(constantFoundInAST(5));
     ASSERT_TRUE(constantFoundInAST(4));
 }
+
+TEST_F(ConstantFoldingTest, FoldVectorCrossProduct)
+{
+    const std::string &shaderString =
+        "#version 300 es\n"
+        "precision mediump float;\n"
+        "out vec3 my_Vec3;"
+        "void main() {\n"
+        "   const vec3 v3 = cross(vec3(1.0f, 1.0f, 1.0f), vec3(1.0f, -1.0f, 1.0f));\n"
+        "   my_Vec3 = v3;\n"
+        "}\n";
+    compile(shaderString);
+    std::vector<float> input1(3, 1.0f);
+    ASSERT_FALSE(constantVectorFoundInAST(input1));
+    std::vector<float> input2;
+    input2.push_back(1.0f);
+    input2.push_back(-1.0f);
+    input2.push_back(1.0f);
+    ASSERT_FALSE(constantVectorFoundInAST(input2));
+    std::vector<float> result;
+    result.push_back(2.0f);
+    result.push_back(0.0f);
+    result.push_back(-2.0f);
+    ASSERT_TRUE(constantVectorFoundInAST(result));
+}