Added prototypes for DSLFunction

Previously, there was no way to create a forward declaration for a DSL
function. To avoid introducing new API and make this work in an
intuitive fashion, we now create prototypes for all DSL functions and
remove them when the function is promptly defined.

Change-Id: Ief36164ceb303a3d76a57dc073f2e9b8409bb45f
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/436562
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/dsl/DSLFunction.cpp b/src/sksl/dsl/DSLFunction.cpp
index 6eafc99..8bd93ad 100644
--- a/src/sksl/dsl/DSLFunction.cpp
+++ b/src/sksl/dsl/DSLFunction.cpp
@@ -12,6 +12,7 @@
 #include "src/sksl/SkSLCompiler.h"
 #include "src/sksl/SkSLIRGenerator.h"
 #include "src/sksl/dsl/priv/DSLWriter.h"
+#include "src/sksl/ir/SkSLFunctionPrototype.h"
 #include "src/sksl/ir/SkSLReturnStatement.h"
 
 namespace SkSL {
@@ -60,6 +61,11 @@
         for (size_t i = 0; i < params.size(); ++i) {
             params[i]->fVar = fDecl->parameters()[i];
         }
+        // We don't know when this function is going to be defined; go ahead and add a prototype in
+        // case the definition is delayed. If we end up defining the function immediately, we'll
+        // remove the prototype in define().
+        DSLWriter::ProgramElements().push_back(std::make_unique<SkSL::FunctionPrototype>(
+                /*offset=*/-1, fDecl, DSLWriter::IsModule()));
     }
 }
 
@@ -70,6 +76,17 @@
         block.release();
         return;
     }
+    if (!DSLWriter::ProgramElements().empty()) {
+        // If the last ProgramElement was the prototype for this function, it was unnecessary and we
+        // can remove it.
+        const SkSL::ProgramElement& last = *DSLWriter::ProgramElements().back();
+        if (last.is<SkSL::FunctionPrototype>()) {
+            const SkSL::FunctionPrototype& prototype = last.as<SkSL::FunctionPrototype>();
+            if (&prototype.declaration() == fDecl) {
+                DSLWriter::ProgramElements().pop_back();
+            }
+        }
+    }
     SkASSERTF(!fDecl->definition(), "function '%s' already defined", fDecl->description().c_str());
     std::unique_ptr<Block> body = block.release();
     body = DSLWriter::IRGenerator().finalizeFunction(*fDecl, std::move(body));
diff --git a/tests/SkSLDSLTest.cpp b/tests/SkSLDSLTest.cpp
index 94095f7..8e1bfd2 100644
--- a/tests/SkSLDSLTest.cpp
+++ b/tests/SkSLDSLTest.cpp
@@ -2051,3 +2051,45 @@
     // Ensure that we can safely destroy statements and expressions despite being unused while
     // settings.fAssertDSLObjectsReleased is disabled.
 }
+
+DEF_GPUTEST_FOR_MOCK_CONTEXT(DSLPrototypes, r, ctxInfo) {
+    AutoDSLContext context(ctxInfo.directContext()->priv().getGpu(), no_mark_vars_declared());
+    {
+        DSLParameter x(kFloat_Type, "x");
+        DSLFunction sqr(kFloat_Type, "sqr", x);
+        REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[0], "float sqr(float x);");
+        sqr.define(
+            Return(x * x)
+        );
+        REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[0], "float sqr(float x) { return (x * x); }");
+    }
+
+    {
+        DSLWriter::Reset();
+            DSLParameter x(kFloat_Type, "x");
+        DSLFunction sqr(kFloat_Type, "sqr", x);
+        REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[0], "float sqr(float x);");
+        DSLFunction(kVoid_Type, "main").define(sqr(5));
+        REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 2);
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[0], "float sqr(float x);");
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[1], "void main() { sqr(5.0); }");
+        sqr.define(
+            Return(x * x)
+        );
+        REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 3);
+        EXPECT_EQUAL(*DSLWriter::ProgramElements()[2], "float sqr(float x) { return (x * x); }");
+
+        const char* source = "source test";
+        std::unique_ptr<SkSL::Program> p = ReleaseProgram(std::make_unique<SkSL::String>(source));
+        EXPECT_EQUAL(*p,
+            "layout (builtin = 17) in bool sk_Clockwise;"
+            "float sqr(float x);"
+            "void main() {"
+            "/* inlined: sqr */;"
+            "25.0;"
+            "}");
+    }
+}