Implement getBVecComponent.
Like getIVecComponent or getFVecComponent, this retrieves the n'th
element of a Boolean compile-time constant vector. This will be used in
followup CLs.
Change-Id: Ib41c9c89cb773251e4c0d6cdcaea0437d8074e48
Bug: skia:11141
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/350918
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 7a3cfe5..d38c7f8 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -74,47 +74,88 @@
}
const Constructor& c = other.as<Constructor>();
const Type& myType = this->type();
- const Type& otherType = c.type();
- SkASSERT(myType == otherType);
- if (otherType.isVector()) {
- bool isFloat = otherType.columns() > 1 ? otherType.componentType().isFloat()
- : otherType.isFloat();
- for (int i = 0; i < myType.columns(); i++) {
- if (isFloat) {
+ SkASSERT(myType == c.type());
+
+ if (myType.isVector()) {
+ if (myType.componentType().isFloat()) {
+ for (int i = 0; i < myType.columns(); i++) {
if (this->getFVecComponent(i) != c.getFVecComponent(i)) {
return ComparisonResult::kNotEqual;
}
- } else if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
- return ComparisonResult::kNotEqual;
+ }
+ return ComparisonResult::kEqual;
+ }
+ if (myType.componentType().isInteger()) {
+ for (int i = 0; i < myType.columns(); i++) {
+ if (this->getIVecComponent(i) != c.getIVecComponent(i)) {
+ return ComparisonResult::kNotEqual;
+ }
+ }
+ return ComparisonResult::kEqual;
+ }
+ if (myType.componentType().isBoolean()) {
+ for (int i = 0; i < myType.columns(); i++) {
+ if (this->getBVecComponent(i) != c.getBVecComponent(i)) {
+ return ComparisonResult::kNotEqual;
+ }
+ }
+ return ComparisonResult::kEqual;
+ }
+ }
+
+ if (myType.isMatrix()) {
+ for (int col = 0; col < myType.columns(); col++) {
+ for (int row = 0; row < myType.rows(); row++) {
+ if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
+ return ComparisonResult::kNotEqual;
+ }
}
}
return ComparisonResult::kEqual;
}
- // shouldn't be possible to have a constant constructor that isn't a vector or matrix;
- // a constant scalar constructor should have been collapsed down to the appropriate
- // literal
- SkASSERT(myType.isMatrix());
- for (int col = 0; col < myType.columns(); col++) {
- for (int row = 0; row < myType.rows(); row++) {
- if (getMatComponent(col, row) != c.getMatComponent(col, row)) {
- return ComparisonResult::kNotEqual;
- }
- }
- }
- return ComparisonResult::kEqual;
+
+ SkDEBUGFAILF("compareConstant unexpected type: %s", myType.description().c_str());
+ return ComparisonResult::kUnknown;
}
template <typename ResultType>
ResultType Constructor::getVecComponent(int index) const {
+ static_assert(std::is_same<ResultType, SKSL_FLOAT>::value ||
+ std::is_same<ResultType, SKSL_INT>::value ||
+ std::is_same<ResultType, bool>::value);
+
SkASSERT(this->type().isVector());
+ SkASSERT(this->isCompileTimeConstant());
+
+ auto getConstantValue = [](const Expression& expr) -> ResultType {
+ if constexpr (std::is_same<ResultType, SKSL_FLOAT>::value) {
+ return expr.getConstantFloat();
+ } else if constexpr (std::is_same<ResultType, SKSL_INT>::value) {
+ return expr.getConstantInt();
+ } else if constexpr (std::is_same<ResultType, bool>::value) {
+ return expr.getConstantBool();
+ }
+ SkDEBUGFAILF("unrecognized kind of constant value: %s", expr.description().c_str());
+ return ResultType(0);
+ };
+
+ auto getInnerVecComponent = [](const Expression& expr, int position) -> ResultType {
+ const Type& type = expr.type().componentType();
+ if (type.isFloat()) {
+ return ResultType(expr.getVecComponent<SKSL_FLOAT>(position));
+ } else if (type.isInteger()) {
+ return ResultType(expr.getVecComponent<SKSL_INT>(position));
+ } else if (type.isBoolean()) {
+ return ResultType(expr.getVecComponent<bool>(position));
+ }
+ SkDEBUGFAILF("unrecognized type of constant: %s", expr.description().c_str());
+ return ResultType(0);
+ };
+
if (this->arguments().size() == 1 &&
this->arguments()[0]->type().isScalar()) {
// This constructor just wraps a scalar. Propagate out the value.
- if (std::is_floating_point<ResultType>::value) {
- return this->arguments()[0]->getConstantFloat();
- } else {
- return this->arguments()[0]->getConstantInt();
- }
+ return getConstantValue(*this->arguments()[0]);
}
// Walk through all the constructor arguments until we reach the index we're searching for.
@@ -128,53 +169,16 @@
if (arg->type().isScalar()) {
if (index == current) {
// We're on the proper argument, and it's a scalar; fetch it.
- if (std::is_floating_point<ResultType>::value) {
- return arg->getConstantFloat();
- } else {
- return arg->getConstantInt();
- }
+ return getConstantValue(*arg);
}
current++;
continue;
}
- switch (arg->kind()) {
- case Kind::kConstructor: {
- const Constructor& constructor = arg->as<Constructor>();
- if (current + constructor.type().columns() > index) {
- // We've found a constructor that overlaps the proper argument. Descend into
- // it, honoring the type.
- return constructor.componentType().isFloat()
- ? ResultType(constructor.getVecComponent<SKSL_FLOAT>(index - current))
- : ResultType(constructor.getVecComponent<SKSL_INT>(index - current));
- }
- break;
- }
- case Kind::kPrefix: {
- const PrefixExpression& prefix = arg->as<PrefixExpression>();
- if (current + prefix.type().columns() > index) {
- // We found a prefix operator that contains the proper argument. Descend
- // into it. We only support for constant propagation of the unary minus, so
- // we shouldn't see any other tokens here.
- SkASSERT(prefix.getOperator() == Token::Kind::TK_MINUS);
-
- const Expression& operand = *prefix.operand();
- if (operand.type().isVector()) {
- return operand.type().componentType().isFloat()
- ? -ResultType(operand.getVecComponent<SKSL_FLOAT>(index - current))
- : -ResultType(operand.getVecComponent<SKSL_INT>(index - current));
- } else {
- return operand.type().isFloat()
- ? -ResultType(operand.getConstantFloat())
- : -ResultType(operand.getConstantInt());
- }
- }
- break;
- }
- default: {
- SkDEBUGFAILF("unexpected component %d { %s } in %s\n",
- index, arg->description().c_str(), description().c_str());
- break;
+ if (arg->type().isVector()) {
+ if (current + arg->type().columns() > index) {
+ // We've found an expression that encompasses the proper argument. Descend into it.
+ return getInnerVecComponent(*arg, index - current);
}
}
@@ -182,11 +186,12 @@
}
SkDEBUGFAILF("failed to find vector component %d in %s\n", index, description().c_str());
- return -1;
+ return ResultType(0);
}
-template int Constructor::getVecComponent(int) const;
-template float Constructor::getVecComponent(int) const;
+template SKSL_INT Constructor::getVecComponent(int) const;
+template SKSL_FLOAT Constructor::getVecComponent(int) const;
+template bool Constructor::getVecComponent(int) const;
SKSL_FLOAT Constructor::getMatComponent(int col, int row) const {
SkDEBUGCODE(const Type& myType = this->type();)
diff --git a/src/sksl/ir/SkSLConstructor.h b/src/sksl/ir/SkSLConstructor.h
index 1355131..2eabede 100644
--- a/src/sksl/ir/SkSLConstructor.h
+++ b/src/sksl/ir/SkSLConstructor.h
@@ -107,12 +107,13 @@
ComparisonResult compareConstant(const Context& context,
const Expression& other) const override;
- template <typename resultType>
- resultType getVecComponent(int index) const;
+ template <typename ResultType>
+ ResultType getVecComponent(int index) const;
/**
* For a literal vector expression, return the float value of the n'th vector component. It is
- * an error to call this method on an expression which is not a vector of FloatLiterals.
+ * an error to call this method on an expression which is not a compile-time constant vector of
+ * floating-point type.
*/
SKSL_FLOAT getFVecComponent(int n) const override {
return this->getVecComponent<SKSL_FLOAT>(n);
@@ -120,12 +121,22 @@
/**
* For a literal vector expression, return the integer value of the n'th vector component. It is
- * an error to call this method on an expression which is not a vector of IntLiterals.
+ * an error to call this method on an expression which is not a compile-time constant vector of
+ * integer type.
*/
SKSL_INT getIVecComponent(int n) const override {
return this->getVecComponent<SKSL_INT>(n);
}
+ /**
+ * For a literal vector expression, return the boolean value of the n'th vector component. It is
+ * an error to call this method on an expression which is not a compile-time constant vector of
+ * Boolean type.
+ */
+ bool getBVecComponent(int n) const override {
+ return this->getVecComponent<bool>(n);
+ }
+
SKSL_FLOAT getMatComponent(int col, int row) const override;
SKSL_INT getConstantInt() const override;
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index d021897..a3c16b0 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -179,23 +179,36 @@
/**
* For a vector of floating point values, return the value of the n'th vector component. It is
- * an error to call this method on an expression which is not a vector of FloatLiterals.
+ * an error to call this method on an expression which is not a vector of floating-point
+ * constant expressions.
*/
virtual SKSL_FLOAT getFVecComponent(int n) const {
- SkASSERT(false);
+ SkDEBUGFAILF("expression does not support getVecComponent: %s",
+ this->description().c_str());
return 0;
}
/**
* For a vector of integer values, return the value of the n'th vector component. It is an error
- * to call this method on an expression which is not a vector of IntLiterals.
+ * to call this method on an expression which is not a vector of integer constant expressions.
*/
virtual SKSL_INT getIVecComponent(int n) const {
- SkASSERT(false);
+ SkDEBUGFAILF("expression does not support getVecComponent: %s",
+ this->description().c_str());
return 0;
}
/**
+ * For a vector of Boolean values, return the value of the n'th vector component. It is an error
+ * to call this method on an expression which is not a vector of Boolean constant expressions.
+ */
+ virtual bool getBVecComponent(int n) const {
+ SkDEBUGFAILF("expression does not support getVecComponent: %s",
+ this->description().c_str());
+ return false;
+ }
+
+ /**
* For a vector of literals, return the value of the n'th vector component. It is an error to
* call this method on an expression which is not a vector of Literal<T>.
*/
@@ -227,6 +240,10 @@
return this->getIVecComponent(index);
}
+template <> inline bool Expression::getVecComponent<bool>(int index) const {
+ return this->getBVecComponent(index);
+}
+
} // namespace SkSL
#endif
diff --git a/src/sksl/ir/SkSLPrefixExpression.h b/src/sksl/ir/SkSLPrefixExpression.h
index 354e64b..1cb982b 100644
--- a/src/sksl/ir/SkSLPrefixExpression.h
+++ b/src/sksl/ir/SkSLPrefixExpression.h
@@ -71,6 +71,11 @@
return -this->operand()->getIVecComponent(index);
}
+ bool getBVecComponent(int index) const override {
+ SkDEBUGFAIL("negation of boolean values is not allowed");
+ return this->operand()->getBVecComponent(index);
+ }
+
SKSL_FLOAT getMatComponent(int col, int row) const override {
SkASSERT(this->getOperator() == Token::Kind::TK_MINUS);
return -this->operand()->getMatComponent(col, row);