Add support for multiple child nodes to SkImageFilters::RuntimeShader

Bug: skia:12766
Change-Id: I9dfe07a71961ab952c1593b9cc68c61191fbc13c
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/489536
Reviewed-by: Michael Ludwig <michaelludwig@google.com>
Reviewed-by: Derek Sollenberger <djsollen@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
diff --git a/RELEASE_NOTES.txt b/RELEASE_NOTES.txt
index 1c68bb3..ec71e24 100644
--- a/RELEASE_NOTES.txt
+++ b/RELEASE_NOTES.txt
@@ -15,6 +15,8 @@
     Note that if the destination surface has no color space (color space is `nullptr`), these
     intrinsics will do no conversion, and return the input color unchanged.
     https://review.skia.org/481416
+  * Added a new variant of SkImageFilters::RuntimeShader that supports multiple child nodes.
+    https://review.skia.org/489536
 
 * * *
 
diff --git a/gm/runtimeimagefilter.cpp b/gm/runtimeimagefilter.cpp
index 75522b3..e0166ae 100644
--- a/gm/runtimeimagefilter.cpp
+++ b/gm/runtimeimagefilter.cpp
@@ -20,6 +20,7 @@
 #include "include/effects/SkRuntimeEffect.h"
 #include "include/utils/SkRandom.h"
 #include "src/effects/imagefilters/SkRuntimeImageFilter.h"
+#include "tools/Resources.h"
 #include "tools/ToolUtils.h"
 
 static sk_sp<SkImageFilter> make_filter() {
@@ -70,3 +71,35 @@
     p.setPerspY(-0.0015f);
     draw_layer(250, 500, p);
 }
+
+DEF_SIMPLE_GM(rtif_unsharp, canvas, 512, 256) {
+    // Similar to "unsharp_rt", which does the entire unsharp filter in a single shader. This uses
+    // the image filter DAG to compute the blurred version, then does the weighted subtraction.
+    sk_sp<SkRuntimeEffect> effect = SkRuntimeEffect::MakeForShader(SkString(R"(
+        uniform shader content;
+        uniform shader blurred;
+        vec4 main(vec2 coord) {
+            vec4 c = content.eval(coord);
+            vec4 b = blurred.eval(coord);
+            return c + (c - b) * 4;
+        }
+    )")).effect;
+    SkRuntimeShaderBuilder builder(std::move(effect));
+
+    auto image = GetResourceAsImage("images/mandrill_256.png");
+    auto blurredSrc = SkImageFilters::Blur(1, 1, /*input=*/nullptr);
+
+    const char* childNames[] = { "content", "blurred" };
+    sk_sp<SkImageFilter> childNodes[] = { nullptr, blurredSrc };
+
+    auto sharpened = SkImageFilters::RuntimeShader(builder, childNames, childNodes, 2);
+
+    canvas->drawImage(image, 0, 0);
+    canvas->translate(256, 0);
+
+    SkPaint paint;
+    paint.setImageFilter(sharpened);
+    canvas->saveLayer({ 0, 0, 256, 256 }, &paint);
+    canvas->drawImage(image, 0, 0);
+    canvas->restore();
+}
diff --git a/include/effects/SkImageFilters.h b/include/effects/SkImageFilters.h
index e91837e..144bfb8 100644
--- a/include/effects/SkImageFilters.h
+++ b/include/effects/SkImageFilters.h
@@ -352,6 +352,26 @@
     static sk_sp<SkImageFilter> RuntimeShader(const SkRuntimeShaderBuilder& builder,
                                               const char* childShaderName,
                                               sk_sp<SkImageFilter> input);
+
+    /**
+     *  Create a filter that fills the output with the per-pixel evaluation of the SkShader produced
+     *  by the SkRuntimeShaderBuilder. The shader is defined in the image filter's local coordinate
+     *  system, so it will automatically be affected by SkCanvas' transform.
+     *
+     *  @param builder          The builder used to produce the runtime shader, that will in turn
+     *                          fill the result image
+     *  @param childShaderNames The names of the child shaders defined in the builder that will be
+     *                          bound to the input params (or the source image if the input param
+     *                          is null). If any name is null, or appears more than once, factory
+     *                          fails and returns nullptr.
+     *  @param inputs           The image filters that will be provided as input to the runtime
+     *                          shader. If any are null, the implicit source image is used instead.
+     *  @param inputCount       How many entries are present in 'childShaderNames' and 'inputs'.
+     */
+    static sk_sp<SkImageFilter> RuntimeShader(const SkRuntimeShaderBuilder& builder,
+                                              const char* childShaderNames[],
+                                              const sk_sp<SkImageFilter> inputs[],
+                                              int inputCount);
 #endif  // SK_ENABLE_SKSL
 
     enum class Dither : bool {
diff --git a/src/effects/imagefilters/SkRuntimeImageFilter.cpp b/src/effects/imagefilters/SkRuntimeImageFilter.cpp
index 41b7598..7856aae 100644
--- a/src/effects/imagefilters/SkRuntimeImageFilter.cpp
+++ b/src/effects/imagefilters/SkRuntimeImageFilter.cpp
@@ -27,13 +27,17 @@
                          sk_sp<SkImageFilter> input)
             : INHERITED(&input, 1, /*cropRect=*/nullptr)
             , fShaderBuilder(std::move(effect), std::move(uniforms))
-            , fChildShaderName(fShaderBuilder.effect()->children().front().name) {}
+            , fChildShaderNames(&fShaderBuilder.effect()->children().front().name, 1) {}
     SkRuntimeImageFilter(const SkRuntimeShaderBuilder& builder,
-                         const char* childShaderName,
-                         sk_sp<SkImageFilter> input)
-            : INHERITED(&input, 1, /*cropRect=*/nullptr)
-            , fShaderBuilder(builder)
-            , fChildShaderName(childShaderName) {}
+                         const char* childShaderNames[],
+                         const sk_sp<SkImageFilter> inputs[],
+                         int inputCount)
+            : INHERITED(inputs, inputCount, /*cropRect=*/nullptr)
+            , fShaderBuilder(builder) {
+        for (int i = 0; i < inputCount; i++) {
+            fChildShaderNames.push_back(SkString(childShaderNames[i]));
+        }
+    }
 
     bool onAffectsTransparentBlack() const override { return true; }
     MatrixCapability onGetCTMCapability() const override { return MatrixCapability::kTranslate; }
@@ -48,7 +52,7 @@
 
     mutable SkSpinlock fShaderBuilderLock;
     mutable SkRuntimeShaderBuilder fShaderBuilder;
-    SkString fChildShaderName;
+    SkSTArray<1, SkString> fChildShaderNames;
 
     using INHERITED = SkImageFilter_Base;
 };
@@ -74,7 +78,8 @@
 }
 
 sk_sp<SkFlattenable> SkRuntimeImageFilter::CreateProc(SkReadBuffer& buffer) {
-    SK_IMAGEFILTER_UNFLATTEN_COMMON(common, 1);
+    // We don't know how many inputs to expect yet. Passing -1 allows any number of children.
+    SK_IMAGEFILTER_UNFLATTEN_COMMON(common, -1);
     if (common.cropRect()) {
         return nullptr;
     }
@@ -93,11 +98,14 @@
         return nullptr;
     }
 
-    // Read the child shader name and make sure it matches one declared in the effect
-    SkString childShaderName;
-    buffer.readString(&childShaderName);
-    if (!buffer.validate(effect->findChild(childShaderName.c_str()) != nullptr)) {
-        return nullptr;
+    // Read the child shader names
+    SkSTArray<4, const char*> childShaderNames;
+    SkSTArray<4, SkString> childShaderNameStrings;
+    childShaderNames.resize(common.inputCount());
+    childShaderNameStrings.resize(common.inputCount());
+    for (int i = 0; i < common.inputCount(); i++) {
+        buffer.readString(&childShaderNameStrings[i]);
+        childShaderNames[i] = childShaderNameStrings[i].c_str();
     }
 
     SkRuntimeShaderBuilder builder(std::move(effect), std::move(uniforms));
@@ -125,7 +133,8 @@
         return nullptr;
     }
 
-    return SkImageFilters::RuntimeShader(builder, childShaderName.c_str(), common.getInput(0));
+    return SkImageFilters::RuntimeShader(
+            builder, childShaderNames.data(), common.inputs(), common.inputCount());
 }
 
 void SkRuntimeImageFilter::flatten(SkWriteBuffer& buffer) const {
@@ -133,7 +142,9 @@
     fShaderBuilderLock.acquire();
     buffer.writeString(fShaderBuilder.effect()->source().c_str());
     buffer.writeDataAsByteArray(fShaderBuilder.uniforms().get());
-    buffer.writeString(fChildShaderName.c_str());
+    for (const SkString& name : fChildShaderNames) {
+        buffer.writeString(name.c_str());
+    }
     for (size_t x = 0; x < fShaderBuilder.numChildren(); x++) {
         buffer.writeFlattenable(fShaderBuilder.children()[x].flattenable());
     }
@@ -144,12 +155,6 @@
 
 sk_sp<SkSpecialImage> SkRuntimeImageFilter::onFilterImage(const Context& ctx,
                                                           SkIPoint* offset) const {
-    SkIPoint inputOffset = SkIPoint::Make(0, 0);
-    sk_sp<SkSpecialImage> input(this->filterInput(0, ctx, &inputOffset));
-    if (!input) {
-        return nullptr;
-    }
-
     SkIRect outputBounds = SkIRect(ctx.desiredOutput());
     sk_sp<SkSpecialSurface> surf(ctx.makeSurface(outputBounds.size()));
     if (!surf) {
@@ -160,20 +165,37 @@
     SkMatrix inverse;
     SkAssertResult(ctm.invert(&inverse));
 
-    SkMatrix localM = inverse *
-                      SkMatrix::Translate(inputOffset) *
-                      SkMatrix::Translate(-input->subset().topLeft());
-    sk_sp<SkShader> inputShader =
-            input->asImage()->makeShader(SkSamplingOptions(SkFilterMode::kLinear), &localM);
-    SkASSERT(inputShader);
+    const int inputCount = this->countInputs();
+    SkASSERT(inputCount == fChildShaderNames.count());
+
+    SkSTArray<1, sk_sp<SkShader>> inputShaders;
+    for (int i = 0; i < inputCount; i++) {
+        SkIPoint inputOffset = SkIPoint::Make(0, 0);
+        sk_sp<SkSpecialImage> input(this->filterInput(i, ctx, &inputOffset));
+        if (!input) {
+            return nullptr;
+        }
+
+        SkMatrix localM = inverse *
+                          SkMatrix::Translate(inputOffset) *
+                          SkMatrix::Translate(-input->subset().topLeft());
+        sk_sp<SkShader> inputShader =
+                input->asImage()->makeShader(SkSamplingOptions(SkFilterMode::kLinear), &localM);
+        SkASSERT(inputShader);
+        inputShaders.push_back(std::move(inputShader));
+    }
 
     // lock the mutation of the builder and creation of the shader so that the builder's state is
     // const and is safe for multi-threaded access.
     fShaderBuilderLock.acquire();
-    fShaderBuilder.child(fChildShaderName.c_str()) = inputShader;
-    sk_sp<SkShader>   shader = fShaderBuilder.makeShader(nullptr, false);
-    // Remove the shader from the builder to avoid unnecessarily prolonging the shader's lifetime
-    fShaderBuilder.child(fChildShaderName.c_str()) = nullptr;
+    for (int i = 0; i < inputCount; i++) {
+        fShaderBuilder.child(fChildShaderNames[i].c_str()) = inputShaders[i];
+    }
+    sk_sp<SkShader> shader = fShaderBuilder.makeShader(nullptr, false);
+    // Remove the inputs from the builder to avoid unnecessarily prolonging the shader's lifetime
+    for (int i = 0; i < inputCount; i++) {
+        fShaderBuilder.child(fChildShaderNames[i].c_str()) = nullptr;
+    }
     fShaderBuilderLock.release();
 
     SkASSERT(shader.get());
@@ -196,6 +218,10 @@
     return surf->makeImageSnapshot();
 }
 
+static bool child_is_shader(const SkRuntimeEffect::Child* child) {
+    return child && child->type == SkRuntimeEffect::ChildType::kShader;
+}
+
 sk_sp<SkImageFilter> SkImageFilters::RuntimeShader(const SkRuntimeShaderBuilder& builder,
                                                    const char* childShaderName,
                                                    sk_sp<SkImageFilter> input) {
@@ -207,13 +233,32 @@
             return nullptr;
         }
         childShaderName = children.front().name.c_str();
-    } else if (builder.effect()->findChild(childShaderName) == nullptr) {
-        // there was no child declared in the runtime effect that matches the provided name
-        return nullptr;
+    }
+
+    return SkImageFilters::RuntimeShader(builder, &childShaderName, &input, 1);
+}
+
+sk_sp<SkImageFilter> SkImageFilters::RuntimeShader(const SkRuntimeShaderBuilder& builder,
+                                                   const char* childShaderNames[],
+                                                   const sk_sp<SkImageFilter> inputs[],
+                                                   int inputCount) {
+    for (int i = 0; i < inputCount; i++) {
+        const char* name = childShaderNames[i];
+        // All names must be non-null, and present as a child shader in the effect:
+        if (!name || !child_is_shader(builder.effect()->findChild(name))) {
+            return nullptr;
+        }
+
+        // We don't allow duplicates, either:
+        for (int j = 0; j < i; j++) {
+            if (!strcmp(name, childShaderNames[j])) {
+                return nullptr;
+            }
+        }
     }
 
     return sk_sp<SkImageFilter>(
-            new SkRuntimeImageFilter(builder, childShaderName, std::move(input)));
+            new SkRuntimeImageFilter(builder, childShaderNames, inputs, inputCount));
 }
 
 #endif  // SK_ENABLE_SKSL