Add support for constant folding of matrix-times-matrix.
This code should be easily adaptable to matrix-times-vector as well;
just treat the vector as a 1-row or 1-column matrix. I haven't gotten
around to writing tests for this, though.
Change-Id: If59ae52cd12952b44d3574d54398b2dc66edbcc8
Bug: skia:12819
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/505221
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index e5138e5..f429078 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -94,6 +94,58 @@
return nullptr;
}
+static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
+ const Expression& left,
+ const Expression& right) {
+ const Type& leftType = left.type();
+ const Type& rightType = right.type();
+
+ SkASSERT(leftType.isMatrix());
+ SkASSERT(rightType.isMatrix());
+
+ const Type& componentType = leftType.componentType();
+ SkASSERT(componentType.matches(rightType.componentType()));
+
+ const int leftColumns = leftType.columns(),
+ leftRows = leftType.rows(),
+ rightColumns = rightType.columns(),
+ rightRows = rightType.rows(),
+ outColumns = rightColumns,
+ outRows = leftRows;
+ SkASSERT(leftColumns == rightRows);
+ const Type& resultType = componentType.toCompound(context, outColumns, outRows);
+
+ // Fetch the left matrix.
+ double leftVals[4][4];
+ for (int c = 0; c < leftColumns; ++c) {
+ for (int r = 0; r < leftRows; ++r) {
+ leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
+ }
+ }
+ // Fetch the right matrix.
+ double rightVals[4][4];
+ for (int c = 0; c < rightColumns; ++c) {
+ for (int r = 0; r < rightRows; ++r) {
+ rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
+ }
+ }
+
+ ExpressionArray args;
+ args.reserve_back(outColumns * outRows);
+ for (int c = 0; c < outColumns; ++c) {
+ for (int r = 0; r < outRows; ++r) {
+ // Compute a dot product for this position.
+ double val = 0;
+ for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
+ val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
+ }
+ args.push_back(Literal::Make(left.fLine, val, &componentType));
+ }
+ }
+
+ return ConstructorCompound::Make(context, left.fLine, resultType, std::move(args));
+}
+
static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
const Expression& left,
Operator op,
@@ -533,8 +585,7 @@
// Perform matrix * matrix multiplication.
if (op.kind() == Token::Kind::TK_STAR && leftType.isMatrix() && rightType.isMatrix()) {
- // TODO(skia:12819): Implement matrix * matrix multiplication.
- return nullptr;
+ return simplify_matrix_times_matrix(context, *left, *right);
}
// Perform constant folding on pairs of vectors/matrices.
diff --git a/tests/sksl/folding/MatrixFoldingES2.glsl b/tests/sksl/folding/MatrixFoldingES2.glsl
index c666cc8..7eb59df 100644
--- a/tests/sksl/folding/MatrixFoldingES2.glsl
+++ b/tests/sksl/folding/MatrixFoldingES2.glsl
@@ -14,14 +14,10 @@
}
bool test_matrix_op_matrix_float_b() {
bool ok = true;
- ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
- ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;
}
bool test_matrix_op_matrix_half_b() {
bool ok = true;
- ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
- ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
return ok;
}
vec4 main() {
diff --git a/tests/sksl/folding/MatrixFoldingES3.glsl b/tests/sksl/folding/MatrixFoldingES3.glsl
index 9201ebb..5b733cc 100644
--- a/tests/sksl/folding/MatrixFoldingES3.glsl
+++ b/tests/sksl/folding/MatrixFoldingES3.glsl
@@ -8,12 +8,10 @@
}
bool test_matrix_op_matrix_float_b() {
bool ok = true;
- ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
return ok;
}
bool test_matrix_op_matrix_half_b() {
bool ok = true;
- ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
return ok;
}
vec4 main() {