Implement getBVecComponent.

Like getIVecComponent or getFVecComponent, this retrieves the n'th
element of a Boolean compile-time constant vector. This will be used in
followup CLs.

Change-Id: Ib41c9c89cb773251e4c0d6cdcaea0437d8074e48
Bug: skia:11141
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/350918
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 7a3cfe5..d38c7f8 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -74,47 +74,88 @@
     }
     const Constructor& c = other.as<Constructor>();
     const Type& myType = this->type();
-    const Type& otherType = c.type();
-    SkASSERT(myType == otherType);
-    if (otherType.isVector()) {
-        bool isFloat = otherType.columns() > 1 ? otherType.componentType().isFloat()
-                                             : otherType.isFloat();
-        for (int i = 0; i < myType.columns(); i++) {
-            if (isFloat) {
+    SkASSERT(myType == c.type());
+
+    if (myType.isVector()) {
+        if (myType.componentType().isFloat()) {
+            for (int i = 0; i < myType.columns(); i++) {
                 if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
                     return ComparisonResult::kNotEqual;
                 }
-            } else if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
-                return ComparisonResult::kNotEqual;
+            }
+            return ComparisonResult::kEqual;
+        }
+        if (myType.componentType().isInteger()) {
+            for (int i = 0; i < myType.columns(); i++) {
+                if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
+                    return ComparisonResult::kNotEqual;
+                }
+            }
+            return ComparisonResult::kEqual;
+        }
+        if (myType.componentType().isBoolean()) {
+            for (int i = 0; i < myType.columns(); i++) {
+                if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
+                    return ComparisonResult::kNotEqual;
+                }
+            }
+            return ComparisonResult::kEqual;
+        }
+    }
+
+    if (myType.isMatrix()) {
+        for (int col = 0; col < myType.columns(); col++) {
+            for (int row = 0; row < myType.rows(); row++) {
+                if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
+                    return ComparisonResult::kNotEqual;
+                }
             }
         }
         return ComparisonResult::kEqual;
     }
-    // shouldn't be possible to have a constant constructor that isn't a vector or matrix;
-    // a constant scalar constructor should have been collapsed down to the appropriate
-    // literal
-    SkASSERT(myType.isMatrix());
-    for (int col = 0; col < myType.columns(); col++) {
-        for (int row = 0; row < myType.rows(); row++) {
-            if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
-                return ComparisonResult::kNotEqual;
-            }
-        }
-    }
-    return ComparisonResult::kEqual;
+
+    SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
+    return ComparisonResult::kUnknown;
 }
 
 template <typename ResultType>
 ResultType Constructor::getVecComponent(int index) const {
+    static_assert(std::is_same<ResultType, SKSL_FLOAT>::value ||
+                  std::is_same<ResultType, SKSL_INT>::value ||
+                  std::is_same<ResultType, bool>::value);
+
     SkASSERT(this->type().isVector());
+    SkASSERT(this->isCompileTimeConstant());
+
+    auto getConstantValue = [](const Expression& expr) -> ResultType {
+        if constexpr (std::is_same<ResultType, SKSL_FLOAT>::value) {
+            return expr.getConstantFloat();
+        } else if constexpr (std::is_same<ResultType, SKSL_INT>::value) {
+            return expr.getConstantInt();
+        } else if constexpr (std::is_same<ResultType, bool>::value) {
+            return expr.getConstantBool();
+        }
+        SkDEBUGFAILF("unrecognized kind of constant value: %s", expr.description().c_str());
+        return ResultType(0);
+    };
+
+    auto getInnerVecComponent = [](const Expression& expr, int position) -> ResultType {
+        const Type& type = expr.type().componentType();
+        if (type.isFloat()) {
+            return ResultType(expr.getVecComponent<SKSL_FLOAT>(position));
+        } else if (type.isInteger()) {
+            return ResultType(expr.getVecComponent<SKSL_INT>(position));
+        } else if (type.isBoolean()) {
+            return ResultType(expr.getVecComponent<bool>(position));
+        }
+        SkDEBUGFAILF("unrecognized type of constant: %s", expr.description().c_str());
+        return ResultType(0);
+    };
+
     if (this->arguments().size() == 1 &&
         this->arguments()[0]->type().isScalar()) {
         // This constructor just wraps a scalar. Propagate out the value.
-        if (std::is_floating_point<ResultType>::value) {
-            return this->arguments()[0]->getConstantFloat();
-        } else {
-            return this->arguments()[0]->getConstantInt();
-        }
+        return getConstantValue(*this->arguments()[0]);
     }
 
     // Walk through all the constructor arguments until we reach the index we're searching for.
@@ -128,53 +169,16 @@
         if (arg->type().isScalar()) {
             if (index == current) {
                 // We're on the proper argument, and it's a scalar; fetch it.
-                if (std::is_floating_point<ResultType>::value) {
-                    return arg->getConstantFloat();
-                } else {
-                    return arg->getConstantInt();
-                }
+                return getConstantValue(*arg);
             }
             current++;
             continue;
         }
 
-        switch (arg->kind()) {
-            case Kind::kConstructor: {
-                const Constructor& constructor = arg->as<Constructor>();
-                if (current + constructor.type().columns() > index) {
-                    // We've found a constructor that overlaps the proper argument. Descend into
-                    // it, honoring the type.
-                    return constructor.componentType().isFloat()
-                              ? ResultType(constructor.getVecComponent<SKSL_FLOAT>(index - current))
-                              : ResultType(constructor.getVecComponent<SKSL_INT>(index - current));
-                }
-                break;
-            }
-            case Kind::kPrefix: {
-                const PrefixExpression& prefix = arg->as<PrefixExpression>();
-                if (current + prefix.type().columns() > index) {
-                    // We found a prefix operator that contains the proper argument. Descend
-                    // into it. We only support for constant propagation of the unary minus, so
-                    // we shouldn't see any other tokens here.
-                    SkASSERT(prefix.getOperator() == Token::Kind::TK_MINUS);
-
-                    const Expression& operand = *prefix.operand();
-                    if (operand.type().isVector()) {
-                        return operand.type().componentType().isFloat()
-                                ? -ResultType(operand.getVecComponent<SKSL_FLOAT>(index - current))
-                                : -ResultType(operand.getVecComponent<SKSL_INT>(index - current));
-                    } else {
-                        return operand.type().isFloat()
-                                ? -ResultType(operand.getConstantFloat())
-                                : -ResultType(operand.getConstantInt());
-                    }
-                }
-                break;
-            }
-            default: {
-                SkDEBUGFAILF("unexpected component %d { %s } in %s\n",
-                             index, arg->description().c_str(), description().c_str());
-                break;
+        if (arg->type().isVector()) {
+            if (current + arg->type().columns() > index) {
+                // We've found an expression that encompasses the proper argument. Descend into it.
+                return getInnerVecComponent(*arg, index - current);
             }
         }
 
@@ -182,11 +186,12 @@
     }
 
     SkDEBUGFAILF("failed to find vector component %d in %s\n", index, description().c_str());
-    return -1;
+    return ResultType(0);
 }
 
-template int Constructor::getVecComponent(int) const;
-template float Constructor::getVecComponent(int) const;
+template SKSL_INT Constructor::getVecComponent(int) const;
+template SKSL_FLOAT Constructor::getVecComponent(int) const;
+template bool Constructor::getVecComponent(int) const;
 
 SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
     SkDEBUGCODE(const Type& myType = this->type();)
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 1355131..2eabede 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -107,12 +107,13 @@
     ComparisonResult compareConstant(const Context& context,
                                      const Expression& other) const override;
 
-    template <typename resultType>
-    resultType getVecComponent(int index) const;
+    template <typename ResultType>
+    ResultType getVecComponent(int index) const;
 
     /**
      * For a literal vector expression, return the float value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a vector of FloatLiterals.
+     * an error to call this method on an expression which is not a compile-time constant vector of
+     * floating-point type.
      */
     SKSL_FLOAT getFVecComponent(int n) const override {
         return this->getVecComponent<SKSL_FLOAT>(n);
@@ -120,12 +121,22 @@
 
     /**
      * For a literal vector expression, return the integer value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a vector of IntLiterals.
+     * an error to call this method on an expression which is not a compile-time constant vector of
+     * integer type.
      */
     SKSL_INT getIVecComponent(int n) const override {
         return this->getVecComponent<SKSL_INT>(n);
     }
 
+    /**
+     * For a literal vector expression, return the boolean value of the n'th vector component. It is
+     * an error to call this method on an expression which is not a compile-time constant vector of
+     * Boolean type.
+     */
+    bool getBVecComponent(int n) const override {
+        return this->getVecComponent<bool>(n);
+    }
+
     SKSL_FLOAT getMatComponent(int col, int row) const override;
 
     SKSL_INT getConstantInt() const override;
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index d021897..a3c16b0 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -179,23 +179,36 @@
 
     /**
      * For a vector of floating point values, return the value of the n'th vector component. It is
-     * an error to call this method on an expression which is not a vector of FloatLiterals.
+     * an error to call this method on an expression which is not a vector of floating-point
+     * constant expressions.
      */
     virtual SKSL_FLOAT getFVecComponent(int n) const {
-        SkASSERT(false);
+        SkDEBUGFAILF("expression does not support getVecComponent: %s",
+                     this->description().c_str());
         return 0;
     }
 
     /**
      * For a vector of integer values, return the value of the n'th vector component. It is an error
-     * to call this method on an expression which is not a vector of IntLiterals.
+     * to call this method on an expression which is not a vector of integer constant expressions.
      */
     virtual SKSL_INT getIVecComponent(int n) const {
-        SkASSERT(false);
+        SkDEBUGFAILF("expression does not support getVecComponent: %s",
+                     this->description().c_str());
         return 0;
     }
 
     /**
+     * For a vector of Boolean values, return the value of the n'th vector component. It is an error
+     * to call this method on an expression which is not a vector of Boolean constant expressions.
+     */
+    virtual bool getBVecComponent(int n) const {
+        SkDEBUGFAILF("expression does not support getVecComponent: %s",
+                     this->description().c_str());
+        return false;
+    }
+
+    /**
      * For a vector of literals, return the value of the n'th vector component. It is an error to
      * call this method on an expression which is not a vector of Literal<T>.
      */
@@ -227,6 +240,10 @@
     return this->getIVecComponent(index);
 }
 
+template <> inline bool Expression::getVecComponent<bool>(int index) const {
+    return this->getBVecComponent(index);
+}
+
 }  // namespace SkSL
 
 #endif
diff --git a/src/sksl/ir/SkSLPrefixExpression.h b/src/sksl/ir/SkSLPrefixExpression.h
index 354e64b..1cb982b 100644
--- a/src/sksl/ir/SkSLPrefixExpression.h
+++ b/src/sksl/ir/SkSLPrefixExpression.h
@@ -71,6 +71,11 @@
         return -this->operand()->getIVecComponent(index);
     }
 
+    bool getBVecComponent(int index) const override {
+        SkDEBUGFAIL("negation of boolean values is not allowed");
+        return this->operand()->getBVecComponent(index);
+    }
+
     SKSL_FLOAT getMatComponent(int col, int row) const override {
         SkASSERT(this->getOperator() == Token::Kind::TK_MINUS);
         return -this->operand()->getMatComponent(col, row);