Metal: Cache compute pipelines with render pipelines.

Support caching of compute pipelines in the same LRU cache as render
pipelines in mtl::PipelineCache.

Bug: chromium:1329376
Change-Id: I93bbfadb8f5c1461144f1c222362c174402cced1
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/4628673
Commit-Queue: Geoff Lang <geofflang@chromium.org>
Reviewed-by: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Quyen Le <lehoangquyen@chromium.org>
diff --git a/src/libANGLE/renderer/metal/mtl_pipeline_cache.h b/src/libANGLE/renderer/metal/mtl_pipeline_cache.h
index b4993e6..8888c41 100644
--- a/src/libANGLE/renderer/metal/mtl_pipeline_cache.h
+++ b/src/libANGLE/renderer/metal/mtl_pipeline_cache.h
@@ -19,13 +19,17 @@
 namespace mtl
 {
 
-struct RenderPipelineKey
+struct PipelineKey
 {
     AutoObjCPtr<id<MTLFunction>> vertexShader;
     AutoObjCPtr<id<MTLFunction>> fragmentShader;
     RenderPipelineDesc pipelineDesc;
 
-    bool operator==(const RenderPipelineKey &rhs) const;
+    AutoObjCPtr<id<MTLFunction>> computeShader;
+
+    bool isRenderPipeline() const;
+
+    bool operator==(const PipelineKey &rhs) const;
     size_t hash() const;
 };
 
@@ -36,9 +40,9 @@
 {
 
 template <>
-struct hash<rx::mtl::RenderPipelineKey>
+struct hash<rx::mtl::PipelineKey>
 {
-    size_t operator()(const rx::mtl::RenderPipelineKey &key) const { return key.hash(); }
+    size_t operator()(const rx::mtl::PipelineKey &key) const { return key.hash(); }
 };
 
 }  // namespace std
@@ -58,6 +62,9 @@
                                     id<MTLFunction> fragmentShader,
                                     const RenderPipelineDesc &desc,
                                     AutoObjCPtr<id<MTLRenderPipelineState>> *outRenderPipeline);
+    angle::Result getComputePipeline(ContextMtl *context,
+                                     id<MTLFunction> computeShader,
+                                     AutoObjCPtr<id<MTLComputePipelineState>> *outComputePipeline);
 
   private:
     static constexpr unsigned int kMaxPipelines = 128;
@@ -65,9 +72,14 @@
     // The cache tries to clean up this many states at once.
     static constexpr unsigned int kGCLimit = 32;
 
-    using RenderPipelineMap =
-        angle::base::HashingMRUCache<RenderPipelineKey, AutoObjCPtr<id<MTLRenderPipelineState>>>;
-    RenderPipelineMap mRenderPiplineCache;
+    struct PipelineVariant
+    {
+        AutoObjCPtr<id<MTLRenderPipelineState>> renderPipeline;
+        AutoObjCPtr<id<MTLComputePipelineState>> computePipeline;
+    };
+
+    using RenderPipelineMap = angle::base::HashingMRUCache<PipelineKey, PipelineVariant>;
+    RenderPipelineMap mPipelineCache;
 };
 
 }  // namespace mtl
diff --git a/src/libANGLE/renderer/metal/mtl_pipeline_cache.mm b/src/libANGLE/renderer/metal/mtl_pipeline_cache.mm
index 3526d95..6ea4c75 100644
--- a/src/libANGLE/renderer/metal/mtl_pipeline_cache.mm
+++ b/src/libANGLE/renderer/metal/mtl_pipeline_cache.mm
@@ -92,11 +92,12 @@
 }
 
 angle::Result CreateRenderPipelineState(ContextMtl *context,
-                                        const RenderPipelineKey &key,
+                                        const PipelineKey &key,
                                         AutoObjCPtr<id<MTLRenderPipelineState>> *outRenderPipeline)
 {
     ANGLE_MTL_OBJC_SCOPE
     {
+        ASSERT(key.isRenderPipeline());
         if (!key.vertexShader)
         {
             // Render pipeline without vertex shader is invalid.
@@ -138,20 +139,85 @@
     }
 }
 
+angle::Result CreateComputePipelineState(
+    ContextMtl *context,
+    const PipelineKey &key,
+    AutoObjCPtr<id<MTLComputePipelineState>> *outComputePipeline)
+{
+    ANGLE_MTL_OBJC_SCOPE
+    {
+        ASSERT(!key.isRenderPipeline());
+        if (!key.computeShader)
+        {
+            ANGLE_MTL_HANDLE_ERROR(context, "Compute pipeline without a shader is invalid.",
+                                   GL_INVALID_OPERATION);
+            return angle::Result::Stop;
+        }
+
+        const mtl::ContextDevice &metalDevice = context->getMetalDevice();
+
+        // Create pipeline state
+        NSError *err  = nil;
+        auto newState = metalDevice.newComputePipelineStateWithFunction(key.computeShader, &err);
+        if (err)
+        {
+            ANGLE_MTL_HANDLE_ERROR(context, mtl::FormatMetalErrorMessage(err).c_str(),
+                                   GL_INVALID_OPERATION);
+            return angle::Result::Stop;
+        }
+
+        *outComputePipeline = newState;
+        return angle::Result::Continue;
+    }
+}
+
 }  // namespace
 
-bool RenderPipelineKey::operator==(const RenderPipelineKey &rhs) const
+bool PipelineKey::isRenderPipeline() const
 {
-    return std::tie(vertexShader, fragmentShader, pipelineDesc) ==
-           std::tie(rhs.vertexShader, rhs.fragmentShader, rhs.pipelineDesc);
+    if (vertexShader)
+    {
+        ASSERT(!computeShader);
+        return true;
+    }
+    else
+    {
+        ASSERT(computeShader);
+        return false;
+    }
 }
 
-size_t RenderPipelineKey::hash() const
+bool PipelineKey::operator==(const PipelineKey &rhs) const
 {
-    return angle::HashMultiple(vertexShader.get(), fragmentShader.get(), pipelineDesc);
+    if (isRenderPipeline() != rhs.isRenderPipeline())
+    {
+        return false;
+    }
+
+    if (isRenderPipeline())
+    {
+        return std::tie(vertexShader, fragmentShader, pipelineDesc) ==
+               std::tie(rhs.vertexShader, rhs.fragmentShader, rhs.pipelineDesc);
+    }
+    else
+    {
+        return computeShader == rhs.computeShader;
+    }
 }
 
-PipelineCache::PipelineCache() : mRenderPiplineCache(kMaxPipelines) {}
+size_t PipelineKey::hash() const
+{
+    if (isRenderPipeline())
+    {
+        return angle::HashMultiple(vertexShader.get(), fragmentShader.get(), pipelineDesc);
+    }
+    else
+    {
+        return angle::HashMultiple(computeShader.get());
+    }
+}
+
+PipelineCache::PipelineCache() : mPipelineCache(kMaxPipelines) {}
 
 angle::Result PipelineCache::getRenderPipeline(
     ContextMtl *context,
@@ -160,26 +226,56 @@
     const RenderPipelineDesc &desc,
     AutoObjCPtr<id<MTLRenderPipelineState>> *outRenderPipeline)
 {
-    RenderPipelineKey key;
+    PipelineKey key;
     key.vertexShader.retainAssign(vertexShader);
     key.fragmentShader.retainAssign(fragmentShader);
     key.pipelineDesc = desc;
 
-    auto iter = mRenderPiplineCache.Get(key);
-    if (iter != mRenderPiplineCache.end())
+    auto iter = mPipelineCache.Get(key);
+    if (iter != mPipelineCache.end())
     {
-        *outRenderPipeline = iter->second;
+        // Should be no way that this key matched a compute pipeline entry
+        ASSERT(iter->second.renderPipeline);
+        *outRenderPipeline = iter->second.renderPipeline;
         return angle::Result::Continue;
     }
 
-    angle::TrimCache(kMaxPipelines, kGCLimit, "render pipeline", &mRenderPiplineCache);
+    angle::TrimCache(kMaxPipelines, kGCLimit, "render pipeline", &mPipelineCache);
 
-    AutoObjCPtr<id<MTLRenderPipelineState>> newPipeline;
-    ANGLE_TRY(CreateRenderPipelineState(context, key, &newPipeline));
+    PipelineVariant newPipeline;
+    ANGLE_TRY(CreateRenderPipelineState(context, key, &newPipeline.renderPipeline));
 
-    iter = mRenderPiplineCache.Put(std::move(key), std::move(newPipeline));
+    iter = mPipelineCache.Put(std::move(key), std::move(newPipeline));
 
-    *outRenderPipeline = iter->second;
+    *outRenderPipeline = iter->second.renderPipeline;
+    return angle::Result::Continue;
+}
+
+angle::Result PipelineCache::getComputePipeline(
+    ContextMtl *context,
+    id<MTLFunction> computeShader,
+    AutoObjCPtr<id<MTLComputePipelineState>> *outComputePipeline)
+{
+    PipelineKey key;
+    key.computeShader.retainAssign(computeShader);
+
+    auto iter = mPipelineCache.Get(key);
+    if (iter != mPipelineCache.end())
+    {
+        // Should be no way that this key matched a render pipeline entry
+        ASSERT(iter->second.computePipeline);
+        *outComputePipeline = iter->second.computePipeline;
+        return angle::Result::Continue;
+    }
+
+    angle::TrimCache(kMaxPipelines, kGCLimit, "render pipeline", &mPipelineCache);
+
+    PipelineVariant newPipeline;
+    ANGLE_TRY(CreateComputePipelineState(context, key, &newPipeline.computePipeline));
+
+    iter = mPipelineCache.Put(std::move(key), std::move(newPipeline));
+
+    *outComputePipeline = iter->second.computePipeline;
     return angle::Result::Continue;
 }