Add support for constant folding of matrix-times-matrix.

This code should be easily adaptable to matrix-times-vector as well;
just treat the vector as a 1-row or 1-column matrix. I haven't gotten
around to writing tests for this, though.

Change-Id: If59ae52cd12952b44d3574d54398b2dc66edbcc8
Bug: skia:12819
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/505221
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index e5138e5..f429078 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -94,6 +94,58 @@
     return nullptr;
 }
 
+static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
+                                                                const Expression& left,
+                                                                const Expression& right) {
+    const Type& leftType = left.type();
+    const Type& rightType = right.type();
+
+    SkASSERT(leftType.isMatrix());
+    SkASSERT(rightType.isMatrix());
+
+    const Type& componentType = leftType.componentType();
+    SkASSERT(componentType.matches(rightType.componentType()));
+
+    const int leftColumns  = leftType.columns(),
+              leftRows     = leftType.rows(),
+              rightColumns = rightType.columns(),
+              rightRows    = rightType.rows(),
+              outColumns   = rightColumns,
+              outRows      = leftRows;
+    SkASSERT(leftColumns == rightRows);
+    const Type& resultType = componentType.toCompound(context, outColumns, outRows);
+
+    // Fetch the left matrix.
+    double leftVals[4][4];
+    for (int c = 0; c < leftColumns; ++c) {
+        for (int r = 0; r < leftRows; ++r) {
+            leftVals[c][r] = *left.getConstantValue((c * leftRows) + r);
+        }
+    }
+    // Fetch the right matrix.
+    double rightVals[4][4];
+    for (int c = 0; c < rightColumns; ++c) {
+        for (int r = 0; r < rightRows; ++r) {
+            rightVals[c][r] = *right.getConstantValue((c * rightRows) + r);
+        }
+    }
+
+    ExpressionArray args;
+    args.reserve_back(outColumns * outRows);
+    for (int c = 0; c < outColumns; ++c) {
+        for (int r = 0; r < outRows; ++r) {
+            // Compute a dot product for this position.
+            double val = 0;
+            for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
+                val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
+            }
+            args.push_back(Literal::Make(left.fLine, val, &componentType));
+        }
+    }
+
+    return ConstructorCompound::Make(context, left.fLine, resultType, std::move(args));
+}
+
 static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
                                                           const Expression& left,
                                                           Operator op,
@@ -533,8 +585,7 @@
 
     // Perform matrix * matrix multiplication.
     if (op.kind() == Token::Kind::TK_STAR && leftType.isMatrix() && rightType.isMatrix()) {
-        // TODO(skia:12819): Implement matrix * matrix multiplication.
-        return nullptr;
+        return simplify_matrix_times_matrix(context, *left, *right);
     }
 
     // Perform constant folding on pairs of vectors/matrices.
diff --git a/tests/sksl/folding/MatrixFoldingES2.glsl b/tests/sksl/folding/MatrixFoldingES2.glsl
index c666cc8..7eb59df 100644
--- a/tests/sksl/folding/MatrixFoldingES2.glsl
+++ b/tests/sksl/folding/MatrixFoldingES2.glsl
@@ -14,14 +14,10 @@
 }
 bool test_matrix_op_matrix_float_b() {
     bool ok = true;
-    ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
-    ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
     return ok;
 }
 bool test_matrix_op_matrix_half_b() {
     bool ok = true;
-    ok = ok && mat2(1.0, 2.0, 7.0, 4.0) * mat2(3.0, 5.0, 3.0, 2.0) == mat2(38.0, 26.0, 17.0, 14.0);
-    ok = ok && mat3(10.0, 4.0, 2.0, 20.0, 5.0, 3.0, 10.0, 6.0, 5.0) * mat3(3.0, 3.0, 4.0, 2.0, 3.0, 4.0, 4.0, 9.0, 2.0) == mat3(130.0, 51.0, 35.0, 120.0, 47.0, 33.0, 240.0, 73.0, 45.0);
     return ok;
 }
 vec4 main() {
diff --git a/tests/sksl/folding/MatrixFoldingES3.glsl b/tests/sksl/folding/MatrixFoldingES3.glsl
index 9201ebb..5b733cc 100644
--- a/tests/sksl/folding/MatrixFoldingES3.glsl
+++ b/tests/sksl/folding/MatrixFoldingES3.glsl
@@ -8,12 +8,10 @@
 }
 bool test_matrix_op_matrix_float_b() {
     bool ok = true;
-    ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
     return ok;
 }
 bool test_matrix_op_matrix_half_b() {
     bool ok = true;
-    ok = ok && mat3x2(1.0, 4.0, 2.0, 5.0, 3.0, 6.0) * mat2x3(7.0, 9.0, 11.0, 8.0, 10.0, 12.0) == mat2(58.0, 139.0, 64.0, 154.0);
     return ok;
 }
 vec4 main() {