Implement scalar `reflect` intrinsic in Metal.

Change-Id: I954af70f545a2258babd82af0d43d509201fdc59
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/348889
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 0eb047f..d694716 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -56,6 +56,7 @@
     fIntrinsicMap[String("mod")]                = SPECIAL(Mod);
     fIntrinsicMap[String("normalize")]          = SPECIAL(Normalize);
     fIntrinsicMap[String("radians")]            = SPECIAL(Radians);
+    fIntrinsicMap[String("reflect")]            = SPECIAL(Reflect);
     fIntrinsicMap[String("sample")]             = SPECIAL(Texture);
     fIntrinsicMap[String("equal")]              = METAL(Equal);
     fIntrinsicMap[String("notEqual")]           = METAL(NotEqual);
@@ -582,7 +583,20 @@
     return tempVar;
 }
 
-void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall & c, SpecialIntrinsic kind) {
+void MetalCodeGenerator::writeSimpleIntrinsic(const FunctionCall& c) {
+    // Write out an intrinsic function call exactly as-is. No muss no fuss.
+    this->write(c.function().name());
+    this->write("(");
+    const char* separator = "";
+    for (const std::unique_ptr<Expression>& arg : c.arguments()) {
+        this->write(separator);
+        separator = ", ";
+        this->writeExpression(*arg, kSequence_Precedence);
+    }
+    this->write(")");
+}
+
+void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind) {
     const ExpressionArray& arguments = c.arguments();
     switch (kind) {
         case kTexture_SpecialIntrinsic: {
@@ -625,11 +639,7 @@
                 this->writeExpression(*arguments[1], kAdditive_Precedence);
                 this->write(")");
             } else {
-                this->write("distance(");
-                this->writeExpression(*arguments[0], kSequence_Precedence);
-                this->write(", ");
-                this->writeExpression(*arguments[1], kSequence_Precedence);
-                this->write(")");
+                this->writeSimpleIntrinsic(c);
             }
             break;
         }
@@ -641,11 +651,7 @@
                 this->writeExpression(*arguments[1], kMultiplicative_Precedence);
                 this->write(")");
             } else {
-                this->write("dot(");
-                this->writeExpression(*arguments[0], kSequence_Precedence);
-                this->write(", ");
-                this->writeExpression(*arguments[1], kSequence_Precedence);
-                this->write(")");
+                this->writeSimpleIntrinsic(c);
             }
             break;
         }
@@ -660,13 +666,7 @@
                 this->writeExpression(*arguments[0], kSequence_Precedence);
                 this->write("))");
             } else {
-                this->write("faceforward(");
-                this->writeExpression(*arguments[0], kSequence_Precedence);
-                this->write(", ");
-                this->writeExpression(*arguments[1], kSequence_Precedence);
-                this->write(", ");
-                this->writeExpression(*arguments[2], kSequence_Precedence);
-                this->write(")");
+                this->writeSimpleIntrinsic(c);
             }
             break;
         }
@@ -701,6 +701,27 @@
             this->write(") * 0.0174532925)");
             break;
         }
+        case kReflect_SpecialIntrinsic: {
+            if (arguments[0]->type().columns() == 1) {
+                // We need to synthesize `I - 2 * N * I * N`.
+                String tmpI = this->getTempVariable(arguments[0]->type());
+                String tmpN = this->getTempVariable(arguments[1]->type());
+
+                // (_skTempI = ...
+                this->write("(" + tmpI + " = ");
+                this->writeExpression(*arguments[0], kSequence_Precedence);
+
+                // , _skTempN = ...
+                this->write(", " + tmpN + " = ");
+                this->writeExpression(*arguments[1], kSequence_Precedence);
+
+                // , _skTempI - 2 * _skTempN * _skTempI * _skTempN)
+                this->write(", " + tmpI + " - 2 * " + tmpN + " * " + tmpI + " * " + tmpN + ")");
+            } else {
+                this->writeSimpleIntrinsic(c);
+            }
+            break;
+        }
         case kBitCount_SpecialIntrinsic: {
             this->write("popcount(");
             this->writeExpression(*arguments[0], kSequence_Precedence);
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index ca5f3bd..96b55c7 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -115,6 +115,7 @@
         kMod_SpecialIntrinsic,
         kNormalize_SpecialIntrinsic,
         kRadians_SpecialIntrinsic,
+        kReflect_SpecialIntrinsic,
         kTexture_SpecialIntrinsic,
     };
 
@@ -235,6 +236,8 @@
     void writeMatrixCompMult();
     void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
 
+    void writeSimpleIntrinsic(const FunctionCall& c);
+
     void writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind);
 
     bool canCoerce(const Type& t1, const Type& t2);
diff --git a/tests/sksl/intrinsics/golden/Reflect.metal b/tests/sksl/intrinsics/golden/Reflect.metal
index 290ccb1..664eacf 100644
--- a/tests/sksl/intrinsics/golden/Reflect.metal
+++ b/tests/sksl/intrinsics/golden/Reflect.metal
@@ -17,7 +17,9 @@
 fragment Outputs fragmentMain(Inputs _in [[stage_in]], bool _frontFacing [[front_facing]], float4 _fragCoord [[position]]) {
     Outputs _outputStruct;
     thread Outputs* _out = &_outputStruct;
-    _out->sk_FragColor.x = reflect(_in.a, _in.b);
+    float _skTemp0;
+    float _skTemp1;
+    _out->sk_FragColor.x = (_skTemp0 = _in.a, _skTemp1 = _in.b, _skTemp0 - 2 * _skTemp1 * _skTemp0 * _skTemp1);
     _out->sk_FragColor = reflect(_in.c, _in.d);
     return *_out;
 }