Make shader recompile while parallel linking safe

Prior to this change, Program* objects held references to Shader*
objects.  This poses a problem where a shader recompile can race with a
program link, if the program link is done in parallel.

As a result, a good chunk of the link job is done serially and under the
share group lock.  After this change, that is no longer a problem, and
most of the link can be made lockless/parallelized.

This change separates out the "compiled state" from the rest of the
shader state.  This was already done for the front-end state (for the
sake of caching), but is also now done for the backends that need it.
The compiled state in turn is placed in a shared_ptr, and is shared with
the program.  When a shader is compiled, its own shared_ptr is replaced
with a new object, leaving all programs currently compiling unaffected
and using the previous compilation results.

Once a program is linked, its references to compiled shader states is
updated.

Bug: angleproject:8297
Change-Id: Iff7094a37088fbad99c6241f1c48b0bd4c820eb2
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/4791065
Reviewed-by: Geoff Lang <geofflang@chromium.org>
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
Reviewed-by: Charlie Lao <cclao@google.com>
diff --git a/src/common/CompiledShaderState.cpp b/src/common/CompiledShaderState.cpp
index 4bff565..aba469c 100644
--- a/src/common/CompiledShaderState.cpp
+++ b/src/common/CompiledShaderState.cpp
@@ -181,7 +181,18 @@
 }
 
 CompiledShaderState::CompiledShaderState(gl::ShaderType type)
-    : shaderType(type), shaderVersion(100), numViews(-1), geometryShaderInvocations(1)
+    : shaderType(type),
+      shaderVersion(100),
+      hasClipDistance(false),
+      hasDiscard(false),
+      enablesPerSampleShading(false),
+      numViews(-1),
+      geometryShaderInvocations(1),
+      tessControlShaderVertices(0),
+      tessGenMode(0),
+      tessGenSpacing(0),
+      tessGenVertexOrder(0),
+      tessGenPointMode(0)
 {
     localSize.fill(-1);
 }
diff --git a/src/common/CompiledShaderState.h b/src/common/CompiledShaderState.h
index b7ac430..b5abaa3 100644
--- a/src/common/CompiledShaderState.h
+++ b/src/common/CompiledShaderState.h
@@ -18,6 +18,7 @@
 #include <GLSLANG/ShaderLang.h>
 #include <GLSLANG/ShaderVars.h>
 
+#include <memory>
 #include <string>
 
 namespace sh
@@ -89,6 +90,8 @@
     GLenum tessGenVertexOrder;
     GLenum tessGenPointMode;
 };
+
+using SharedCompiledShaderState = std::shared_ptr<CompiledShaderState>;
 }  // namespace gl
 
 #endif  // COMMON_COMPILEDSHADERSTATE_H_
diff --git a/src/libANGLE/Context.cpp b/src/libANGLE/Context.cpp
index 800061d..af019f2 100644
--- a/src/libANGLE/Context.cpp
+++ b/src/libANGLE/Context.cpp
@@ -1105,7 +1105,7 @@
     const ShaderProgramID shaderID = PackParam<ShaderProgramID>(createShader(type));
     if (shaderID.value)
     {
-        Shader *shaderObject = getShader(shaderID);
+        Shader *shaderObject = getShaderNoResolveCompile(shaderID);
         ASSERT(shaderObject);
         shaderObject->setSource(this, count, strings, nullptr);
         shaderObject->compile(this);
@@ -1115,6 +1115,9 @@
             gl::Program *programObject = getProgramNoResolveLink(programID);
             ASSERT(programObject);
 
+            // Note: this call serializes the compilation with the following link.  For backends
+            // that prefer parallel compile and link, it's more efficient to remove this check, and
+            // let link fail instead.
             if (shaderObject->isCompiled(this))
             {
                 // As per Khronos issue 2261:
@@ -1355,7 +1358,7 @@
             return getBuffer({name});
         case GL_SHADER:
         case GL_SHADER_OBJECT_EXT:
-            return getShader({name});
+            return getShaderNoResolveCompile({name});
         case GL_PROGRAM:
         case GL_PROGRAM_OBJECT_EXT:
             return getProgramNoResolveLink({name});
@@ -6825,7 +6828,7 @@
     Program *programObject = getProgramNoResolveLink(program);
     ASSERT(programObject);
 
-    Shader *shaderObject = getShader(shader);
+    Shader *shaderObject = getShaderNoResolveCompile(shader);
     ASSERT(shaderObject);
 
     programObject->detachShader(this, shaderObject);
@@ -7042,7 +7045,7 @@
     Shader *shaderObject = nullptr;
     if (!isContextLost())
     {
-        shaderObject = getShader(shader);
+        shaderObject = getShaderNoResolveCompile(shader);
         ASSERT(shaderObject);
     }
     QueryShaderiv(this, shaderObject, pname, params);
@@ -7062,7 +7065,7 @@
                                GLsizei *length,
                                GLchar *infolog)
 {
-    Shader *shaderObject = getShader(shader);
+    Shader *shaderObject = getShaderNoResolveCompile(shader);
     ASSERT(shaderObject);
     shaderObject->getInfoLog(this, bufsize, length, infolog);
 }
@@ -7143,7 +7146,7 @@
                               GLsizei *length,
                               GLchar *source)
 {
-    Shader *shaderObject = getShader(shader);
+    Shader *shaderObject = getShaderNoResolveCompile(shader);
     ASSERT(shaderObject);
     shaderObject->getSource(bufsize, length, source);
 }
@@ -7234,7 +7237,7 @@
         return GL_FALSE;
     }
 
-    return ConvertToGLBoolean(getShader(shader));
+    return ConvertToGLBoolean(getShaderNoResolveCompile(shader));
 }
 
 GLboolean Context::isTexture(TextureID texture) const
@@ -7266,7 +7269,7 @@
                            const void *binary,
                            GLsizei length)
 {
-    Shader *shaderObject = getShader(*shaders);
+    Shader *shaderObject = getShaderNoResolveCompile(*shaders);
     ASSERT(shaderObject != nullptr);
     ANGLE_CONTEXT_TRY(shaderObject->loadShaderBinary(this, binary, length));
 }
@@ -7306,7 +7309,7 @@
                            const GLchar *const *string,
                            const GLint *length)
 {
-    Shader *shaderObject = getShader(shader);
+    Shader *shaderObject = getShaderNoResolveCompile(shader);
     ASSERT(shaderObject);
     shaderObject->setSource(this, count, string, length);
 }
@@ -8438,7 +8441,7 @@
                                         GLsizei *length,
                                         GLchar *source)
 {
-    Shader *shaderObject = getShader(shader);
+    Shader *shaderObject = getShaderNoResolveCompile(shader);
     ASSERT(shaderObject);
     shaderObject->getTranslatedSourceWithDebugInfo(this, bufsize, length, source);
 }
@@ -9299,7 +9302,17 @@
     return mState.mShaderProgramManager->getProgram(handle);
 }
 
-Shader *Context::getShader(ShaderProgramID handle) const
+Shader *Context::getShaderResolveCompile(ShaderProgramID handle) const
+{
+    Shader *shader = getShaderNoResolveCompile(handle);
+    if (shader)
+    {
+        shader->resolveCompile(this);
+    }
+    return shader;
+}
+
+Shader *Context::getShaderNoResolveCompile(ShaderProgramID handle) const
 {
     return mState.mShaderProgramManager->getShader(handle);
 }
diff --git a/src/libANGLE/Context.h b/src/libANGLE/Context.h
index a57b1f8..b783815 100644
--- a/src/libANGLE/Context.h
+++ b/src/libANGLE/Context.h
@@ -646,7 +646,8 @@
     }
 
     Program *getProgramNoResolveLink(ShaderProgramID handle) const;
-    Shader *getShader(ShaderProgramID handle) const;
+    Shader *getShaderResolveCompile(ShaderProgramID handle) const;
+    Shader *getShaderNoResolveCompile(ShaderProgramID handle) const;
 
     ANGLE_INLINE bool isTextureGenerated(TextureID texture) const
     {
diff --git a/src/libANGLE/Program.cpp b/src/libANGLE/Program.cpp
index ff5a3a1..543aebe 100644
--- a/src/libANGLE/Program.cpp
+++ b/src/libANGLE/Program.cpp
@@ -300,30 +300,26 @@
     }
 }
 
-void InitUniformBlockLinker(const Context *context,
-                            const ProgramState &state,
-                            UniformBlockLinker *blockLinker)
+void InitUniformBlockLinker(const ProgramState &state, UniformBlockLinker *blockLinker)
 {
     for (ShaderType shaderType : AllShaderTypes())
     {
-        Shader *shader = state.getAttachedShader(shaderType);
+        const SharedCompiledShaderState &shader = state.getAttachedShader(shaderType);
         if (shader)
         {
-            blockLinker->addShaderBlocks(shaderType, &shader->getUniformBlocks(context));
+            blockLinker->addShaderBlocks(shaderType, &shader->uniformBlocks);
         }
     }
 }
 
-void InitShaderStorageBlockLinker(const Context *context,
-                                  const ProgramState &state,
-                                  ShaderStorageBlockLinker *blockLinker)
+void InitShaderStorageBlockLinker(const ProgramState &state, ShaderStorageBlockLinker *blockLinker)
 {
     for (ShaderType shaderType : AllShaderTypes())
     {
-        Shader *shader = state.getAttachedShader(shaderType);
-        if (shader != nullptr)
+        const SharedCompiledShaderState &shader = state.getAttachedShader(shaderType);
+        if (shader)
         {
-            blockLinker->addShaderBlocks(shaderType, &shader->getShaderStorageBlocks(context));
+            blockLinker->addShaderBlocks(shaderType, &shader->shaderStorageBlocks);
         }
     }
 }
@@ -954,7 +950,7 @@
 
 ProgramState::~ProgramState()
 {
-    ASSERT(!hasAttachedShader());
+    ASSERT(!hasAnyAttachedShader());
 }
 
 const std::string &ProgramState::getLabel()
@@ -962,7 +958,7 @@
     return mLabel;
 }
 
-Shader *ProgramState::getAttachedShader(ShaderType shaderType) const
+SharedCompiledShaderState ProgramState::getAttachedShader(ShaderType shaderType) const
 {
     ASSERT(shaderType != ShaderType::InvalidEnum);
     return mAttachedShaders[shaderType];
@@ -1035,9 +1031,9 @@
     return static_cast<GLuint>(-1);
 }
 
-bool ProgramState::hasAttachedShader() const
+bool ProgramState::hasAnyAttachedShader() const
 {
-    for (const Shader *shader : mAttachedShaders)
+    for (const SharedCompiledShaderState &shader : mAttachedShaders)
     {
         if (shader)
         {
@@ -1090,7 +1086,8 @@
       mDeleteStatus(false),
       mRefCount(0),
       mResourceManager(manager),
-      mHandle(handle)
+      mHandle(handle),
+      mAttachedShaders{}
 {
     ASSERT(mProgram);
 
@@ -1107,16 +1104,18 @@
     resolveLink(context);
     for (ShaderType shaderType : AllShaderTypes())
     {
-        if (mState.mAttachedShaders[shaderType])
+        Shader *shader = getAttachedShader(shaderType);
+        if (shader != nullptr)
         {
-            mState.mAttachedShaders[shaderType]->release(context);
-            mState.mAttachedShaders[shaderType] = nullptr;
+            shader->release(context);
         }
+        mState.mAttachedShaders[shaderType].reset();
+        mAttachedShaders[shaderType] = nullptr;
     }
 
     mProgram->destroy(context);
 
-    ASSERT(!mState.hasAttachedShader());
+    ASSERT(!mState.hasAnyAttachedShader());
     SafeDelete(mProgram);
 
     delete this;
@@ -1150,8 +1149,8 @@
     ShaderType shaderType = shader->getType();
     ASSERT(shaderType != ShaderType::InvalidEnum);
 
-    mState.mAttachedShaders[shaderType] = shader;
-    mState.mAttachedShaders[shaderType]->addRef();
+    shader->addRef();
+    mAttachedShaders[shaderType] = shader;
 }
 
 void Program::detachShader(const Context *context, Shader *shader)
@@ -1160,18 +1159,19 @@
     ShaderType shaderType = shader->getType();
     ASSERT(shaderType != ShaderType::InvalidEnum);
 
-    ASSERT(mState.mAttachedShaders[shaderType] == shader);
+    ASSERT(mAttachedShaders[shaderType] == shader);
     shader->release(context);
-    mState.mAttachedShaders[shaderType] = nullptr;
+    mAttachedShaders[shaderType] = nullptr;
+    mState.mAttachedShaders[shaderType].reset();
 }
 
 int Program::getAttachedShadersCount() const
 {
     ASSERT(!mLinkingState);
     int numAttachedShaders = 0;
-    for (const Shader *shader : mState.mAttachedShaders)
+    for (const Shader *shader : mAttachedShaders)
     {
-        if (shader)
+        if (shader != nullptr)
         {
             ++numAttachedShaders;
         }
@@ -1183,7 +1183,7 @@
 Shader *Program::getAttachedShader(ShaderType shaderType) const
 {
     ASSERT(!mLinkingState);
-    return mState.getAttachedShader(shaderType);
+    return mAttachedShaders[shaderType];
 }
 
 void Program::bindAttributeLocation(GLuint index, const char *name)
@@ -1210,10 +1210,33 @@
 
 angle::Result Program::link(const Context *context)
 {
+    // Make sure no compile jobs are pending.
+    //
+    // For every attached shader, get the compiled state.  This is done at link time (instead of
+    // earlier, such as attachShader time), because the shader could get recompiled between attach
+    // and link.
+    //
+    // Additionally, make sure the backend is also able to cache the compiled state of its own
+    // ShaderImpl objects.
+    ShaderMap<rx::ShaderImpl *> shaderImpls = {};
+    for (ShaderType shaderType : AllShaderTypes())
+    {
+        Shader *shader = mAttachedShaders[shaderType];
+        SharedCompiledShaderState shaderCompiledState;
+        if (shader != nullptr)
+        {
+            shader->resolveCompile(context);
+            shaderCompiledState     = shader->getCompiledState();
+            shaderImpls[shaderType] = shader->getImplementation();
+        }
+        mState.mAttachedShaders[shaderType] = std::move(shaderCompiledState);
+    }
+    mProgram->prepareForLink(shaderImpls);
+
     const angle::FrontendFeatures &frontendFeatures = context->getFrontendFeatures();
     if (frontendFeatures.dumpShaderSource.enabled)
     {
-        dumpProgramInfo();
+        dumpProgramInfo(context);
     }
 
     angle::Result result = linkImpl(context);
@@ -1284,7 +1307,7 @@
 
     std::unique_ptr<LinkingState> linkingState(new LinkingState());
     ProgramMergedVaryings mergedVaryings;
-    LinkingVariables linkingVariables(context, mState);
+    LinkingVariables linkingVariables(mState);
     ProgramLinkedResources &resources = linkingState->resources;
 
     resources.init(&mState.mExecutable->mUniformBlocks, &mState.mExecutable->mUniforms,
@@ -1295,8 +1318,8 @@
     // TODO: Fix incomplete linking. http://anglebug.com/6358
     updateLinkedShaderStages();
 
-    InitUniformBlockLinker(context, mState, &resources.uniformBlockLinker);
-    InitShaderStorageBlockLinker(context, mState, &resources.shaderStorageBlockLinker);
+    InitUniformBlockLinker(mState, &resources.uniformBlockLinker);
+    InitShaderStorageBlockLinker(mState, &resources.shaderStorageBlockLinker);
 
     if (mState.mAttachedShaders[ShaderType::Compute])
     {
@@ -1337,7 +1360,7 @@
             return angle::Result::Continue;
         }
 
-        if (!linkVaryings(context, infoLog))
+        if (!linkVaryings(infoLog))
         {
             return angle::Result::Continue;
         }
@@ -1361,33 +1384,33 @@
             return angle::Result::Continue;
         }
 
-        gl::Shader *vertexShader = mState.mAttachedShaders[ShaderType::Vertex];
+        const SharedCompiledShaderState &vertexShader = mState.mAttachedShaders[ShaderType::Vertex];
         if (vertexShader)
         {
-            mState.mNumViews                               = vertexShader->getNumViews(context);
-            mState.mExecutable->mPODStruct.hasClipDistance = vertexShader->hasClipDistance();
-            mState.mSpecConstUsageBits |= vertexShader->getSpecConstUsageBits();
+            mState.mNumViews                               = vertexShader->numViews;
+            mState.mExecutable->mPODStruct.hasClipDistance = vertexShader->hasClipDistance;
+            mState.mSpecConstUsageBits |= vertexShader->specConstUsageBits;
         }
 
-        gl::Shader *fragmentShader = mState.mAttachedShaders[ShaderType::Fragment];
+        const SharedCompiledShaderState &fragmentShader =
+            mState.mAttachedShaders[ShaderType::Fragment];
         if (fragmentShader)
         {
             if (!mState.mExecutable->linkValidateOutputVariables(
                     context->getCaps(), context->getExtensions(), context->getClientVersion(),
                     combinedImageUniforms, combinedShaderStorageBlocks,
-                    fragmentShader->getActiveOutputVariables(context),
-                    fragmentShader->getShaderVersion(context), mFragmentOutputLocations,
-                    mFragmentOutputIndexes))
+                    fragmentShader->activeOutputVariables, fragmentShader->shaderVersion,
+                    mFragmentOutputLocations, mFragmentOutputIndexes))
             {
                 return angle::Result::Continue;
             }
 
-            mState.mExecutable->mPODStruct.hasDiscard = fragmentShader->hasDiscard();
+            mState.mExecutable->mPODStruct.hasDiscard = fragmentShader->hasDiscard;
             mState.mExecutable->mPODStruct.enablesPerSampleShading =
-                fragmentShader->enablesPerSampleShading();
+                fragmentShader->enablesPerSampleShading;
             mState.mExecutable->mPODStruct.advancedBlendEquations =
-                fragmentShader->getAdvancedBlendEquations();
-            mState.mSpecConstUsageBits |= fragmentShader->getSpecConstUsageBits();
+                fragmentShader->advancedBlendEquations;
+            mState.mSpecConstUsageBits |= fragmentShader->specConstUsageBits;
         }
 
         mergedVaryings = GetMergedVaryingsFromLinkingVariables(linkingVariables);
@@ -1407,8 +1430,8 @@
     mLinkingState->linkEvent         = mProgram->link(context, resources, infoLog, mergedVaryings);
 
     // Must be after mProgram->link() to avoid misleading the linker about output variables.
-    mState.updateProgramInterfaceInputs(context);
-    mState.updateProgramInterfaceOutputs(context);
+    mState.updateProgramInterfaceInputs();
+    mState.updateProgramInterfaceOutputs();
 
     if (mState.mSeparable)
     {
@@ -1478,11 +1501,11 @@
 {
     mState.mExecutable->resetLinkedShaderStages();
 
-    for (const Shader *shader : mState.mAttachedShaders)
+    for (ShaderType shaderType : AllShaderTypes())
     {
-        if (shader)
+        if (mState.mAttachedShaders[shaderType])
         {
-            mState.mExecutable->setLinkedShaderStages(shader->getType());
+            mState.mExecutable->setLinkedShaderStages(shaderType);
         }
     }
 }
@@ -1493,7 +1516,7 @@
     mExecutable->updateActiveSamplers(*this);
 }
 
-void ProgramState::updateProgramInterfaceInputs(const Context *context)
+void ProgramState::updateProgramInterfaceInputs()
 {
     const ShaderType firstAttachedShaderType = getFirstAttachedShaderStageType();
 
@@ -1503,13 +1526,13 @@
         return;
     }
 
-    Shader *shader = getAttachedShader(firstAttachedShaderType);
+    const SharedCompiledShaderState &shader = getAttachedShader(firstAttachedShaderType);
     ASSERT(shader);
 
     // Copy over each input varying, since the Shader could go away
-    if (shader->getType() == ShaderType::Compute)
+    if (shader->shaderType == ShaderType::Compute)
     {
-        for (const sh::ShaderVariable &attribute : shader->getAllAttributes(context))
+        for (const sh::ShaderVariable &attribute : shader->allAttributes)
         {
             // Compute Shaders have the following built-in input variables.
             //
@@ -1524,14 +1547,14 @@
     }
     else
     {
-        for (const sh::ShaderVariable &varying : shader->getInputVaryings(context))
+        for (const sh::ShaderVariable &varying : shader->inputVaryings)
         {
             UpdateInterfaceVariable(&mExecutable->mProgramInputs, varying);
         }
     }
 }
 
-void ProgramState::updateProgramInterfaceOutputs(const Context *context)
+void ProgramState::updateProgramInterfaceOutputs()
 {
     const ShaderType lastAttachedShaderType = getLastAttachedShaderStageType();
 
@@ -1546,11 +1569,11 @@
         return;
     }
 
-    Shader *shader = getAttachedShader(lastAttachedShaderType);
+    const SharedCompiledShaderState &shader = getAttachedShader(lastAttachedShaderType);
     ASSERT(shader);
 
     // Copy over each output varying, since the Shader could go away
-    for (const sh::ShaderVariable &varying : shader->getOutputVaryings(context))
+    for (const sh::ShaderVariable &varying : shader->outputVaryings)
     {
         UpdateInterfaceVariable(&mExecutable->mOutputVariables, varying);
     }
@@ -1760,9 +1783,9 @@
     ASSERT(!mLinkingState);
     int total = 0;
 
-    for (const Shader *shader : mState.mAttachedShaders)
+    for (const Shader *shader : mAttachedShaders)
     {
-        if (shader && (total < maxCount))
+        if (shader != nullptr && total < maxCount)
         {
             shaders[total] = shader->getHandle();
             ++total;
@@ -2847,13 +2870,14 @@
 
 bool Program::linkValidateShaders(const Context *context, InfoLog &infoLog)
 {
-    const ShaderMap<Shader *> &shaders = mState.mAttachedShaders;
+    const ShaderMap<SharedCompiledShaderState> &shaders = mState.mAttachedShaders;
 
-    bool isComputeShaderAttached  = shaders[ShaderType::Compute] != nullptr;
-    bool isGraphicsShaderAttached = shaders[ShaderType::Vertex] ||
-                                    shaders[ShaderType::TessControl] ||
-                                    shaders[ShaderType::TessEvaluation] ||
-                                    shaders[ShaderType::Geometry] || shaders[ShaderType::Fragment];
+    bool isComputeShaderAttached  = shaders[ShaderType::Compute].get() != nullptr;
+    bool isGraphicsShaderAttached = shaders[ShaderType::Vertex].get() != nullptr ||
+                                    shaders[ShaderType::TessControl].get() != nullptr ||
+                                    shaders[ShaderType::TessEvaluation].get() != nullptr ||
+                                    shaders[ShaderType::Geometry].get() != nullptr ||
+                                    shaders[ShaderType::Fragment].get() != nullptr;
     // Check whether we both have a compute and non-compute shaders attached.
     // If there are of both types attached, then linking should fail.
     // OpenGL ES 3.10, 7.3 Program Objects, under LinkProgram
@@ -2866,14 +2890,19 @@
     Optional<int> version;
     for (ShaderType shaderType : kAllGraphicsShaderTypes)
     {
-        Shader *shader = shaders[shaderType];
-        ASSERT(!shader || shader->getType() == shaderType);
-        if (!shader)
+        Shader *shaderObj = getAttachedShader(shaderType);
+        ASSERT(!shaderObj || shaderObj->getType() == shaderType);
+
+        const SharedCompiledShaderState &shader = shaders[shaderType];
+        ASSERT(!shader || shader->shaderType == shaderType);
+
+        if (!shaderObj)
         {
+            ASSERT(!shader);
             continue;
         }
 
-        if (!shader->isCompiled(context))
+        if (!shaderObj->isCompiled(context))
         {
             infoLog << ShaderTypeToString(shaderType) << " shader is not compiled.";
             return false;
@@ -2881,9 +2910,9 @@
 
         if (!version.valid())
         {
-            version = shader->getShaderVersion(context);
+            version = shader->shaderVersion;
         }
-        else if (version != shader->getShaderVersion(context))
+        else if (version != shader->shaderVersion)
         {
             infoLog << ShaderTypeToString(shaderType)
                     << " shader version does not match other shader versions.";
@@ -2893,9 +2922,9 @@
 
     if (isComputeShaderAttached)
     {
-        ASSERT(shaders[ShaderType::Compute]->getType() == ShaderType::Compute);
+        ASSERT(shaders[ShaderType::Compute]->shaderType == ShaderType::Compute);
 
-        mState.mComputeShaderLocalSize = shaders[ShaderType::Compute]->getWorkGroupSize(context);
+        mState.mComputeShaderLocalSize = shaders[ShaderType::Compute]->localSize;
 
         // GLSL ES 3.10, 4.4.1.1 Compute Shader Inputs
         // If the work group size is not specified, a link time error should occur.
@@ -2913,8 +2942,8 @@
             return false;
         }
 
-        bool hasVertex   = shaders[ShaderType::Vertex] != nullptr;
-        bool hasFragment = shaders[ShaderType::Fragment] != nullptr;
+        bool hasVertex   = shaders[ShaderType::Vertex].get() != nullptr;
+        bool hasFragment = shaders[ShaderType::Fragment].get() != nullptr;
         if (!isSeparable() && (!hasVertex || !hasFragment))
         {
             infoLog
@@ -2922,16 +2951,16 @@
             return false;
         }
 
-        bool hasTessControl    = shaders[ShaderType::TessControl] != nullptr;
-        bool hasTessEvaluation = shaders[ShaderType::TessEvaluation] != nullptr;
+        bool hasTessControl    = shaders[ShaderType::TessControl].get() != nullptr;
+        bool hasTessEvaluation = shaders[ShaderType::TessEvaluation].get() != nullptr;
         if (!isSeparable() && (hasTessControl != hasTessEvaluation))
         {
             infoLog << "Tessellation control and evaluation shaders must be specified together.";
             return false;
         }
 
-        Shader *geometryShader = shaders[ShaderType::Geometry];
-        if (shaders[ShaderType::Geometry])
+        const SharedCompiledShaderState &geometryShader = shaders[ShaderType::Geometry];
+        if (geometryShader)
         {
             // [GL_EXT_geometry_shader] Chapter 7
             // Linking can fail for a variety of reasons as specified in the OpenGL ES Shading
@@ -2943,10 +2972,8 @@
             //   - <program> is not separable and contains no objects to form a vertex shader; or
             //   - the input primitive type, output primitive type, or maximum output vertex count
             //     is not specified in the compiled geometry shader object.
-            ASSERT(geometryShader->getType() == ShaderType::Geometry);
-
             Optional<PrimitiveMode> inputPrimitive =
-                geometryShader->getGeometryShaderInputPrimitiveType(context);
+                geometryShader->geometryShaderInputPrimitiveType;
             if (!inputPrimitive.valid())
             {
                 infoLog << "Input primitive type is not specified in the geometry shader.";
@@ -2954,14 +2981,14 @@
             }
 
             Optional<PrimitiveMode> outputPrimitive =
-                geometryShader->getGeometryShaderOutputPrimitiveType(context);
+                geometryShader->geometryShaderOutputPrimitiveType;
             if (!outputPrimitive.valid())
             {
                 infoLog << "Output primitive type is not specified in the geometry shader.";
                 return false;
             }
 
-            Optional<GLint> maxVertices = geometryShader->getGeometryShaderMaxVertices(context);
+            Optional<GLint> maxVertices = geometryShader->geometryShaderMaxVertices;
             if (!maxVertices.valid())
             {
                 infoLog << "'max_vertices' is not specified in the geometry shader.";
@@ -2974,13 +3001,13 @@
                 outputPrimitive.value();
             mState.mExecutable->mPODStruct.geometryShaderMaxVertices = maxVertices.value();
             mState.mExecutable->mPODStruct.geometryShaderInvocations =
-                geometryShader->getGeometryShaderInvocations(context);
+                geometryShader->geometryShaderInvocations;
         }
 
-        Shader *tessControlShader = shaders[ShaderType::TessControl];
+        const SharedCompiledShaderState &tessControlShader = shaders[ShaderType::TessControl];
         if (tessControlShader)
         {
-            int tcsShaderVertices = tessControlShader->getTessControlShaderVertices(context);
+            int tcsShaderVertices = tessControlShader->tessControlShaderVertices;
             if (tcsShaderVertices == 0)
             {
                 // In tessellation control shader, output vertices should be specified at least
@@ -2999,10 +3026,10 @@
             mState.mExecutable->mPODStruct.tessControlShaderVertices = tcsShaderVertices;
         }
 
-        Shader *tessEvaluationShader = shaders[ShaderType::TessEvaluation];
+        const SharedCompiledShaderState &tessEvaluationShader = shaders[ShaderType::TessEvaluation];
         if (tessEvaluationShader)
         {
-            GLenum tesPrimitiveMode = tessEvaluationShader->getTessGenMode(context);
+            GLenum tesPrimitiveMode = tessEvaluationShader->tessGenMode;
             if (tesPrimitiveMode == 0)
             {
                 // In tessellation evaluation shader, a primitive mode should be specified at least
@@ -3019,13 +3046,12 @@
                 return false;
             }
 
-            mState.mExecutable->mPODStruct.tessGenMode = tesPrimitiveMode;
-            mState.mExecutable->mPODStruct.tessGenSpacing =
-                tessEvaluationShader->getTessGenSpacing(context);
+            mState.mExecutable->mPODStruct.tessGenMode    = tesPrimitiveMode;
+            mState.mExecutable->mPODStruct.tessGenSpacing = tessEvaluationShader->tessGenSpacing;
             mState.mExecutable->mPODStruct.tessGenVertexOrder =
-                tessEvaluationShader->getTessGenVertexOrder(context);
+                tessEvaluationShader->tessGenVertexOrder;
             mState.mExecutable->mPODStruct.tessGenPointMode =
-                tessEvaluationShader->getTessGenPointMode(context);
+                tessEvaluationShader->tessGenPointMode;
         }
     }
 
@@ -3104,12 +3130,12 @@
     mProgram->setUniform1iv(mState.mBaseInstanceLocation, 1, &baseInstanceInt);
 }
 
-bool Program::linkVaryings(const Context *context, InfoLog &infoLog) const
+bool Program::linkVaryings(InfoLog &infoLog) const
 {
     ShaderType previousShaderType = ShaderType::InvalidEnum;
     for (ShaderType shaderType : kAllGraphicsShaderTypes)
     {
-        Shader *currentShader = mState.mAttachedShaders[shaderType];
+        const SharedCompiledShaderState &currentShader = mState.mAttachedShaders[shaderType];
         if (!currentShader)
         {
             continue;
@@ -3117,33 +3143,32 @@
 
         if (previousShaderType != ShaderType::InvalidEnum)
         {
-            Shader *previousShader = mState.mAttachedShaders[previousShaderType];
-            const std::vector<sh::ShaderVariable> &outputVaryings =
-                previousShader->getOutputVaryings(context);
+            const SharedCompiledShaderState &previousShader =
+                mState.mAttachedShaders[previousShaderType];
+            const std::vector<sh::ShaderVariable> &outputVaryings = previousShader->outputVaryings;
 
             if (!LinkValidateShaderInterfaceMatching(
-                    outputVaryings, currentShader->getInputVaryings(context), previousShaderType,
-                    currentShader->getType(), previousShader->getShaderVersion(context),
-                    currentShader->getShaderVersion(context), isSeparable(), infoLog))
+                    outputVaryings, currentShader->inputVaryings, previousShaderType,
+                    currentShader->shaderType, previousShader->shaderVersion,
+                    currentShader->shaderVersion, isSeparable(), infoLog))
             {
                 return false;
             }
         }
-        previousShaderType = currentShader->getType();
+        previousShaderType = currentShader->shaderType;
     }
 
     // TODO: http://anglebug.com/3571 and http://anglebug.com/3572
     // Need to move logic of validating builtin varyings inside the for-loop above.
     // This is because the built-in symbols `gl_ClipDistance` and `gl_CullDistance`
     // can be redeclared in Geometry or Tessellation shaders as well.
-    Shader *vertexShader   = mState.mAttachedShaders[ShaderType::Vertex];
-    Shader *fragmentShader = mState.mAttachedShaders[ShaderType::Fragment];
+    const SharedCompiledShaderState &vertexShader   = mState.mAttachedShaders[ShaderType::Vertex];
+    const SharedCompiledShaderState &fragmentShader = mState.mAttachedShaders[ShaderType::Fragment];
     if (vertexShader && fragmentShader &&
-        !LinkValidateBuiltInVaryings(vertexShader->getOutputVaryings(context),
-                                     fragmentShader->getInputVaryings(context),
-                                     vertexShader->getType(), fragmentShader->getType(),
-                                     vertexShader->getShaderVersion(context),
-                                     fragmentShader->getShaderVersion(context), infoLog))
+        !LinkValidateBuiltInVaryings(vertexShader->outputVaryings, fragmentShader->inputVaryings,
+                                     vertexShader->shaderType, fragmentShader->shaderType,
+                                     vertexShader->shaderVersion, fragmentShader->shaderVersion,
+                                     infoLog))
     {
         return false;
     }
@@ -3158,11 +3183,11 @@
 {
     // Initialize executable shader map.
     ShaderMap<std::vector<sh::ShaderVariable>> shaderUniforms;
-    for (Shader *shader : mState.mAttachedShaders)
+    for (const SharedCompiledShaderState &shader : mState.mAttachedShaders)
     {
         if (shader)
         {
-            shaderUniforms[shader->getType()] = shader->getUniforms(context);
+            shaderUniforms[shader->shaderType] = shader->uniforms;
         }
     }
 
@@ -3196,7 +3221,8 @@
     int shaderVersion              = -1;
     unsigned int usedLocations     = 0;
 
-    Shader *vertexShader = mState.getAttachedShader(gl::ShaderType::Vertex);
+    const SharedCompiledShaderState &vertexShader =
+        mState.getAttachedShader(gl::ShaderType::Vertex);
 
     if (!vertexShader)
     {
@@ -3208,10 +3234,9 @@
     // see GLSL ES 3.00.6 section 12.46. Inactive attributes will be pruned after
     // aliasing checks.
     // In GLSL ES 1.00.17 we only do aliasing checks for active attributes.
-    shaderVersion = vertexShader->getShaderVersion(context);
+    shaderVersion = vertexShader->shaderVersion;
     const std::vector<sh::ShaderVariable> &shaderAttributes =
-        shaderVersion >= 300 ? vertexShader->getAllAttributes(context)
-                             : vertexShader->getActiveAttributes(context);
+        shaderVersion >= 300 ? vertexShader->allAttributes : vertexShader->activeAttributes;
 
     ASSERT(mState.mExecutable->mProgramInputs.empty());
     mState.mExecutable->mProgramInputs.reserve(shaderAttributes.size());
@@ -3844,12 +3869,12 @@
     }
 }
 
-void Program::dumpProgramInfo() const
+void Program::dumpProgramInfo(const Context *context) const
 {
     std::stringstream dumpStream;
     for (ShaderType shaderType : angle::AllEnums<ShaderType>())
     {
-        gl::Shader *shader = mState.mAttachedShaders[shaderType];
+        gl::Shader *shader = getAttachedShader(shaderType);
         if (shader)
         {
             dumpStream << shader->getType() << ": "
diff --git a/src/libANGLE/Program.h b/src/libANGLE/Program.h
index 069850f..66b3ed5 100644
--- a/src/libANGLE/Program.h
+++ b/src/libANGLE/Program.h
@@ -15,6 +15,7 @@
 
 #include <array>
 #include <map>
+#include <memory>
 #include <set>
 #include <sstream>
 #include <string>
@@ -222,8 +223,11 @@
 
     const std::string &getLabel();
 
-    Shader *getAttachedShader(ShaderType shaderType) const;
-    const gl::ShaderMap<Shader *> &getAttachedShaders() const { return mAttachedShaders; }
+    SharedCompiledShaderState getAttachedShader(ShaderType shaderType) const;
+    const ShaderMap<SharedCompiledShaderState> &getAttachedShaders() const
+    {
+        return mAttachedShaders;
+    }
     const std::vector<std::string> &getTransformFeedbackVaryingNames() const
     {
         return mTransformFeedbackVaryingNames;
@@ -325,7 +329,7 @@
     int getNumViews() const { return mNumViews; }
     bool usesMultiview() const { return mNumViews != -1; }
 
-    bool hasAttachedShader() const;
+    bool hasAnyAttachedShader() const;
 
     ShaderType getFirstAttachedShaderStageType() const;
     ShaderType getLastAttachedShaderStageType() const;
@@ -374,8 +378,8 @@
     friend class Program;
 
     void updateActiveSamplers();
-    void updateProgramInterfaceInputs(const Context *context);
-    void updateProgramInterfaceOutputs(const Context *context);
+    void updateProgramInterfaceInputs();
+    void updateProgramInterfaceOutputs();
 
     // Scans the sampler bindings for type conflicts with sampler 'textureUnitIndex'.
     void setSamplerUniformTextureTypeAndFormat(size_t textureUnitIndex);
@@ -384,7 +388,7 @@
 
     sh::WorkGroupSize mComputeShaderLocalSize;
 
-    ShaderMap<Shader *> mAttachedShaders;
+    ShaderMap<SharedCompiledShaderState> mAttachedShaders;
 
     uint32_t mLocationsUsedForXfbExtension;
     std::vector<std::string> mTransformFeedbackVaryingNames;
@@ -819,7 +823,7 @@
 
     bool linkValidateShaders(const Context *context, InfoLog &infoLog);
     bool linkAttributes(const Context *context, InfoLog &infoLog);
-    bool linkVaryings(const Context *context, InfoLog &infoLog) const;
+    bool linkVaryings(InfoLog &infoLog) const;
 
     bool linkUniforms(const Context *context,
                       std::vector<UnusedUniform> *unusedUniformsOutOrNull,
@@ -889,7 +893,7 @@
                                  GLboolean transpose,
                                  const UniformT *v);
 
-    void dumpProgramInfo() const;
+    void dumpProgramInfo(const Context *context) const;
 
     rx::UniqueSerial mSerial;
     ProgramState mState;
@@ -912,6 +916,12 @@
     ShaderProgramManager *mResourceManager;
     const ShaderProgramID mHandle;
 
+    // ProgramState::mAttachedShaders holds a reference to shaders' compiled state, which is all the
+    // program and the backends require after link.  The actual shaders linked to the program are
+    // stored here to support shader attach/detach and link without providing access to them in the
+    // backends.
+    ShaderMap<Shader *> mAttachedShaders;
+
     DirtyBits mDirtyBits;
 
     // To simplify dirty bits handling, instead of tracking dirtiness of both uniform block index
diff --git a/src/libANGLE/ProgramExecutable.cpp b/src/libANGLE/ProgramExecutable.cpp
index 1d87605..4825f64 100644
--- a/src/libANGLE/ProgramExecutable.cpp
+++ b/src/libANGLE/ProgramExecutable.cpp
@@ -813,13 +813,13 @@
 {
     for (ShaderType shaderType : getLinkedShaderStages())
     {
-        Shader *shader = state.getAttachedShader(shaderType);
+        const SharedCompiledShaderState &shader = state.getAttachedShader(shaderType);
         ASSERT(shader);
-        mPODStruct.linkedShaderVersions[shaderType] = shader->getShaderVersion(context);
-        mLinkedOutputVaryings[shaderType]           = shader->getOutputVaryings(context);
-        mLinkedInputVaryings[shaderType]            = shader->getInputVaryings(context);
-        mLinkedUniforms[shaderType]                 = shader->getUniforms(context);
-        mLinkedUniformBlocks[shaderType]            = shader->getUniformBlocks(context);
+        mPODStruct.linkedShaderVersions[shaderType] = shader->shaderVersion;
+        mLinkedOutputVaryings[shaderType]           = shader->outputVaryings;
+        mLinkedInputVaryings[shaderType]            = shader->inputVaryings;
+        mLinkedUniforms[shaderType]                 = shader->uniforms;
+        mLinkedUniformBlocks[shaderType]            = shader->uniformBlocks;
     }
 }
 
diff --git a/src/libANGLE/ProgramLinkedResources.cpp b/src/libANGLE/ProgramLinkedResources.cpp
index 06f8814..b8a0317 100644
--- a/src/libANGLE/ProgramLinkedResources.cpp
+++ b/src/libANGLE/ProgramLinkedResources.cpp
@@ -1611,17 +1611,17 @@
 
 ProgramLinkedResources::~ProgramLinkedResources() = default;
 
-LinkingVariables::LinkingVariables(const Context *context, const ProgramState &state)
+LinkingVariables::LinkingVariables(const ProgramState &state)
 {
     for (ShaderType shaderType : kAllGraphicsShaderTypes)
     {
-        Shader *shader = state.getAttachedShader(shaderType);
+        const SharedCompiledShaderState &shader = state.getAttachedShader(shaderType);
         if (shader)
         {
-            outputVaryings[shaderType] = shader->getOutputVaryings(context);
-            inputVaryings[shaderType]  = shader->getInputVaryings(context);
-            uniforms[shaderType]       = shader->getUniforms(context);
-            uniformBlocks[shaderType]  = shader->getUniformBlocks(context);
+            outputVaryings[shaderType] = shader->outputVaryings;
+            inputVaryings[shaderType]  = shader->inputVaryings;
+            uniforms[shaderType]       = shader->uniforms;
+            uniformBlocks[shaderType]  = shader->uniformBlocks;
             isShaderStageUsedBitset.set(shaderType);
         }
     }
@@ -1659,18 +1659,17 @@
     atomicCounterBufferLinker.init(atomicCounterBuffersOut);
 }
 
-void ProgramLinkedResourcesLinker::linkResources(const Context *context,
-                                                 const ProgramState &programState,
+void ProgramLinkedResourcesLinker::linkResources(const ProgramState &programState,
                                                  const ProgramLinkedResources &resources) const
 {
     // Gather uniform interface block info.
     InterfaceBlockInfo uniformBlockInfo(mCustomEncoderFactory);
     for (const ShaderType shaderType : AllShaderTypes())
     {
-        Shader *shader = programState.getAttachedShader(shaderType);
+        const SharedCompiledShaderState &shader = programState.getAttachedShader(shaderType);
         if (shader)
         {
-            uniformBlockInfo.getShaderBlockInfo(shader->getUniformBlocks(context));
+            uniformBlockInfo.getShaderBlockInfo(shader->uniformBlocks);
         }
     }
 
@@ -1692,10 +1691,10 @@
     InterfaceBlockInfo shaderStorageBlockInfo(mCustomEncoderFactory);
     for (const ShaderType shaderType : AllShaderTypes())
     {
-        Shader *shader = programState.getAttachedShader(shaderType);
+        const SharedCompiledShaderState &shader = programState.getAttachedShader(shaderType);
         if (shader)
         {
-            shaderStorageBlockInfo.getShaderBlockInfo(shader->getShaderStorageBlocks(context));
+            shaderStorageBlockInfo.getShaderBlockInfo(shader->shaderStorageBlocks);
         }
     }
     auto getShaderStorageBlockSize = [&shaderStorageBlockInfo](const std::string &name,
diff --git a/src/libANGLE/ProgramLinkedResources.h b/src/libANGLE/ProgramLinkedResources.h
index 393c260..ecffc77 100644
--- a/src/libANGLE/ProgramLinkedResources.h
+++ b/src/libANGLE/ProgramLinkedResources.h
@@ -296,7 +296,7 @@
 
 struct LinkingVariables final : private angle::NonCopyable
 {
-    LinkingVariables(const Context *context, const ProgramState &state);
+    LinkingVariables(const ProgramState &state);
     LinkingVariables(const ProgramPipelineState &state);
     ~LinkingVariables();
 
@@ -324,8 +324,7 @@
         : mCustomEncoderFactory(customEncoderFactory)
     {}
 
-    void linkResources(const Context *context,
-                       const ProgramState &programState,
+    void linkResources(const ProgramState &programState,
                        const ProgramLinkedResources &resources) const;
 
   private:
diff --git a/src/libANGLE/Shader.cpp b/src/libANGLE/Shader.cpp
index ae54435..4dfff79 100644
--- a/src/libANGLE/Shader.cpp
+++ b/src/libANGLE/Shader.cpp
@@ -127,7 +127,9 @@
     ShCompilerInstance shCompilerInstance;
 };
 
-ShaderState::ShaderState(ShaderType shaderType) : mCompiledShaderState(shaderType) {}
+ShaderState::ShaderState(ShaderType shaderType)
+    : mCompiledState(std::make_shared<CompiledShaderState>(shaderType))
+{}
 
 ShaderState::~ShaderState() {}
 
@@ -140,7 +142,6 @@
       mImplementation(implFactory->createShader(mState)),
       mRendererLimitations(rendererLimitations),
       mHandle(handle),
-      mType(type),
       mRefCount(0),
       mDeleteStatus(false),
       mResourceManager(manager),
@@ -304,12 +305,12 @@
 {
     resolveCompile(context);
 
-    if (mState.getTranslatedSource().empty())
+    if (mState.mCompiledState->translatedSource.empty())
     {
         return 0;
     }
 
-    return (static_cast<int>(mState.getTranslatedSource().length()) + 1);
+    return static_cast<int>(mState.mCompiledState->translatedSource.length()) + 1;
 }
 
 int Shader::getTranslatedSourceWithDebugInfoLength(const Context *context)
@@ -363,13 +364,7 @@
 const std::string &Shader::getTranslatedSource(const Context *context)
 {
     resolveCompile(context);
-    return mState.getTranslatedSource();
-}
-
-const sh::BinaryBlob &Shader::getCompiledBinary(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getCompiledBinary();
+    return mState.mCompiledState->translatedSource;
 }
 
 size_t Shader::getSourceHash() const
@@ -391,32 +386,11 @@
 {
     resolveCompile(context);
 
-    mState.mCompiledShaderState.translatedSource.clear();
-    mState.mCompiledShaderState.compiledBinary.clear();
+    // Create a new compiled shader state.  If any programs are currently linking using this shader,
+    // they would use the old compiled state, and this shader is free to recompile in the meantime.
+    mState.mCompiledState = std::make_shared<CompiledShaderState>(mState.getShaderType());
+
     mInfoLog.clear();
-    mState.mCompiledShaderState.shaderVersion = 100;
-    mState.mCompiledShaderState.inputVaryings.clear();
-    mState.mCompiledShaderState.outputVaryings.clear();
-    mState.mCompiledShaderState.uniforms.clear();
-    mState.mCompiledShaderState.uniformBlocks.clear();
-    mState.mCompiledShaderState.shaderStorageBlocks.clear();
-    mState.mCompiledShaderState.activeAttributes.clear();
-    mState.mCompiledShaderState.activeOutputVariables.clear();
-    mState.mCompiledShaderState.numViews = -1;
-    mState.mCompiledShaderState.geometryShaderInputPrimitiveType.reset();
-    mState.mCompiledShaderState.geometryShaderOutputPrimitiveType.reset();
-    mState.mCompiledShaderState.geometryShaderMaxVertices.reset();
-    mState.mCompiledShaderState.geometryShaderInvocations = 1;
-    mState.mCompiledShaderState.tessControlShaderVertices = 0;
-    mState.mCompiledShaderState.tessGenMode               = 0;
-    mState.mCompiledShaderState.tessGenSpacing            = 0;
-    mState.mCompiledShaderState.tessGenVertexOrder        = 0;
-    mState.mCompiledShaderState.tessGenPointMode          = 0;
-    mState.mCompiledShaderState.advancedBlendEquations.reset();
-    mState.mCompiledShaderState.hasClipDistance         = false;
-    mState.mCompiledShaderState.hasDiscard              = false;
-    mState.mCompiledShaderState.enablesPerSampleShading = false;
-    mState.mCompiledShaderState.specConstUsageBits.reset();
 
     mCurrentMaxComputeWorkGroupInvocations =
         static_cast<GLuint>(context->getCaps().maxComputeWorkGroupInvocations);
@@ -460,7 +434,7 @@
     mBoundCompiler.set(context, context->getCompiler());
 
     ASSERT(mBoundCompiler.get());
-    ShCompilerInstance compilerInstance = mBoundCompiler->getInstance(mType);
+    ShCompilerInstance compilerInstance = mBoundCompiler->getInstance(mState.getShaderType());
     ShHandle compilerHandle             = compilerInstance.getHandle();
     ASSERT(compilerHandle);
 
@@ -497,6 +471,7 @@
     }
 
     ASSERT(mCompilingState.get());
+    mState.mCompileStatus = CompileStatus::IS_RESOLVING;
 
     mCompilingState->compileEvent->wait();
 
@@ -519,7 +494,7 @@
 
     const ShShaderOutput outputType = mCompilingState->shCompilerInstance.getShaderOutputType();
     bool isBinaryOutput             = outputType == SH_SPIRV_VULKAN_OUTPUT;
-    mState.mCompiledShaderState.buildCompiledShaderState(compilerHandle, isBinaryOutput);
+    mState.mCompiledState->buildCompiledShaderState(compilerHandle, isBinaryOutput);
 
     const angle::FrontendFeatures &frontendFeatures = context->getFrontendFeatures();
     bool substitutedTranslatedShader                = false;
@@ -541,8 +516,8 @@
             std::string substituteShader;
             if (angle::ReadFileToString(substituteShaderPath, &substituteShader))
             {
-                mState.mCompiledShaderState.translatedSource = std::move(substituteShader);
-                substitutedTranslatedShader                  = true;
+                mState.mCompiledState->translatedSource = std::move(substituteShader);
+                substitutedTranslatedShader             = true;
                 INFO() << "Trasnslated shader substitute found, loading from "
                        << substituteShaderPath;
             }
@@ -561,7 +536,7 @@
         {
             std::string dumpFile = GetShaderDumpFilePath(mState.mSourceHash, suffix);
 
-            const std::string &translatedSource = mState.mCompiledShaderState.translatedSource;
+            const std::string &translatedSource = mState.mCompiledState->translatedSource;
             writeFile(dumpFile.c_str(), translatedSource.c_str(), translatedSource.length());
             INFO() << "Dumped translated source: " << dumpFile;
         }
@@ -594,19 +569,19 @@
             shaderStream << std::endl;
         }
         shaderStream << "\n\n";
-        shaderStream << mState.mCompiledShaderState.translatedSource;
-        mState.mCompiledShaderState.translatedSource = shaderStream.str();
+        shaderStream << mState.mCompiledState->translatedSource;
+        mState.mCompiledState->translatedSource = shaderStream.str();
     }
 #endif  // !defined(NDEBUG)
 
     // Validation checks for compute shaders
-    if (mState.mCompiledShaderState.shaderType == ShaderType::Compute &&
-        mState.mCompiledShaderState.localSize.isDeclared())
+    if (mState.mCompiledState->shaderType == ShaderType::Compute &&
+        mState.mCompiledState->localSize.isDeclared())
     {
         angle::CheckedNumeric<uint32_t> checked_local_size_product(
-            mState.mCompiledShaderState.localSize[0]);
-        checked_local_size_product *= mState.mCompiledShaderState.localSize[1];
-        checked_local_size_product *= mState.mCompiledShaderState.localSize[2];
+            mState.mCompiledState->localSize[0]);
+        checked_local_size_product *= mState.mCompiledState->localSize[1];
+        checked_local_size_product *= mState.mCompiledState->localSize[2];
 
         if (!checked_local_size_product.IsValid())
         {
@@ -634,8 +609,8 @@
         return;
     }
 
-    ASSERT(!mState.mCompiledShaderState.translatedSource.empty() ||
-           !mState.mCompiledShaderState.compiledBinary.empty());
+    ASSERT(!mState.mCompiledState->translatedSource.empty() ||
+           !mState.mCompiledState->compiledBinary.empty());
 
     bool success          = mCompilingState->compileEvent->postTranslate(&mInfoLog);
     mState.mCompileStatus = success ? CompileStatus::COMPILED : CompileStatus::NOT_COMPILED;
@@ -693,186 +668,12 @@
     return (!mState.compilePending() || mCompilingState->compileEvent->isReady());
 }
 
-int Shader::getShaderVersion(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.shaderVersion;
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getInputVaryings(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getInputVaryings();
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getOutputVaryings(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getOutputVaryings();
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getUniforms(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getUniforms();
-}
-
-const std::vector<sh::InterfaceBlock> &Shader::getUniformBlocks(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getUniformBlocks();
-}
-
-const std::vector<sh::InterfaceBlock> &Shader::getShaderStorageBlocks(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getShaderStorageBlocks();
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getActiveAttributes(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getActiveAttributes();
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getAllAttributes(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getAllAttributes();
-}
-
-const std::vector<sh::ShaderVariable> &Shader::getActiveOutputVariables(const Context *context)
-{
-    resolveCompile(context);
-    return mState.getActiveOutputVariables();
-}
-
-std::string Shader::getTransformFeedbackVaryingMappedName(const Context *context,
-                                                          const std::string &tfVaryingName)
-{
-    ASSERT(mState.getShaderType() != ShaderType::Fragment &&
-           mState.getShaderType() != ShaderType::Compute);
-    const auto &varyings = getOutputVaryings(context);
-    auto bracketPos      = tfVaryingName.find("[");
-    if (bracketPos != std::string::npos)
-    {
-        auto tfVaryingBaseName = tfVaryingName.substr(0, bracketPos);
-        for (const auto &varying : varyings)
-        {
-            if (varying.name == tfVaryingBaseName)
-            {
-                std::string mappedNameWithArrayIndex =
-                    varying.mappedName + tfVaryingName.substr(bracketPos);
-                return mappedNameWithArrayIndex;
-            }
-        }
-    }
-    else
-    {
-        for (const auto &varying : varyings)
-        {
-            if (varying.name == tfVaryingName)
-            {
-                return varying.mappedName;
-            }
-            else if (varying.isStruct())
-            {
-                GLuint fieldIndex = 0;
-                const auto *field = varying.findField(tfVaryingName, &fieldIndex);
-                if (field == nullptr)
-                {
-                    continue;
-                }
-                ASSERT(field != nullptr && !field->isStruct() &&
-                       (!field->isArray() || varying.isShaderIOBlock));
-                std::string mappedName;
-                // If it's an I/O block without an instance name, don't include the block name.
-                if (!varying.isShaderIOBlock || !varying.name.empty())
-                {
-                    mappedName = varying.isShaderIOBlock ? varying.mappedStructOrBlockName
-                                                         : varying.mappedName;
-                    mappedName += '.';
-                }
-                return mappedName + field->mappedName;
-            }
-        }
-    }
-    UNREACHABLE();
-    return std::string();
-}
-
-const sh::WorkGroupSize &Shader::getWorkGroupSize(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.localSize;
-}
-
-int Shader::getNumViews(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.numViews;
-}
-
-Optional<PrimitiveMode> Shader::getGeometryShaderInputPrimitiveType(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.geometryShaderInputPrimitiveType;
-}
-
-Optional<PrimitiveMode> Shader::getGeometryShaderOutputPrimitiveType(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.geometryShaderOutputPrimitiveType;
-}
-
-int Shader::getGeometryShaderInvocations(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.geometryShaderInvocations;
-}
-
-Optional<GLint> Shader::getGeometryShaderMaxVertices(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.geometryShaderMaxVertices;
-}
-
-int Shader::getTessControlShaderVertices(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.tessControlShaderVertices;
-}
-
-GLenum Shader::getTessGenMode(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.tessGenMode;
-}
-
-GLenum Shader::getTessGenSpacing(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.tessGenSpacing;
-}
-
-GLenum Shader::getTessGenVertexOrder(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.tessGenVertexOrder;
-}
-
-GLenum Shader::getTessGenPointMode(const Context *context)
-{
-    resolveCompile(context);
-    return mState.mCompiledShaderState.tessGenPointMode;
-}
-
 angle::Result Shader::serialize(const Context *context, angle::MemoryBuffer *binaryOut) const
 {
     BinaryOutputStream stream;
 
     stream.writeInt(kShaderCacheIdentifier);
-    mState.mCompiledShaderState.serialize(stream);
+    mState.mCompiledState->serialize(stream);
 
     ASSERT(binaryOut);
     if (!binaryOut->resize(stream.length()))
@@ -892,7 +693,7 @@
 
 angle::Result Shader::deserialize(BinaryInputStream &stream)
 {
-    mState.mCompiledShaderState.deserialize(stream);
+    mState.mCompiledState->deserialize(stream);
 
     if (stream.error())
     {
@@ -900,6 +701,10 @@
         return angle::Result::Stop;
     }
 
+    // Note: Currently, shader binaries are only supported on backends that don't happen to have any
+    // additional state used at link time.  If other backends implement this functionality, this
+    // function should call into the backend object to deserialize their part.
+
     return angle::Result::Continue;
 }
 
@@ -920,6 +725,8 @@
 {
     BinaryInputStream stream(binary, length);
 
+    mState.mCompiledState = std::make_shared<CompiledShaderState>(mState.getShaderType());
+
     // Shader binaries generated with offline compiler have additional fields
     if (generatedWithOfflineCompiler)
     {
@@ -933,7 +740,7 @@
 
         gl::ShaderType shaderType;
         stream.readEnum(&shaderType);
-        ASSERT(mType == shaderType);
+        ASSERT(mState.getShaderType() == shaderType);
 
         // Get fields needed to generate the key for memory caches.
         ShShaderOutput outputType;
@@ -979,7 +786,7 @@
     BinaryOutputStream hashStream;
 
     // Start with the shader type and source.
-    hashStream.writeEnum(mType);
+    hashStream.writeEnum(mState.getShaderType());
     hashStream.writeString(mState.getSource());
 
     // Include the shader program version hash.
diff --git a/src/libANGLE/Shader.h b/src/libANGLE/Shader.h
index 38ab93a..cb64da2 100644
--- a/src/libANGLE/Shader.h
+++ b/src/libANGLE/Shader.h
@@ -57,8 +57,15 @@
 // We defer the compile until link time, or until properties are queried.
 enum class CompileStatus
 {
+    // Compilation never done, or has failed.
     NOT_COMPILED,
+    // Compile is in progress.
     COMPILE_REQUESTED,
+    // Compilation job is done, but is being resolved.  This enum value is there to allow access to
+    // compiled state during resolve without triggering threading-related assertions (which ensure
+    // no compile job is in progress).
+    IS_RESOLVING,
+    // Compilation was successful.
     COMPILED,
 };
 
@@ -71,82 +78,17 @@
     const std::string &getLabel() const { return mLabel; }
 
     const std::string &getSource() const { return mSource; }
-    bool isCompiledToBinary() const { return !mCompiledShaderState.compiledBinary.empty(); }
-    const std::string &getTranslatedSource() const { return mCompiledShaderState.translatedSource; }
-    const sh::BinaryBlob &getCompiledBinary() const { return mCompiledShaderState.compiledBinary; }
-
-    ShaderType getShaderType() const { return mCompiledShaderState.shaderType; }
-    int getShaderVersion() const { return mCompiledShaderState.shaderVersion; }
-
-    const std::vector<sh::ShaderVariable> &getInputVaryings() const
-    {
-        return mCompiledShaderState.inputVaryings;
-    }
-    const std::vector<sh::ShaderVariable> &getOutputVaryings() const
-    {
-        return mCompiledShaderState.outputVaryings;
-    }
-    const std::vector<sh::ShaderVariable> &getUniforms() const
-    {
-        return mCompiledShaderState.uniforms;
-    }
-    const std::vector<sh::InterfaceBlock> &getUniformBlocks() const
-    {
-        return mCompiledShaderState.uniformBlocks;
-    }
-    const std::vector<sh::InterfaceBlock> &getShaderStorageBlocks() const
-    {
-        return mCompiledShaderState.shaderStorageBlocks;
-    }
-    const std::vector<sh::ShaderVariable> &getActiveAttributes() const
-    {
-        return mCompiledShaderState.activeAttributes;
-    }
-    const std::vector<sh::ShaderVariable> &getAllAttributes() const
-    {
-        return mCompiledShaderState.allAttributes;
-    }
-    const std::vector<sh::ShaderVariable> &getActiveOutputVariables() const
-    {
-        return mCompiledShaderState.activeOutputVariables;
-    }
-
     bool compilePending() const { return mCompileStatus == CompileStatus::COMPILE_REQUESTED; }
-
-    const sh::WorkGroupSize &getLocalSize() const { return mCompiledShaderState.localSize; }
-
-    bool hasClipDistance() const { return mCompiledShaderState.hasClipDistance; }
-    bool hasDiscard() const { return mCompiledShaderState.hasDiscard; }
-    bool enablesPerSampleShading() const { return mCompiledShaderState.enablesPerSampleShading; }
-    rx::SpecConstUsageBits getSpecConstUsageBits() const
-    {
-        return mCompiledShaderState.specConstUsageBits;
-    }
-
-    int getNumViews() const { return mCompiledShaderState.numViews; }
-
-    Optional<PrimitiveMode> getGeometryShaderInputPrimitiveType() const
-    {
-        return mCompiledShaderState.geometryShaderInputPrimitiveType;
-    }
-
-    Optional<PrimitiveMode> getGeometryShaderOutputPrimitiveType() const
-    {
-        return mCompiledShaderState.geometryShaderOutputPrimitiveType;
-    }
-
-    Optional<GLint> getGeometryShaderMaxVertices() const
-    {
-        return mCompiledShaderState.geometryShaderMaxVertices;
-    }
-
-    Optional<GLint> getGeometryShaderInvocations() const
-    {
-        return mCompiledShaderState.geometryShaderInvocations;
-    }
-
     CompileStatus getCompileStatus() const { return mCompileStatus; }
 
+    ShaderType getShaderType() const { return mCompiledState->shaderType; }
+
+    const SharedCompiledShaderState &getCompiledState() const
+    {
+        ASSERT(!compilePending());
+        return mCompiledState;
+    }
+
   private:
     friend class Shader;
 
@@ -154,7 +96,7 @@
     std::string mSource;
     size_t mSourceHash = 0;
 
-    gl::CompiledShaderState mCompiledShaderState;
+    SharedCompiledShaderState mCompiledState;
 
     // Indicates if this shader has been successfully compiled
     CompileStatus mCompileStatus = CompileStatus::NOT_COMPILED;
@@ -174,7 +116,7 @@
     angle::Result setLabel(const Context *context, const std::string &label) override;
     const std::string &getLabel() const override;
 
-    ShaderType getType() const { return mType; }
+    ShaderType getType() const { return mState.getShaderType(); }
     ShaderProgramID getHandle() const;
 
     rx::ShaderImpl *getImplementation() const { return mImplementation.get(); }
@@ -200,7 +142,6 @@
                                           GLsizei bufSize,
                                           GLsizei *length,
                                           char *buffer);
-    const sh::BinaryBlob &getCompiledBinary(const Context *context);
 
     size_t getSourceHash() const;
 
@@ -208,56 +149,15 @@
     bool isCompiled(const Context *context);
     bool isCompleted();
 
+    // Return the compiled shader state for the program.  The program holds a reference to this
+    // state, so the shader is free to recompile, get deleted, etc.
+    const SharedCompiledShaderState &getCompiledState() const { return mState.getCompiledState(); }
+
     void addRef();
     void release(const Context *context);
     unsigned int getRefCount() const;
     bool isFlaggedForDeletion() const;
     void flagForDeletion();
-    bool hasClipDistance() const { return mState.mCompiledShaderState.hasClipDistance; }
-    bool hasDiscard() const { return mState.mCompiledShaderState.hasDiscard; }
-    bool enablesPerSampleShading() const
-    {
-        return mState.mCompiledShaderState.enablesPerSampleShading;
-    }
-    BlendEquationBitSet getAdvancedBlendEquations() const
-    {
-        return mState.mCompiledShaderState.advancedBlendEquations;
-    }
-    rx::SpecConstUsageBits getSpecConstUsageBits() const
-    {
-        return mState.mCompiledShaderState.specConstUsageBits;
-    }
-
-    int getShaderVersion(const Context *context);
-
-    const std::vector<sh::ShaderVariable> &getInputVaryings(const Context *context);
-    const std::vector<sh::ShaderVariable> &getOutputVaryings(const Context *context);
-    const std::vector<sh::ShaderVariable> &getUniforms(const Context *context);
-    const std::vector<sh::InterfaceBlock> &getUniformBlocks(const Context *context);
-    const std::vector<sh::InterfaceBlock> &getShaderStorageBlocks(const Context *context);
-    const std::vector<sh::ShaderVariable> &getActiveAttributes(const Context *context);
-    const std::vector<sh::ShaderVariable> &getAllAttributes(const Context *context);
-    const std::vector<sh::ShaderVariable> &getActiveOutputVariables(const Context *context);
-
-    // Returns mapped name of a transform feedback varying. The original name may contain array
-    // brackets with an index inside, which will get copied to the mapped name. The varying must be
-    // known to be declared in the shader.
-    std::string getTransformFeedbackVaryingMappedName(const Context *context,
-                                                      const std::string &tfVaryingName);
-
-    const sh::WorkGroupSize &getWorkGroupSize(const Context *context);
-
-    int getNumViews(const Context *context);
-
-    Optional<PrimitiveMode> getGeometryShaderInputPrimitiveType(const Context *context);
-    Optional<PrimitiveMode> getGeometryShaderOutputPrimitiveType(const Context *context);
-    int getGeometryShaderInvocations(const Context *context);
-    Optional<GLint> getGeometryShaderMaxVertices(const Context *context);
-    int getTessControlShaderVertices(const Context *context);
-    GLenum getTessGenMode(const Context *context);
-    GLenum getTessGenSpacing(const Context *context);
-    GLenum getTessGenVertexOrder(const Context *context);
-    GLenum getTessGenPointMode(const Context *context);
 
     const ShaderState &getState() const { return mState; }
 
@@ -314,7 +214,6 @@
     std::unique_ptr<rx::ShaderImpl> mImplementation;
     const gl::Limitations mRendererLimitations;
     const ShaderProgramID mHandle;
-    const ShaderType mType;
     unsigned int mRefCount;  // Number of program objects this shader is attached to
     bool mDeleteStatus;  // Flag to indicate that the shader can be deleted when no longer in use
     std::string mInfoLog;
diff --git a/src/libANGLE/capture/FrameCapture.cpp b/src/libANGLE/capture/FrameCapture.cpp
index 600247e..8dc017d 100644
--- a/src/libANGLE/capture/FrameCapture.cpp
+++ b/src/libANGLE/capture/FrameCapture.cpp
@@ -1865,7 +1865,7 @@
     replayWriter.saveSetupFile();
 }
 
-ProgramSources GetAttachedProgramSources(const gl::Program *program)
+ProgramSources GetAttachedProgramSources(const gl::Context *context, const gl::Program *program)
 {
     ProgramSources sources;
     for (gl::ShaderType shaderType : gl::AllShaderTypes())
@@ -7410,7 +7410,7 @@
             gl::ShaderProgramID shaderID =
                 call.params.getParam("shaderPacked", ParamType::TShaderProgramID, 0)
                     .value.ShaderProgramIDVal;
-            const gl::Shader *shader = context->getShader(shaderID);
+            const gl::Shader *shader = context->getShaderNoResolveCompile(shaderID);
             // Shaders compiled for ProgramBinary will not have a shader created
             if (shader)
             {
@@ -7427,9 +7427,9 @@
                     .value.ShaderProgramIDVal;
             const gl::Program *program = context->getProgramResolveLink(programID);
             // Programs linked in support of ProgramBinary will not have attached shaders
-            if (program->getState().hasAttachedShader())
+            if (program->getState().hasAnyAttachedShader())
             {
-                setProgramSources(programID, GetAttachedProgramSources(program));
+                setProgramSources(programID, GetAttachedProgramSources(context, program));
             }
             break;
         }
diff --git a/src/libANGLE/capture/serialize.cpp b/src/libANGLE/capture/serialize.cpp
index 65c8ce9..1aea4bd 100644
--- a/src/libANGLE/capture/serialize.cpp
+++ b/src/libANGLE/capture/serialize.cpp
@@ -915,41 +915,41 @@
     }
 }
 
+void SerializeCompiledShaderState(JsonSerializer *json, const gl::SharedCompiledShaderState &state)
+{
+    json->addCString("Type", gl::ShaderTypeToString(state->shaderType));
+    json->addScalar("Version", state->shaderVersion);
+    json->addString("TranslatedSource", state->translatedSource);
+    json->addVectorAsHash("CompiledBinary", state->compiledBinary);
+    SerializeWorkGroupSize(json, state->localSize);
+    SerializeShaderVariablesVector(json, state->inputVaryings);
+    SerializeShaderVariablesVector(json, state->outputVaryings);
+    SerializeShaderVariablesVector(json, state->uniforms);
+    SerializeInterfaceBlocksVector(json, state->uniformBlocks);
+    SerializeInterfaceBlocksVector(json, state->shaderStorageBlocks);
+    SerializeShaderVariablesVector(json, state->allAttributes);
+    SerializeShaderVariablesVector(json, state->activeAttributes);
+    SerializeShaderVariablesVector(json, state->activeOutputVariables);
+    json->addScalar("NumViews", state->numViews);
+    json->addScalar("SpecConstUsageBits", state->specConstUsageBits.bits());
+    if (state->geometryShaderInputPrimitiveType.valid())
+    {
+        json->addString("GeometryShaderInputPrimitiveType",
+                        ToString(state->geometryShaderInputPrimitiveType.value()));
+    }
+    if (state->geometryShaderOutputPrimitiveType.valid())
+    {
+        json->addString("GeometryShaderOutputPrimitiveType",
+                        ToString(state->geometryShaderOutputPrimitiveType.value()));
+    }
+    json->addScalar("GeometryShaderInvocations", state->geometryShaderInvocations);
+}
+
 void SerializeShaderState(JsonSerializer *json, const gl::ShaderState &shaderState)
 {
     GroupScope group(json, "ShaderState");
     json->addString("Label", shaderState.getLabel());
-    json->addCString("Type", gl::ShaderTypeToString(shaderState.getShaderType()));
-    json->addScalar("Version", shaderState.getShaderVersion());
-    json->addString("TranslatedSource", shaderState.getTranslatedSource());
-    json->addVectorAsHash("CompiledBinary", shaderState.getCompiledBinary());
     json->addString("Source", shaderState.getSource());
-    SerializeWorkGroupSize(json, shaderState.getLocalSize());
-    SerializeShaderVariablesVector(json, shaderState.getInputVaryings());
-    SerializeShaderVariablesVector(json, shaderState.getOutputVaryings());
-    SerializeShaderVariablesVector(json, shaderState.getUniforms());
-    SerializeInterfaceBlocksVector(json, shaderState.getUniformBlocks());
-    SerializeInterfaceBlocksVector(json, shaderState.getShaderStorageBlocks());
-    SerializeShaderVariablesVector(json, shaderState.getAllAttributes());
-    SerializeShaderVariablesVector(json, shaderState.getActiveAttributes());
-    SerializeShaderVariablesVector(json, shaderState.getActiveOutputVariables());
-    json->addScalar("NumViews", shaderState.getNumViews());
-    json->addScalar("SpecConstUsageBits", shaderState.getSpecConstUsageBits().bits());
-    if (shaderState.getGeometryShaderInputPrimitiveType().valid())
-    {
-        json->addString("GeometryShaderInputPrimitiveType",
-                        ToString(shaderState.getGeometryShaderInputPrimitiveType().value()));
-    }
-    if (shaderState.getGeometryShaderOutputPrimitiveType().valid())
-    {
-        json->addString("GeometryShaderOutputPrimitiveType",
-                        ToString(shaderState.getGeometryShaderOutputPrimitiveType().value()));
-    }
-    if (shaderState.getGeometryShaderInvocations().valid())
-    {
-        json->addScalar("GeometryShaderInvocations",
-                        shaderState.getGeometryShaderInvocations().value());
-    }
     json->addCString("CompileStatus", CompileStatusToString(shaderState.getCompileStatus()));
 }
 
@@ -963,12 +963,13 @@
 
     GroupScope group(json, "Shader", id);
     SerializeShaderState(json, shader->getState());
+    SerializeCompiledShaderState(json, shader->getCompiledState());
     json->addScalar("Handle", shader->getHandle().value);
     // TODO: implement MEC context validation only after all contexts have been initialized
     // http://anglebug.com/8029
     // json->addScalar("RefCount", shader->getRefCount());
     json->addScalar("FlaggedForDeletion", shader->isFlaggedForDeletion());
-    // Do not serialize mType because it is already serialized in SerializeShaderState.
+    // Do not serialize mType because it is already serialized in SerializeCompiledShaderState.
     json->addString("InfoLogString", shader->getInfoLogString());
     // Do not serialize compiler resources string because it can vary between test modes.
     json->addScalar("CurrentMaxComputeWorkGroupInvocations",
@@ -1043,11 +1044,6 @@
     json->addString("Label", programState.getLabel());
     SerializeWorkGroupSize(json, programState.getComputeShaderLocalSize());
 
-    auto attachedShaders = programState.getAttachedShaders();
-    std::vector<GLint> shaderHandles(attachedShaders.size());
-    std::transform(attachedShaders.begin(), attachedShaders.end(), shaderHandles.begin(),
-                   [](gl::Shader *shader) { return shader ? shader->getHandle().value : 0; });
-    json->addVector("Handle", shaderHandles);
     json->addScalar("LocationsUsedForXfbExtension", programState.getLocationsUsedForXfbExtension());
 
     json->addVectorOfStrings("TransformFeedbackVaryingNames",
@@ -1101,6 +1097,15 @@
     program->resolveLink(context);
 
     GroupScope group(json, "Program", id);
+
+    std::vector<GLint> shaderHandles;
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        gl::Shader *shader = program->getAttachedShader(shaderType);
+        shaderHandles.push_back(shader ? shader->getHandle().value : 0);
+    }
+    json->addVector("Handle", shaderHandles);
+
     SerializeProgramState(json, program->getState());
     json->addScalar("IsValidated", program->isValidated());
     SerializeProgramBindings(json, program->getAttributeBindings());
diff --git a/src/libANGLE/renderer/ProgramImpl.h b/src/libANGLE/renderer/ProgramImpl.h
index 18448a1..960681c 100644
--- a/src/libANGLE/renderer/ProgramImpl.h
+++ b/src/libANGLE/renderer/ProgramImpl.h
@@ -82,6 +82,7 @@
     virtual void setBinaryRetrievableHint(bool retrievable)                       = 0;
     virtual void setSeparable(bool separable)                                     = 0;
 
+    virtual void prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders) {}
     virtual std::unique_ptr<LinkEvent> link(const gl::Context *context,
                                             const gl::ProgramLinkedResources &resources,
                                             gl::InfoLog &infoLog,
diff --git a/src/libANGLE/renderer/d3d/DynamicHLSL.cpp b/src/libANGLE/renderer/d3d/DynamicHLSL.cpp
index ff0b4f1..75cc293 100644
--- a/src/libANGLE/renderer/d3d/DynamicHLSL.cpp
+++ b/src/libANGLE/renderer/d3d/DynamicHLSL.cpp
@@ -427,8 +427,8 @@
 
 std::string DynamicHLSL::generateShaderForImage2DBindSignature(
     ProgramD3D &programD3D,
-    const gl::ProgramState &programData,
     gl::ShaderType shaderType,
+    const SharedCompiledShaderStateD3D &shaderData,
     const std::string &shaderHLSL,
     std::vector<sh::ShaderVariable> &image2DUniforms,
     const gl::ImageUnitTextureTypeMap &image2DBindLayout,
@@ -439,7 +439,7 @@
         return shaderHLSL;
     }
 
-    return GenerateShaderForImage2DBindSignature(programD3D, programData, shaderType, shaderHLSL,
+    return GenerateShaderForImage2DBindSignature(programD3D, shaderType, shaderData, shaderHLSL,
                                                  image2DUniforms, image2DBindLayout,
                                                  baseUAVRegister);
 }
@@ -564,27 +564,24 @@
     hlslStream << "};\n";
 }
 
-void DynamicHLSL::generateShaderLinkHLSL(const gl::Context *context,
-                                         const gl::Caps &caps,
-                                         const gl::ProgramState &programData,
-                                         const ProgramD3DMetadata &programMetadata,
-                                         const VaryingPacking &varyingPacking,
-                                         const BuiltinVaryingsD3D &builtinsD3D,
-                                         gl::ShaderMap<std::string> *shaderHLSL) const
+void DynamicHLSL::generateShaderLinkHLSL(
+    const gl::Caps &caps,
+    const gl::ShaderMap<gl::SharedCompiledShaderState> &shaderData,
+    const gl::ShaderMap<SharedCompiledShaderStateD3D> &shaderDataD3D,
+    const ProgramD3DMetadata &programMetadata,
+    const VaryingPacking &varyingPacking,
+    const BuiltinVaryingsD3D &builtinsD3D,
+    gl::ShaderMap<std::string> *shaderHLSL) const
 {
     ASSERT(shaderHLSL);
     ASSERT((*shaderHLSL)[gl::ShaderType::Vertex].empty() &&
            (*shaderHLSL)[gl::ShaderType::Fragment].empty());
 
-    gl::Shader *vertexShaderGL   = programData.getAttachedShader(ShaderType::Vertex);
-    gl::Shader *fragmentShaderGL = programData.getAttachedShader(ShaderType::Fragment);
-    const int shaderModel        = mRenderer->getMajorShaderModel();
+    const gl::SharedCompiledShaderState &vertexShader   = shaderData[ShaderType::Vertex];
+    const gl::SharedCompiledShaderState &fragmentShader = shaderData[ShaderType::Fragment];
+    const int shaderModel                               = mRenderer->getMajorShaderModel();
 
-    const ShaderD3D *fragmentShader = nullptr;
-    if (fragmentShaderGL)
-    {
-        fragmentShader = GetImplAs<ShaderD3D>(fragmentShaderGL);
-    }
+    const SharedCompiledShaderStateD3D &fragmentShaderD3D = shaderDataD3D[ShaderType::Fragment];
 
     // usesViewScale() isn't supported in the D3D9 renderer
     ASSERT(shaderModel >= 4 || !programMetadata.usesViewScale());
@@ -594,7 +591,8 @@
         mRenderer->getFeatures().useInstancedPointSpriteEmulation.enabled;
 
     // Validation done in the compiler
-    ASSERT(!fragmentShader || !fragmentShader->usesFragColor() || !fragmentShader->usesFragData());
+    ASSERT(!fragmentShaderD3D || !fragmentShaderD3D->usesFragColor ||
+           !fragmentShaderD3D->usesFragData);
 
     std::ostringstream vertexStream;
     vertexStream << "struct VS_OUTPUT\n";
@@ -870,9 +868,9 @@
                          << "    return output;\n"
                          << "}";
 
-    if (vertexShaderGL)
+    if (vertexShader)
     {
-        std::string vertexSource = vertexShaderGL->getTranslatedSource(context);
+        std::string vertexSource = vertexShader->translatedSource;
         angle::ReplaceSubstring(&vertexSource, std::string(MAIN_PROLOGUE_STUB_STRING),
                                 "    initAttributes(input);\n");
         angle::ReplaceSubstring(&vertexSource, std::string(VERTEX_OUTPUT_STUB_STRING),
@@ -889,7 +887,7 @@
     pixelStream << "\n";
 
     std::ostringstream pixelPrologue;
-    if (fragmentShader && fragmentShader->usesViewID())
+    if (fragmentShaderD3D && fragmentShaderD3D->usesViewID)
     {
         ASSERT(pixelBuiltins.glViewIDOVR.enabled);
         pixelPrologue << "    ViewID_OVR = input.gl_ViewID_OVR;\n";
@@ -982,7 +980,7 @@
                       << "    gl_PointCoord.y = 1.0 - input.gl_PointCoord.y;\n";
     }
 
-    if (fragmentShader && fragmentShader->usesFrontFacing())
+    if (fragmentShaderD3D && fragmentShaderD3D->usesFrontFacing)
     {
         if (shaderModel <= 3)
         {
@@ -995,19 +993,19 @@
     }
 
     bool declareSampleID = false;
-    if (fragmentShader && fragmentShader->usesSampleID())
+    if (fragmentShaderD3D && fragmentShaderD3D->usesSampleID)
     {
         declareSampleID = true;
         pixelPrologue << "    gl_SampleID = sampleID;\n";
     }
 
-    if (fragmentShader && fragmentShader->usesSamplePosition())
+    if (fragmentShaderD3D && fragmentShaderD3D->usesSamplePosition)
     {
         declareSampleID = true;
         pixelPrologue << "    gl_SamplePosition = GetRenderTargetSamplePosition(sampleID) + 0.5;\n";
     }
 
-    if (fragmentShader && fragmentShader->getClipDistanceArraySize())
+    if (fragmentShaderD3D && fragmentShaderD3D->clipDistanceSize)
     {
         ASSERT(vertexBuiltins.glClipDistance.indexOrSize > 0 &&
                vertexBuiltins.glClipDistance.indexOrSize < 9);
@@ -1045,7 +1043,7 @@
         }
     }
 
-    if (fragmentShader && fragmentShader->getCullDistanceArraySize())
+    if (fragmentShaderD3D && fragmentShaderD3D->cullDistanceSize)
     {
         ASSERT(vertexBuiltins.glCullDistance.indexOrSize > 0 &&
                vertexBuiltins.glCullDistance.indexOrSize < 9);
@@ -1153,7 +1151,7 @@
         pixelPrologue << ";\n";
     }
 
-    if (fragmentShader && fragmentShader->usesSampleMaskIn())
+    if (fragmentShaderD3D && fragmentShaderD3D->usesSampleMaskIn)
     {
         // When per-sample shading is active due to the use of a fragment input qualified
         // by sample or due to the use of the gl_SampleID or gl_SamplePosition variables,
@@ -1163,14 +1161,14 @@
                       << (declareSampleID ? "1 << sampleID" : "sampleMaskIn") << ";\n";
     }
 
-    if (fragmentShaderGL)
+    if (fragmentShader)
     {
-        std::string pixelSource = fragmentShaderGL->getTranslatedSource(context);
+        std::string pixelSource = fragmentShader->translatedSource;
 
         std::ostringstream pixelMainParametersStream;
         pixelMainParametersStream << "PS_INPUT input";
 
-        if (fragmentShader->usesFrontFacing())
+        if (fragmentShaderD3D->usesFrontFacing)
         {
             pixelMainParametersStream << (shaderModel >= 4 ? ", bool isFrontFace : SV_IsFrontFace"
                                                            : ", float vFace : VFACE");
@@ -1180,7 +1178,7 @@
         {
             pixelMainParametersStream << ", uint sampleID : SV_SampleIndex";
         }
-        else if (fragmentShader->usesSampleMaskIn())
+        else if (fragmentShaderD3D->usesSampleMaskIn)
         {
             pixelMainParametersStream << ", uint sampleMaskIn : SV_Coverage";
         }
@@ -1277,7 +1275,6 @@
 
 std::string DynamicHLSL::generateGeometryShaderHLSL(const gl::Caps &caps,
                                                     gl::PrimitiveMode primitiveType,
-                                                    const gl::ProgramState &programData,
                                                     const bool useViewScale,
                                                     const bool hasMultiviewEnabled,
                                                     const bool selectViewInVS,
@@ -1540,14 +1537,14 @@
     }
     else
     {
-        const ShaderD3D *fragmentShader = metadata.getFragmentShader();
+        const gl::SharedCompiledShaderState &fragmentShader = metadata.getFragmentShader();
 
         if (!fragmentShader)
         {
             return;
         }
 
-        const auto &shaderOutputVars = fragmentShader->getState().getActiveOutputVariables();
+        const auto &shaderOutputVars = fragmentShader->activeOutputVariables;
 
         for (size_t outputLocationIndex = 0u;
              outputLocationIndex < programData.getOutputLocations().size(); ++outputLocationIndex)
diff --git a/src/libANGLE/renderer/d3d/DynamicHLSL.h b/src/libANGLE/renderer/d3d/DynamicHLSL.h
index 5f409d5..88b255b 100644
--- a/src/libANGLE/renderer/d3d/DynamicHLSL.h
+++ b/src/libANGLE/renderer/d3d/DynamicHLSL.h
@@ -165,15 +165,15 @@
         size_t baseUAVRegister) const;
     std::string generateShaderForImage2DBindSignature(
         ProgramD3D &programD3D,
-        const gl::ProgramState &programData,
         gl::ShaderType shaderType,
+        const SharedCompiledShaderStateD3D &shaderData,
         const std::string &shaderHLSL,
         std::vector<sh::ShaderVariable> &image2DUniforms,
         const gl::ImageUnitTextureTypeMap &image2DBindLayout,
         unsigned int baseUAVRegister) const;
-    void generateShaderLinkHLSL(const gl::Context *context,
-                                const gl::Caps &caps,
-                                const gl::ProgramState &programData,
+    void generateShaderLinkHLSL(const gl::Caps &caps,
+                                const gl::ShaderMap<gl::SharedCompiledShaderState> &shaderData,
+                                const gl::ShaderMap<SharedCompiledShaderStateD3D> &shaderDataD3D,
                                 const ProgramD3DMetadata &programMetadata,
                                 const gl::VaryingPacking &varyingPacking,
                                 const BuiltinVaryingsD3D &builtinsD3D,
@@ -186,7 +186,6 @@
 
     std::string generateGeometryShaderHLSL(const gl::Caps &caps,
                                            gl::PrimitiveMode primitiveType,
-                                           const gl::ProgramState &programData,
                                            const bool useViewScale,
                                            const bool hasMultiviewEnabled,
                                            const bool selectViewInVS,
diff --git a/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.cpp b/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.cpp
index cf775d0..dcc61a4 100644
--- a/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.cpp
+++ b/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.cpp
@@ -688,8 +688,8 @@
 }
 
 void OutputHLSLImage2DUniformGroup(ProgramD3D &programD3D,
-                                   const gl::ProgramState &programData,
                                    gl::ShaderType shaderType,
+                                   const SharedCompiledShaderStateD3D &shaderData,
                                    std::ostringstream &out,
                                    const Image2DHLSLGroup textureGroup,
                                    const std::vector<sh::ShaderVariable> &group,
@@ -827,20 +827,17 @@
         out << "};\n";
     }
 
-    gl::Shader *shaderGL       = programData.getAttachedShader(shaderType);
-    const ShaderD3D *shaderD3D = GetImplAs<ShaderD3D>(shaderGL);
-
-    if (shaderD3D->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DSIZE)))
+    if (shaderData->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DSIZE)))
     {
         OutputImage2DSizeFunction(out, textureGroup, totalCount, texture2DCount, texture3DCount,
                                   texture2DArrayCount, offsetStr, declarationStr);
     }
-    if (shaderD3D->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DLOAD)))
+    if (shaderData->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DLOAD)))
     {
         OutputImage2DLoadFunction(out, textureGroup, totalCount, texture2DCount, texture3DCount,
                                   texture2DArrayCount, offsetStr, declarationStr);
     }
-    if (shaderD3D->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DSTORE)))
+    if (shaderData->useImage2DFunction(Image2DHLSLGroupFunctionName(textureGroup, IMAGE2DSTORE)))
     {
         OutputImage2DStoreFunction(out, textureGroup, totalCount, texture2DCount, texture3DCount,
                                    texture2DArrayCount, offsetStr, declarationStr);
@@ -853,8 +850,8 @@
 
 std::string GenerateShaderForImage2DBindSignature(
     ProgramD3D &programD3D,
-    const gl::ProgramState &programData,
     gl::ShaderType shaderType,
+    const SharedCompiledShaderStateD3D &shaderData,
     const std::string &shaderHLSL,
     std::vector<sh::ShaderVariable> &image2DUniforms,
     const gl::ImageUnitTextureTypeMap &image2DBindLayout,
@@ -891,10 +888,8 @@
         groupedImage2DUniforms[group].push_back(image2D);
     }
 
-    gl::Shader *shaderGL                     = programData.getAttachedShader(shaderType);
-    const ShaderD3D *shaderD3D               = GetImplAs<ShaderD3D>(shaderGL);
-    unsigned int groupTextureRegisterIndex   = shaderD3D->getReadonlyImage2DRegisterIndex();
-    unsigned int groupRWTextureRegisterIndex = shaderD3D->getImage2DRegisterIndex();
+    unsigned int groupTextureRegisterIndex   = shaderData->readonlyImage2DRegisterIndex;
+    unsigned int groupRWTextureRegisterIndex = shaderData->image2DRegisterIndex;
     unsigned int image2DTexture3DIndex       = 0;
     unsigned int image2DTexture2DArrayIndex  = image2DTexture3DCount;
     unsigned int image2DTexture2DIndex       = image2DTexture3DCount + image2DTexture2DArrayCount;
@@ -903,7 +898,7 @@
     for (int groupId = IMAGE2D_MIN; groupId < IMAGE2D_MAX; ++groupId)
     {
         OutputHLSLImage2DUniformGroup(
-            programD3D, programData, shaderType, out, Image2DHLSLGroup(groupId),
+            programD3D, shaderType, shaderData, out, Image2DHLSLGroup(groupId),
             groupedImage2DUniforms[groupId], image2DBindLayout, baseUAVRegister,
             &groupTextureRegisterIndex, &groupRWTextureRegisterIndex, &image2DTexture3DIndex,
             &image2DTexture2DArrayIndex, &image2DTexture2DIndex);
diff --git a/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.h b/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.h
index 6533c97..fe128a8 100644
--- a/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.h
+++ b/src/libANGLE/renderer/d3d/DynamicImage2DHLSL.h
@@ -11,13 +11,14 @@
 
 #include "common/angleutils.h"
 #include "libANGLE/renderer/d3d/RendererD3D.h"
+#include "libANGLE/renderer/d3d/ShaderD3D.h"
 
 namespace rx
 {
 std::string GenerateShaderForImage2DBindSignature(
     ProgramD3D &programD3D,
-    const gl::ProgramState &programData,
     gl::ShaderType shaderType,
+    const SharedCompiledShaderStateD3D &shaderData,
     const std::string &shaderHLSL,
     std::vector<sh::ShaderVariable> &image2DUniforms,
     const gl::ImageUnitTextureTypeMap &image2DBindLayout,
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.cpp b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
index 60b718a..d557531 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.cpp
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.cpp
@@ -39,8 +39,7 @@
 namespace
 {
 
-void GetDefaultInputLayoutFromShader(const gl::Context *context,
-                                     gl::Shader *vertexShader,
+void GetDefaultInputLayoutFromShader(const gl::SharedCompiledShaderState &vertexShader,
                                      gl::InputLayout *inputLayoutOut)
 {
     inputLayoutOut->clear();
@@ -50,7 +49,7 @@
         return;
     }
 
-    for (const sh::ShaderVariable &shaderAttr : vertexShader->getActiveAttributes(context))
+    for (const sh::ShaderVariable &shaderAttr : vertexShader->activeAttributes)
     {
         if (shaderAttr.type != GL_NONE)
         {
@@ -172,36 +171,35 @@
     return false;
 }
 
-bool FindFlatInterpolationVaryingPerShader(const gl::Context *context, gl::Shader *shader)
+bool FindFlatInterpolationVaryingPerShader(const gl::SharedCompiledShaderState &shader)
 {
     ASSERT(shader);
-    switch (shader->getType())
+    switch (shader->shaderType)
     {
         case gl::ShaderType::Vertex:
-            return HasFlatInterpolationVarying(shader->getOutputVaryings(context));
+            return HasFlatInterpolationVarying(shader->outputVaryings);
         case gl::ShaderType::Fragment:
-            return HasFlatInterpolationVarying(shader->getInputVaryings(context));
+            return HasFlatInterpolationVarying(shader->inputVaryings);
         case gl::ShaderType::Geometry:
-            return HasFlatInterpolationVarying(shader->getInputVaryings(context)) ||
-                   HasFlatInterpolationVarying(shader->getOutputVaryings(context));
+            return HasFlatInterpolationVarying(shader->inputVaryings) ||
+                   HasFlatInterpolationVarying(shader->outputVaryings);
         default:
             UNREACHABLE();
             return false;
     }
 }
 
-bool FindFlatInterpolationVarying(const gl::Context *context,
-                                  const gl::ShaderMap<gl::Shader *> &shaders)
+bool FindFlatInterpolationVarying(const gl::ShaderMap<gl::SharedCompiledShaderState> &shaders)
 {
     for (gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
     {
-        gl::Shader *shader = shaders[shaderType];
+        const gl::SharedCompiledShaderState &shader = shaders[shaderType];
         if (!shader)
         {
             continue;
         }
 
-        if (FindFlatInterpolationVaryingPerShader(context, shader))
+        if (FindFlatInterpolationVaryingPerShader(shader))
         {
             return true;
         }
@@ -400,17 +398,22 @@
 
 // ProgramD3DMetadata Implementation
 
-ProgramD3DMetadata::ProgramD3DMetadata(RendererD3D *renderer,
-                                       const gl::ShaderMap<const ShaderD3D *> &attachedShaders,
-                                       EGLenum clientType)
+ProgramD3DMetadata::ProgramD3DMetadata(
+    RendererD3D *renderer,
+    const gl::SharedCompiledShaderState &fragmentShader,
+    const gl::ShaderMap<SharedCompiledShaderStateD3D> &attachedShaders,
+    EGLenum clientType,
+    int shaderVersion)
     : mRendererMajorShaderModel(renderer->getMajorShaderModel()),
       mShaderModelSuffix(renderer->getShaderModelSuffix()),
       mUsesInstancedPointSpriteEmulation(
           renderer->getFeatures().useInstancedPointSpriteEmulation.enabled),
       mUsesViewScale(renderer->presentPathFastEnabled()),
       mCanSelectViewInVertexShader(renderer->canSelectViewInVertexShader()),
+      mFragmentShader(fragmentShader),
       mAttachedShaders(attachedShaders),
-      mClientType(clientType)
+      mClientType(clientType),
+      mShaderVersion(shaderVersion)
 {}
 
 ProgramD3DMetadata::~ProgramD3DMetadata() = default;
@@ -422,39 +425,39 @@
 
 bool ProgramD3DMetadata::usesBroadcast(const gl::State &data) const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesFragColor() && shader->usesMultipleRenderTargets() &&
-            data.getClientMajorVersion() < 3);
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesFragColor && shader->usesMultipleRenderTargets &&
+           data.getClientMajorVersion() < 3;
 }
 
 bool ProgramD3DMetadata::usesSecondaryColor() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesSecondaryColor());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesSecondaryColor;
 }
 
 FragDepthUsage ProgramD3DMetadata::getFragDepthUsage() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return shader ? shader->getFragDepthUsage() : FragDepthUsage::Unused;
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader ? shader->fragDepthUsage : FragDepthUsage::Unused;
 }
 
 bool ProgramD3DMetadata::usesPointCoord() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesPointCoord());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesPointCoord;
 }
 
 bool ProgramD3DMetadata::usesFragCoord() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesFragCoord());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesFragCoord;
 }
 
 bool ProgramD3DMetadata::usesPointSize() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    return (shader && shader->usesPointSize());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Vertex];
+    return shader && shader->usesPointSize;
 }
 
 bool ProgramD3DMetadata::usesInsertedPointCoordValue() const
@@ -470,20 +473,20 @@
 
 bool ProgramD3DMetadata::hasMultiviewEnabled() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    return (shader && shader->hasMultiviewEnabled());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Vertex];
+    return shader && shader->hasMultiviewEnabled;
 }
 
 bool ProgramD3DMetadata::usesVertexID() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    return (shader && shader->usesVertexID());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Vertex];
+    return shader && shader->usesVertexID;
 }
 
 bool ProgramD3DMetadata::usesViewID() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesViewID());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesViewID;
 }
 
 bool ProgramD3DMetadata::canSelectViewInVertexShader() const
@@ -518,46 +521,42 @@
 
 bool ProgramD3DMetadata::usesMultipleFragmentOuts() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesMultipleRenderTargets());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesMultipleRenderTargets;
 }
 
 bool ProgramD3DMetadata::usesCustomOutVars() const
 {
-
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    int version                 = shader ? shader->getState().getShaderVersion() : -1;
-
     switch (mClientType)
     {
         case EGL_OPENGL_API:
-            return version >= 130;
+            return mShaderVersion >= 130;
         default:
-            return version >= 300;
+            return mShaderVersion >= 300;
     }
 }
 
 bool ProgramD3DMetadata::usesSampleMask() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Fragment];
-    return (shader && shader->usesSampleMask());
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Fragment];
+    return shader && shader->usesSampleMask;
 }
 
-const ShaderD3D *ProgramD3DMetadata::getFragmentShader() const
+const gl::SharedCompiledShaderState &ProgramD3DMetadata::getFragmentShader() const
 {
-    return mAttachedShaders[gl::ShaderType::Fragment];
+    return mFragmentShader;
 }
 
 uint8_t ProgramD3DMetadata::getClipDistanceArraySize() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    return shader ? shader->getClipDistanceArraySize() : 0;
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Vertex];
+    return shader ? shader->clipDistanceSize : 0;
 }
 
 uint8_t ProgramD3DMetadata::getCullDistanceArraySize() const
 {
-    const rx::ShaderD3D *shader = mAttachedShaders[gl::ShaderType::Vertex];
-    return shader ? shader->getCullDistanceArraySize() : 0;
+    const SharedCompiledShaderStateD3D &shader = mAttachedShaders[gl::ShaderType::Vertex];
+    return shader ? shader->cullDistanceSize : 0;
 }
 
 // ProgramD3D::GetExecutableTask class
@@ -1584,7 +1583,7 @@
         mPixelShaderKey.size());
 
     std::string finalPixelHLSL = mDynamicHLSL->generateShaderForImage2DBindSignature(
-        *this, mState, gl::ShaderType::Fragment, pixelHLSL,
+        *this, gl::ShaderType::Fragment, mAttachedShaders[gl::ShaderType::Fragment], pixelHLSL,
         mImage2DUniforms[gl::ShaderType::Fragment],
         mImage2DBindLayoutCache[gl::ShaderType::Fragment],
         static_cast<unsigned int>(mPixelShaderKey.size()));
@@ -1689,7 +1688,7 @@
     }
     const gl::Caps &caps     = state.getCaps();
     std::string geometryHLSL = mDynamicHLSL->generateGeometryShaderHLSL(
-        caps, geometryShaderType, mState, mRenderer->presentPathFastEnabled(), mHasMultiviewEnabled,
+        caps, geometryShaderType, mRenderer->presentPathFastEnabled(), mHasMultiviewEnabled,
         mRenderer->canSelectViewInVertexShader(), usesGeometryShaderForPointSpriteEmulation(),
         mGeometryShaderPreamble);
 
@@ -1736,9 +1735,9 @@
     }
 };
 
-void ProgramD3D::updateCachedInputLayoutFromShader(const gl::Context *context)
+void ProgramD3D::updateCachedInputLayoutFromShader()
 {
-    GetDefaultInputLayoutFromShader(context, mState.getAttachedShader(gl::ShaderType::Vertex),
+    GetDefaultInputLayoutFromShader(mState.getAttachedShader(gl::ShaderType::Vertex),
                                     &mCachedInputLayout);
     VertexExecutable::getSignature(mRenderer, mCachedInputLayout, &mCachedVertexSignature);
     updateCachedVertexExecutableIndex();
@@ -1848,8 +1847,8 @@
                              std::shared_ptr<ProgramD3D::GetPixelExecutableTask> pixelTask,
                              std::shared_ptr<ProgramD3D::GetGeometryExecutableTask> geometryTask,
                              bool useGS,
-                             const ShaderD3D *vertexShader,
-                             const ShaderD3D *fragmentShader)
+                             const SharedCompiledShaderStateD3D &vertexShader,
+                             const SharedCompiledShaderStateD3D &fragmentShader)
         : mInfoLog(infoLog),
           mVertexTask(vertexTask),
           mPixelTask(pixelTask),
@@ -1938,8 +1937,8 @@
     std::shared_ptr<ProgramD3D::GetGeometryExecutableTask> mGeometryTask;
     std::array<std::shared_ptr<WaitableEvent>, 3> mWaitEvents;
     bool mUseGS;
-    const ShaderD3D *mVertexShader;
-    const ShaderD3D *mFragmentShader;
+    SharedCompiledShaderStateD3D mVertexShader;
+    SharedCompiledShaderStateD3D mFragmentShader;
 };
 
 // The LinkEvent implementation for linking a computing program.
@@ -1988,12 +1987,10 @@
     auto pixelTask  = std::make_shared<GetPixelExecutableTask>(context, this);
     auto geometryTask =
         std::make_shared<GetGeometryExecutableTask>(context, this, context->getState());
-    bool useGS                 = usesGeometryShader(context->getState(), gl::PrimitiveMode::Points);
-    gl::Shader *vertexShader   = mState.getAttachedShader(gl::ShaderType::Vertex);
-    gl::Shader *fragmentShader = mState.getAttachedShader(gl::ShaderType::Fragment);
-    const ShaderD3D *vertexShaderD3D = vertexShader ? GetImplAs<ShaderD3D>(vertexShader) : nullptr;
-    const ShaderD3D *fragmentShaderD3D =
-        fragmentShader ? GetImplAs<ShaderD3D>(fragmentShader) : nullptr;
+    bool useGS = usesGeometryShader(context->getState(), gl::PrimitiveMode::Points);
+    const SharedCompiledShaderStateD3D &vertexShaderD3D = mAttachedShaders[gl::ShaderType::Vertex];
+    const SharedCompiledShaderStateD3D &fragmentShaderD3D =
+        mAttachedShaders[gl::ShaderType::Fragment];
 
     return std::make_unique<GraphicsProgramLinkEvent>(
         infoLog, context->getShaderCompileThreadPool(), vertexTask, pixelTask, geometryTask, useGS,
@@ -2044,11 +2041,10 @@
         return angle::Result::Continue;
     }
 
-    std::string computeHLSL =
-        mState.getAttachedShader(gl::ShaderType::Compute)->getTranslatedSource(glContext);
+    std::string computeHLSL = mState.getAttachedShader(gl::ShaderType::Compute)->translatedSource;
 
     std::string finalComputeHLSL = mDynamicHLSL->generateShaderForImage2DBindSignature(
-        *this, mState, gl::ShaderType::Compute, computeHLSL,
+        *this, gl::ShaderType::Compute, mAttachedShaders[gl::ShaderType::Compute], computeHLSL,
         mImage2DUniforms[gl::ShaderType::Compute], mImage2DBindLayoutCache[gl::ShaderType::Compute],
         0u);
 
@@ -2079,6 +2075,20 @@
     return angle::Result::Continue;
 }
 
+void ProgramD3D::prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders)
+{
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        mAttachedShaders[shaderType].reset();
+
+        if (shaders[shaderType] != nullptr)
+        {
+            const ShaderD3D *shaderD3D   = GetAs<ShaderD3D>(shaders[shaderType]);
+            mAttachedShaders[shaderType] = shaderD3D->getCompiledState();
+        }
+    }
+}
+
 std::unique_ptr<LinkEvent> ProgramD3D::link(const gl::Context *context,
                                             const gl::ProgramLinkedResources &resources,
                                             gl::InfoLog &infoLog,
@@ -2089,7 +2099,8 @@
 
     reset();
 
-    gl::Shader *computeShader = mState.getAttachedShader(gl::ShaderType::Compute);
+    const gl::SharedCompiledShaderState &computeShader =
+        mState.getAttachedShader(gl::ShaderType::Compute);
     if (computeShader)
     {
         mShaderSamplers[gl::ShaderType::Compute].resize(
@@ -2099,9 +2110,9 @@
 
         mShaderUniformsDirty.set(gl::ShaderType::Compute);
 
-        linkResources(context, resources);
+        linkResources(resources);
 
-        for (const sh::ShaderVariable &uniform : computeShader->getUniforms(context))
+        for (const sh::ShaderVariable &uniform : computeShader->uniforms)
         {
             if (gl::IsImageType(uniform.type) && gl::IsImage2DType(uniform.type))
             {
@@ -2109,30 +2120,28 @@
             }
         }
 
-        defineUniformsAndAssignRegisters(context);
+        defineUniformsAndAssignRegisters();
 
         return compileComputeExecutable(context, infoLog);
     }
     else
     {
-        gl::ShaderMap<const ShaderD3D *> shadersD3D = {};
         for (gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
         {
-            if (gl::Shader *shader = mState.getAttachedShader(shaderType))
+            const SharedCompiledShaderStateD3D &shader = mAttachedShaders[shaderType];
+            if (shader)
             {
-                shadersD3D[shaderType] = GetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
-
                 mShaderSamplers[shaderType].resize(
                     data.getCaps().maxShaderTextureImageUnits[shaderType]);
                 mImages[shaderType].resize(data.getCaps().maxImageUnits);
                 mReadonlyImages[shaderType].resize(data.getCaps().maxImageUnits);
 
-                shadersD3D[shaderType]->generateWorkarounds(&mShaderWorkarounds[shaderType]);
+                shader->generateWorkarounds(&mShaderWorkarounds[shaderType]);
 
                 mShaderUniformsDirty.set(shaderType);
 
                 const std::set<std::string> &slowCompilingUniformBlockSet =
-                    shadersD3D[shaderType]->getSlowCompilingUniformBlockSet();
+                    shader->slowCompilingUniformBlockSet;
                 if (slowCompilingUniformBlockSet.size() > 0)
                 {
                     std::ostringstream stream;
@@ -2152,7 +2161,8 @@
                                        stream.str().c_str());
                 }
 
-                for (const sh::ShaderVariable &uniform : shader->getUniforms(context))
+                for (const sh::ShaderVariable &uniform :
+                     mState.getAttachedShader(shaderType)->uniforms)
                 {
                     if (gl::IsImageType(uniform.type) && gl::IsImage2DType(uniform.type))
                     {
@@ -2164,8 +2174,9 @@
 
         if (mRenderer->getNativeLimitations().noFrontFacingSupport)
         {
-            const ShaderD3D *fragmentShader = shadersD3D[gl::ShaderType::Fragment];
-            if (fragmentShader && fragmentShader->usesFrontFacing())
+            const SharedCompiledShaderStateD3D &fragmentShader =
+                mAttachedShaders[gl::ShaderType::Fragment];
+            if (fragmentShader && fragmentShader->usesFrontFacing)
             {
                 infoLog << "The current renderer doesn't support gl_FrontFacing";
                 return std::make_unique<LinkEventDone>(angle::Result::Incomplete);
@@ -2175,14 +2186,18 @@
         const gl::VaryingPacking &varyingPacking =
             resources.varyingPacking.getOutputPacking(gl::ShaderType::Vertex);
 
-        ProgramD3DMetadata metadata(mRenderer, shadersD3D, context->getClientType());
+        ProgramD3DMetadata metadata(
+            mRenderer, mState.getAttachedShader(gl::ShaderType::Fragment), mAttachedShaders,
+            context->getClientType(),
+            mState.getAttachedShader(gl::ShaderType::Vertex)->shaderVersion);
         BuiltinVaryingsD3D builtins(metadata, varyingPacking);
 
-        mDynamicHLSL->generateShaderLinkHLSL(context, context->getCaps(), mState, metadata,
-                                             varyingPacking, builtins, &mShaderHLSL);
+        mDynamicHLSL->generateShaderLinkHLSL(context->getCaps(), mState.getAttachedShaders(),
+                                             mAttachedShaders, metadata, varyingPacking, builtins,
+                                             &mShaderHLSL);
 
-        const ShaderD3D *vertexShader = shadersD3D[gl::ShaderType::Vertex];
-        mUsesPointSize                = vertexShader && vertexShader->usesPointSize();
+        mUsesPointSize = mAttachedShaders[gl::ShaderType::Vertex] &&
+                         mAttachedShaders[gl::ShaderType::Vertex]->usesPointSize;
         mDynamicHLSL->getPixelShaderOutputKey(data, mState, metadata, &mPixelShaderKey);
         mFragDepthUsage      = metadata.getFragDepthUsage();
         mUsesSampleMask      = metadata.usesSampleMask();
@@ -2191,7 +2206,7 @@
         mHasMultiviewEnabled = metadata.hasMultiviewEnabled();
 
         // Cache if we use flat shading
-        mUsesFlatInterpolation = FindFlatInterpolationVarying(context, mState.getAttachedShaders());
+        mUsesFlatInterpolation = FindFlatInterpolationVarying(mState.getAttachedShaders());
 
         if (mRenderer->getMajorShaderModel() >= 4)
         {
@@ -2200,17 +2215,17 @@
                 metadata.canSelectViewInVertexShader());
         }
 
-        initAttribLocationsToD3DSemantic(context);
+        initAttribLocationsToD3DSemantic();
 
-        defineUniformsAndAssignRegisters(context);
+        defineUniformsAndAssignRegisters();
 
         gatherTransformFeedbackVaryings(varyingPacking, builtins[gl::ShaderType::Vertex]);
 
-        linkResources(context, resources);
+        linkResources(resources);
 
         if (mState.getAttachedShader(gl::ShaderType::Vertex))
         {
-            updateCachedInputLayoutFromShader(context);
+            updateCachedInputLayoutFromShader();
         }
 
         return compileProgramExecutables(context, infoLog);
@@ -2223,7 +2238,7 @@
     return GL_TRUE;
 }
 
-void ProgramD3D::initializeShaderStorageBlocks(const gl::Context *context)
+void ProgramD3D::initializeShaderStorageBlocks()
 {
     if (mState.getShaderStorageBlocks().empty())
     {
@@ -2233,11 +2248,6 @@
     ASSERT(mD3DShaderStorageBlocks.empty());
 
     // Assign registers and update sizes.
-    gl::ShaderMap<const ShaderD3D *> shadersD3D = {};
-    for (gl::ShaderType shaderType : gl::AllShaderTypes())
-    {
-        shadersD3D[shaderType] = SafeGetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
-    }
     for (const gl::InterfaceBlock &shaderStorageBlock : mState.getShaderStorageBlocks())
     {
         unsigned int shaderStorageBlockElement =
@@ -2248,9 +2258,10 @@
         {
             if (shaderStorageBlock.isActive(shaderType))
             {
-                ASSERT(shadersD3D[shaderType]);
+                ASSERT(mAttachedShaders[shaderType]);
                 unsigned int baseRegister =
-                    shadersD3D[shaderType]->getShaderStorageBlockRegister(shaderStorageBlock.name);
+                    mAttachedShaders[shaderType]->getShaderStorageBlockRegister(
+                        shaderStorageBlock.name);
 
                 d3dShaderStorageBlock.mShaderRegisterIndexes[shaderType] =
                     baseRegister + shaderStorageBlockElement;
@@ -2261,22 +2272,22 @@
 
     for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        gl::Shader *shader = mState.getAttachedShader(shaderType);
+        const gl::SharedCompiledShaderState &shader = mState.getAttachedShader(shaderType);
         if (!shader)
         {
             continue;
         }
-        ShaderD3D *shaderD3D = SafeGetImplAs<ShaderD3D>(shader);
-        for (const sh::InterfaceBlock &ssbo : shader->getShaderStorageBlocks(context))
+        for (const sh::InterfaceBlock &ssbo : shader->shaderStorageBlocks)
         {
             if (!ssbo.active)
             {
                 continue;
             }
             ShaderStorageBlock block;
-            block.name          = !ssbo.instanceName.empty() ? ssbo.instanceName : ssbo.name;
-            block.arraySize     = ssbo.isArray() ? ssbo.arraySize : 0;
-            block.registerIndex = shaderD3D->getShaderStorageBlockRegister(ssbo.name);
+            block.name      = !ssbo.instanceName.empty() ? ssbo.instanceName : ssbo.name;
+            block.arraySize = ssbo.isArray() ? ssbo.arraySize : 0;
+            block.registerIndex =
+                mAttachedShaders[shaderType]->getShaderStorageBlockRegister(ssbo.name);
             mShaderStorageBlocks[shaderType].push_back(block);
         }
     }
@@ -2292,12 +2303,6 @@
     ASSERT(mD3DUniformBlocks.empty());
 
     // Assign registers and update sizes.
-    gl::ShaderMap<const ShaderD3D *> shadersD3D = {};
-    for (gl::ShaderType shaderType : gl::AllShaderTypes())
-    {
-        shadersD3D[shaderType] = SafeGetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
-    }
-
     for (const gl::InterfaceBlock &uniformBlock : mState.getUniformBlocks())
     {
         unsigned int uniformBlockElement = uniformBlock.isArray ? uniformBlock.arrayElement : 0;
@@ -2308,13 +2313,13 @@
         {
             if (uniformBlock.isActive(shaderType))
             {
-                ASSERT(shadersD3D[shaderType]);
+                ASSERT(mAttachedShaders[shaderType]);
                 unsigned int baseRegister =
-                    shadersD3D[shaderType]->getUniformBlockRegister(uniformBlock.name);
+                    mAttachedShaders[shaderType]->getUniformBlockRegister(uniformBlock.name);
                 d3dUniformBlock.mShaderRegisterIndexes[shaderType] =
                     baseRegister + uniformBlockElement;
                 bool useStructuredBuffer =
-                    shadersD3D[shaderType]->shouldUniformBlockUseStructuredBuffer(
+                    mAttachedShaders[shaderType]->shouldUniformBlockUseStructuredBuffer(
                         uniformBlock.name);
                 if (useStructuredBuffer)
                 {
@@ -2615,25 +2620,25 @@
     setUniformInternal(location, count, v, GL_UNSIGNED_INT_VEC4);
 }
 
-void ProgramD3D::defineUniformsAndAssignRegisters(const gl::Context *context)
+void ProgramD3D::defineUniformsAndAssignRegisters()
 {
     D3DUniformMap uniformMap;
 
     gl::ShaderBitSet attachedShaders;
     for (gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        gl::Shader *shader = mState.getAttachedShader(shaderType);
+        const gl::SharedCompiledShaderState &shader = mState.getAttachedShader(shaderType);
         if (shader)
         {
-            for (const sh::ShaderVariable &uniform : shader->getUniforms(context))
+            for (const sh::ShaderVariable &uniform : shader->uniforms)
             {
                 if (uniform.active)
                 {
-                    defineUniformBase(shader, uniform, &uniformMap);
+                    defineUniformBase(shader->shaderType, uniform, &uniformMap);
                 }
             }
 
-            attachedShaders.set(shader->getType());
+            attachedShaders.set(shader->shaderType);
         }
     }
 
@@ -2676,7 +2681,7 @@
     initializeUniformStorage(attachedShaders);
 }
 
-void ProgramD3D::defineUniformBase(const gl::Shader *shader,
+void ProgramD3D::defineUniformBase(gl::ShaderType shaderType,
                                    const sh::ShaderVariable &uniform,
                                    D3DUniformMap *uniformMap)
 {
@@ -2686,8 +2691,8 @@
     // registers assigned in assignAllImageRegisters.
     if (gl::IsSamplerType(uniform.type))
     {
-        UniformEncodingVisitorD3D visitor(shader->getType(), HLSLRegisterType::Texture,
-                                          &stubEncoder, uniformMap);
+        UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::Texture, &stubEncoder,
+                                          uniformMap);
         sh::TraverseShaderVariable(uniform, false, &visitor);
         return;
     }
@@ -2696,14 +2701,14 @@
     {
         if (uniform.readonly)
         {
-            UniformEncodingVisitorD3D visitor(shader->getType(), HLSLRegisterType::Texture,
-                                              &stubEncoder, uniformMap);
+            UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::Texture, &stubEncoder,
+                                              uniformMap);
             sh::TraverseShaderVariable(uniform, false, &visitor);
         }
         else
         {
-            UniformEncodingVisitorD3D visitor(
-                shader->getType(), HLSLRegisterType::UnorderedAccessView, &stubEncoder, uniformMap);
+            UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::UnorderedAccessView,
+                                              &stubEncoder, uniformMap);
             sh::TraverseShaderVariable(uniform, false, &visitor);
         }
         mImageBindingMap[uniform.name] = uniform.binding;
@@ -2712,28 +2717,27 @@
 
     if (uniform.isBuiltIn() && !uniform.isEmulatedBuiltIn())
     {
-        UniformEncodingVisitorD3D visitor(shader->getType(), HLSLRegisterType::None, &stubEncoder,
+        UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::None, &stubEncoder,
                                           uniformMap);
         sh::TraverseShaderVariable(uniform, false, &visitor);
         return;
     }
     else if (gl::IsAtomicCounterType(uniform.type))
     {
-        UniformEncodingVisitorD3D visitor(shader->getType(), HLSLRegisterType::UnorderedAccessView,
+        UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::UnorderedAccessView,
                                           &stubEncoder, uniformMap);
         sh::TraverseShaderVariable(uniform, false, &visitor);
         mAtomicBindingMap[uniform.name] = uniform.binding;
         return;
     }
 
-    const ShaderD3D *shaderD3D = GetImplAs<ShaderD3D>(shader);
-    unsigned int startRegister = shaderD3D->getUniformRegister(uniform.name);
-    ShShaderOutput outputType  = shaderD3D->getCompilerOutputType();
+    const SharedCompiledShaderStateD3D &shaderD3D = mAttachedShaders[shaderType];
+    unsigned int startRegister                    = shaderD3D->getUniformRegister(uniform.name);
+    ShShaderOutput outputType                     = shaderD3D->compilerOutputType;
     sh::HLSLBlockEncoder encoder(sh::HLSLBlockEncoder::GetStrategyFor(outputType), true);
     encoder.skipRegisters(startRegister);
 
-    UniformEncodingVisitorD3D visitor(shader->getType(), HLSLRegisterType::None, &encoder,
-                                      uniformMap);
+    UniformEncodingVisitorD3D visitor(shaderType, HLSLRegisterType::None, &encoder, uniformMap);
     sh::TraverseShaderVariable(uniform, false, &visitor);
 }
 
@@ -2884,7 +2888,7 @@
             continue;
         }
 
-        const ShaderD3D *shaderD3D = GetImplAs<ShaderD3D>(mState.getAttachedShader(shaderType));
+        const SharedCompiledShaderStateD3D &shaderD3D = mAttachedShaders[shaderType];
         if (shaderD3D->hasUniform(baseName))
         {
             d3dUniform->mShaderRegisterIndexes[shaderType] =
@@ -2939,17 +2943,14 @@
     {
         return;
     }
-    gl::ShaderType shaderType       = gl::ShaderType::Compute;
-    const gl::Shader *computeShader = mState.getAttachedShader(shaderType);
+    const SharedCompiledShaderStateD3D &computeShader = mAttachedShaders[gl::ShaderType::Compute];
     if (computeShader)
     {
-        const ShaderD3D *computeShaderD3D = GetImplAs<ShaderD3D>(computeShader);
-        auto &registerIndices             = mComputeAtomicCounterBufferRegisterIndices;
+        auto &registerIndices = mComputeAtomicCounterBufferRegisterIndices;
         for (auto &atomicBinding : mAtomicBindingMap)
         {
-            ASSERT(computeShaderD3D->hasUniform(atomicBinding.first));
-            unsigned int currentRegister =
-                computeShaderD3D->getUniformRegister(atomicBinding.first);
+            ASSERT(computeShader->hasUniform(atomicBinding.first));
+            unsigned int currentRegister = computeShader->getUniformRegister(atomicBinding.first);
             ASSERT(currentRegister != GL_INVALID_INDEX);
             const int kBinding = atomicBinding.second;
 
@@ -2978,14 +2979,12 @@
     unsigned int registerOffset =
         mState.getUniforms()[uniformIndex].parentArrayIndex * d3dUniform->getArraySizeProduct();
 
-    const gl::Shader *computeShader = mState.getAttachedShader(gl::ShaderType::Compute);
+    const SharedCompiledShaderStateD3D &computeShader = mAttachedShaders[gl::ShaderType::Compute];
     if (computeShader)
     {
-        const ShaderD3D *computeShaderD3D =
-            GetImplAs<ShaderD3D>(mState.getAttachedShader(gl::ShaderType::Compute));
-        ASSERT(computeShaderD3D->hasUniform(baseName));
+        ASSERT(computeShader->hasUniform(baseName));
         d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute] =
-            computeShaderD3D->getUniformRegister(baseName) + registerOffset;
+            computeShader->getUniformRegister(baseName) + registerOffset;
         ASSERT(d3dUniform->mShaderRegisterIndexes[gl::ShaderType::Compute] != GL_INVALID_INDEX);
         auto bindingIter = mImageBindingMap.find(baseName);
         ASSERT(bindingIter != mImageBindingMap.end());
@@ -3132,9 +3131,10 @@
     return mCurrentSerial++;
 }
 
-void ProgramD3D::initAttribLocationsToD3DSemantic(const gl::Context *context)
+void ProgramD3D::initAttribLocationsToD3DSemantic()
 {
-    gl::Shader *vertexShader = mState.getAttachedShader(gl::ShaderType::Vertex);
+    const gl::SharedCompiledShaderState &vertexShader =
+        mState.getAttachedShader(gl::ShaderType::Vertex);
     if (!vertexShader)
     {
         return;
@@ -3142,7 +3142,7 @@
 
     // Init semantic index
     int semanticIndex = 0;
-    for (const sh::ShaderVariable &attribute : vertexShader->getActiveAttributes(context))
+    for (const sh::ShaderVariable &attribute : vertexShader->activeAttributes)
     {
         int regCount    = gl::VariableRegisterCount(attribute.type);
         GLuint location = mState.getAttributeLocation(attribute.name);
@@ -3414,16 +3414,15 @@
     }
 }
 
-void ProgramD3D::linkResources(const gl::Context *context,
-                               const gl::ProgramLinkedResources &resources)
+void ProgramD3D::linkResources(const gl::ProgramLinkedResources &resources)
 {
     HLSLBlockLayoutEncoderFactory hlslEncoderFactory;
     gl::ProgramLinkedResourcesLinker linker(&hlslEncoderFactory);
 
-    linker.linkResources(context, mState, resources);
+    linker.linkResources(mState, resources);
 
     initializeUniformBlocks();
-    initializeShaderStorageBlocks(context);
+    initializeShaderStorageBlocks();
 }
 
 }  // namespace rx
diff --git a/src/libANGLE/renderer/d3d/ProgramD3D.h b/src/libANGLE/renderer/d3d/ProgramD3D.h
index 0191207..15e8445 100644
--- a/src/libANGLE/renderer/d3d/ProgramD3D.h
+++ b/src/libANGLE/renderer/d3d/ProgramD3D.h
@@ -18,6 +18,7 @@
 #include "libANGLE/renderer/ProgramImpl.h"
 #include "libANGLE/renderer/d3d/DynamicHLSL.h"
 #include "libANGLE/renderer/d3d/RendererD3D.h"
+#include "libANGLE/renderer/d3d/ShaderD3D.h"
 #include "platform/autogen/FeaturesD3D_autogen.h"
 
 namespace rx
@@ -146,8 +147,10 @@
 {
   public:
     ProgramD3DMetadata(RendererD3D *renderer,
-                       const gl::ShaderMap<const ShaderD3D *> &attachedShaders,
-                       EGLenum clientType);
+                       const gl::SharedCompiledShaderState &fragmentShader,
+                       const gl::ShaderMap<SharedCompiledShaderStateD3D> &attachedShaders,
+                       EGLenum clientType,
+                       int shaderVersion);
     ~ProgramD3DMetadata();
 
     int getRendererMajorShaderModel() const;
@@ -168,7 +171,7 @@
     bool usesMultipleFragmentOuts() const;
     bool usesCustomOutVars() const;
     bool usesSampleMask() const;
-    const ShaderD3D *getFragmentShader() const;
+    const gl::SharedCompiledShaderState &getFragmentShader() const;
     FragDepthUsage getFragDepthUsage() const;
     uint8_t getClipDistanceArraySize() const;
     uint8_t getCullDistanceArraySize() const;
@@ -179,8 +182,10 @@
     const bool mUsesInstancedPointSpriteEmulation;
     const bool mUsesViewScale;
     const bool mCanSelectViewInVertexShader;
-    const gl::ShaderMap<const ShaderD3D *> mAttachedShaders;
+    gl::SharedCompiledShaderState mFragmentShader;
+    const gl::ShaderMap<SharedCompiledShaderStateD3D> &mAttachedShaders;
     const EGLenum mClientType;
+    int mShaderVersion;
 };
 
 using D3DUniformMap = std::map<std::string, D3DUniform *>;
@@ -242,6 +247,7 @@
                                                            d3d::Context *context,
                                                            ShaderExecutableD3D **outExecutable,
                                                            gl::InfoLog *infoLog);
+    void prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders) override;
     std::unique_ptr<LinkEvent> link(const gl::Context *context,
                                     const gl::ProgramLinkedResources &resources,
                                     gl::InfoLog &infoLog,
@@ -468,8 +474,8 @@
 
     void initializeUniformStorage(const gl::ShaderBitSet &availableShaderStages);
 
-    void defineUniformsAndAssignRegisters(const gl::Context *context);
-    void defineUniformBase(const gl::Shader *shader,
+    void defineUniformsAndAssignRegisters();
+    void defineUniformBase(gl::ShaderType shaderType,
                            const sh::ShaderVariable &uniform,
                            D3DUniformMap *uniformMap);
     void assignAllSamplerRegisters();
@@ -524,20 +530,20 @@
     D3DUniform *getD3DUniformFromLocation(GLint location);
     const D3DUniform *getD3DUniformFromLocation(GLint location) const;
 
-    void initAttribLocationsToD3DSemantic(const gl::Context *context);
+    void initAttribLocationsToD3DSemantic();
 
     void reset();
     void initializeUniformBlocks();
-    void initializeShaderStorageBlocks(const gl::Context *context);
+    void initializeShaderStorageBlocks();
 
-    void updateCachedInputLayoutFromShader(const gl::Context *context);
+    void updateCachedInputLayoutFromShader();
     void updateCachedOutputLayoutFromShader();
     void updateCachedImage2DBindLayoutFromShader(gl::ShaderType shaderType);
     void updateCachedVertexExecutableIndex();
     void updateCachedPixelExecutableIndex();
     void updateCachedComputeExecutableIndex();
 
-    void linkResources(const gl::Context *context, const gl::ProgramLinkedResources &resources);
+    void linkResources(const gl::ProgramLinkedResources &resources);
 
     RendererD3D *mRenderer;
     DynamicHLSL *mDynamicHLSL;
@@ -548,6 +554,8 @@
         mGeometryExecutables;
     std::vector<std::unique_ptr<ComputeExecutable>> mComputeExecutables;
 
+    gl::ShaderMap<SharedCompiledShaderStateD3D> mAttachedShaders;
+
     gl::ShaderMap<std::string> mShaderHLSL;
     gl::ShaderMap<CompilerWorkaroundsD3D> mShaderWorkarounds;
 
diff --git a/src/libANGLE/renderer/d3d/ShaderD3D.cpp b/src/libANGLE/renderer/d3d/ShaderD3D.cpp
index fbfd272..d26e59d 100644
--- a/src/libANGLE/renderer/d3d/ShaderD3D.cpp
+++ b/src/libANGLE/renderer/d3d/ShaderD3D.cpp
@@ -89,68 +89,64 @@
     std::shared_ptr<TranslateTaskD3D> mTranslateTask;
 };
 
+CompiledShaderStateD3D::CompiledShaderStateD3D()
+    : compilerOutputType(SH_ESSL_OUTPUT),
+      usesMultipleRenderTargets(false),
+      usesFragColor(false),
+      usesFragData(false),
+      usesSecondaryColor(false),
+      usesFragCoord(false),
+      usesFrontFacing(false),
+      usesHelperInvocation(false),
+      usesPointSize(false),
+      usesPointCoord(false),
+      usesDepthRange(false),
+      usesSampleID(false),
+      usesSamplePosition(false),
+      usesSampleMaskIn(false),
+      usesSampleMask(false),
+      hasMultiviewEnabled(false),
+      usesVertexID(false),
+      usesViewID(false),
+      usesDiscardRewriting(false),
+      usesNestedBreak(false),
+      requiresIEEEStrictCompiling(false),
+      fragDepthUsage(FragDepthUsage::Unused),
+      clipDistanceSize(0),
+      cullDistanceSize(0),
+      readonlyImage2DRegisterIndex(0),
+      image2DRegisterIndex(0)
+{}
+
+CompiledShaderStateD3D::~CompiledShaderStateD3D() = default;
+
 ShaderD3D::ShaderD3D(const gl::ShaderState &state, RendererD3D *renderer)
     : ShaderImpl(state), mRenderer(renderer)
-{
-    uncompile();
-}
+{}
 
 ShaderD3D::~ShaderD3D() {}
 
 std::string ShaderD3D::getDebugInfo() const
 {
-    if (mDebugInfo.empty())
+    if (!mCompiledState || mCompiledState->debugInfo.empty())
     {
         return "";
     }
 
-    return mDebugInfo + std::string("\n// ") + gl::GetShaderTypeString(mState.getShaderType()) +
-           " SHADER END\n";
+    return mCompiledState->debugInfo + std::string("\n// ") +
+           gl::GetShaderTypeString(mState.getShaderType()) + " SHADER END\n";
 }
 
-// initialize/clean up previous state
-void ShaderD3D::uncompile()
+void CompiledShaderStateD3D::generateWorkarounds(CompilerWorkaroundsD3D *workarounds) const
 {
-    // set by compileToHLSL
-    mCompilerOutputType = SH_ESSL_OUTPUT;
-
-    mUsesMultipleRenderTargets   = false;
-    mUsesFragColor               = false;
-    mUsesFragData                = false;
-    mUsesSecondaryColor          = false;
-    mUsesFragCoord               = false;
-    mUsesFrontFacing             = false;
-    mUsesHelperInvocation        = false;
-    mUsesPointSize               = false;
-    mUsesPointCoord              = false;
-    mUsesSampleID                = false;
-    mUsesSamplePosition          = false;
-    mUsesSampleMaskIn            = false;
-    mUsesSampleMask              = false;
-    mUsesDepthRange              = false;
-    mHasMultiviewEnabled         = false;
-    mUsesVertexID                = false;
-    mUsesViewID                  = false;
-    mUsesDiscardRewriting        = false;
-    mUsesNestedBreak             = false;
-    mRequiresIEEEStrictCompiling = false;
-    mFragDepthUsage              = FragDepthUsage::Unused;
-    mClipDistanceSize            = 0;
-    mCullDistanceSize            = 0;
-
-    mDebugInfo.clear();
-}
-
-void ShaderD3D::generateWorkarounds(CompilerWorkaroundsD3D *workarounds) const
-{
-    if (mUsesDiscardRewriting)
+    if (usesDiscardRewriting)
     {
         // ANGLE issue 486:
         // Work-around a D3D9 compiler bug that presents itself when using conditional discard, by
         // disabling optimization
         workarounds->skipOptimization = true;
     }
-    else if (mUsesNestedBreak)
+    else if (usesNestedBreak)
     {
         // ANGLE issue 603:
         // Work-around a D3D9 compiler bug that presents itself when using break in a nested loop,
@@ -160,55 +156,52 @@
         workarounds->useMaxOptimization = true;
     }
 
-    if (mRequiresIEEEStrictCompiling)
+    if (requiresIEEEStrictCompiling)
     {
         // IEEE Strictness for D3D compiler needs to be enabled for NaNs to work.
         workarounds->enableIEEEStrictness = true;
     }
 }
 
-unsigned int ShaderD3D::getUniformRegister(const std::string &uniformName) const
+unsigned int CompiledShaderStateD3D::getUniformRegister(const std::string &uniformName) const
 {
-    ASSERT(mUniformRegisterMap.count(uniformName) > 0);
-    return mUniformRegisterMap.find(uniformName)->second;
+    ASSERT(uniformRegisterMap.count(uniformName) > 0);
+    return uniformRegisterMap.find(uniformName)->second;
 }
 
-unsigned int ShaderD3D::getUniformBlockRegister(const std::string &blockName) const
+unsigned int CompiledShaderStateD3D::getUniformBlockRegister(const std::string &blockName) const
 {
-    ASSERT(mUniformBlockRegisterMap.count(blockName) > 0);
-    return mUniformBlockRegisterMap.find(blockName)->second;
+    ASSERT(uniformBlockRegisterMap.count(blockName) > 0);
+    return uniformBlockRegisterMap.find(blockName)->second;
 }
 
-bool ShaderD3D::shouldUniformBlockUseStructuredBuffer(const std::string &blockName) const
+bool CompiledShaderStateD3D::shouldUniformBlockUseStructuredBuffer(
+    const std::string &blockName) const
 {
-    ASSERT(mUniformBlockUseStructuredBufferMap.count(blockName) > 0);
-    return mUniformBlockUseStructuredBufferMap.find(blockName)->second;
+    ASSERT(uniformBlockUseStructuredBufferMap.count(blockName) > 0);
+    return uniformBlockUseStructuredBufferMap.find(blockName)->second;
 }
 
-unsigned int ShaderD3D::getShaderStorageBlockRegister(const std::string &blockName) const
+unsigned int CompiledShaderStateD3D::getShaderStorageBlockRegister(
+    const std::string &blockName) const
 {
-    ASSERT(mShaderStorageBlockRegisterMap.count(blockName) > 0);
-    return mShaderStorageBlockRegisterMap.find(blockName)->second;
+    ASSERT(shaderStorageBlockRegisterMap.count(blockName) > 0);
+    return shaderStorageBlockRegisterMap.find(blockName)->second;
 }
 
-ShShaderOutput ShaderD3D::getCompilerOutputType() const
+bool CompiledShaderStateD3D::useImage2DFunction(const std::string &functionName) const
 {
-    return mCompilerOutputType;
-}
-
-bool ShaderD3D::useImage2DFunction(const std::string &functionName) const
-{
-    if (mUsedImage2DFunctionNames.empty())
+    if (usedImage2DFunctionNames.empty())
     {
         return false;
     }
 
-    return mUsedImage2DFunctionNames.find(functionName) != mUsedImage2DFunctionNames.end();
+    return usedImage2DFunctionNames.find(functionName) != usedImage2DFunctionNames.end();
 }
 
-const std::set<std::string> &ShaderD3D::getSlowCompilingUniformBlockSet() const
+const std::set<std::string> &CompiledShaderStateD3D::getSlowCompilingUniformBlockSet() const
 {
-    return mSlowCompilingUniformBlockSet;
+    return slowCompilingUniformBlockSet;
 }
 
 const std::map<std::string, unsigned int> &GetUniformRegisterMap(
@@ -236,8 +229,11 @@
                                                          gl::ShCompilerInstance *compilerInstance,
                                                          ShCompileOptions *options)
 {
+    // Create a new compiled shader state.  Currently running program link jobs will use the
+    // previous state.
+    mCompiledState = std::make_shared<CompiledShaderStateD3D>();
+
     std::string sourcePath;
-    uncompile();
 
     const angle::FeaturesD3D &features = mRenderer->getFeatures();
     const gl::Extensions &extensions   = mRenderer->getNativeExtensions();
@@ -299,58 +295,67 @@
     }
 
     auto postTranslateFunctor = [this](gl::ShCompilerInstance *compiler, std::string *infoLog) {
+        const std::string &translatedSource = mState.getCompiledState()->translatedSource;
+        CompiledShaderStateD3D *state       = mCompiledState.get();
+
         // TODO(jmadill): We shouldn't need to cache this.
-        mCompilerOutputType = compiler->getShaderOutputType();
+        state->compilerOutputType = compiler->getShaderOutputType();
 
-        const std::string &translatedSource = mState.getTranslatedSource();
-
-        mUsesMultipleRenderTargets = translatedSource.find("GL_USES_MRT") != std::string::npos;
-        mUsesFragColor      = translatedSource.find("GL_USES_FRAG_COLOR") != std::string::npos;
-        mUsesFragData       = translatedSource.find("GL_USES_FRAG_DATA") != std::string::npos;
-        mUsesSecondaryColor = translatedSource.find("GL_USES_SECONDARY_COLOR") != std::string::npos;
-        mUsesFragCoord      = translatedSource.find("GL_USES_FRAG_COORD") != std::string::npos;
-        mUsesFrontFacing    = translatedSource.find("GL_USES_FRONT_FACING") != std::string::npos;
-        mUsesSampleID       = translatedSource.find("GL_USES_SAMPLE_ID") != std::string::npos;
-        mUsesSamplePosition = translatedSource.find("GL_USES_SAMPLE_POSITION") != std::string::npos;
-        mUsesSampleMaskIn   = translatedSource.find("GL_USES_SAMPLE_MASK_IN") != std::string::npos;
-        mUsesSampleMask     = translatedSource.find("GL_USES_SAMPLE_MASK_OUT") != std::string::npos;
-        mUsesHelperInvocation =
+        state->usesMultipleRenderTargets =
+            translatedSource.find("GL_USES_MRT") != std::string::npos;
+        state->usesFragColor = translatedSource.find("GL_USES_FRAG_COLOR") != std::string::npos;
+        state->usesFragData  = translatedSource.find("GL_USES_FRAG_DATA") != std::string::npos;
+        state->usesSecondaryColor =
+            translatedSource.find("GL_USES_SECONDARY_COLOR") != std::string::npos;
+        state->usesFragCoord   = translatedSource.find("GL_USES_FRAG_COORD") != std::string::npos;
+        state->usesFrontFacing = translatedSource.find("GL_USES_FRONT_FACING") != std::string::npos;
+        state->usesSampleID    = translatedSource.find("GL_USES_SAMPLE_ID") != std::string::npos;
+        state->usesSamplePosition =
+            translatedSource.find("GL_USES_SAMPLE_POSITION") != std::string::npos;
+        state->usesSampleMaskIn =
+            translatedSource.find("GL_USES_SAMPLE_MASK_IN") != std::string::npos;
+        state->usesSampleMask =
+            translatedSource.find("GL_USES_SAMPLE_MASK_OUT") != std::string::npos;
+        state->usesHelperInvocation =
             translatedSource.find("GL_USES_HELPER_INVOCATION") != std::string::npos;
-        mUsesPointSize       = translatedSource.find("GL_USES_POINT_SIZE") != std::string::npos;
-        mUsesPointCoord      = translatedSource.find("GL_USES_POINT_COORD") != std::string::npos;
-        mUsesDepthRange      = translatedSource.find("GL_USES_DEPTH_RANGE") != std::string::npos;
-        mHasMultiviewEnabled = translatedSource.find("GL_MULTIVIEW_ENABLED") != std::string::npos;
-        mUsesVertexID        = translatedSource.find("GL_USES_VERTEX_ID") != std::string::npos;
-        mUsesViewID          = translatedSource.find("GL_USES_VIEW_ID") != std::string::npos;
-        mUsesDiscardRewriting =
+        state->usesPointSize  = translatedSource.find("GL_USES_POINT_SIZE") != std::string::npos;
+        state->usesPointCoord = translatedSource.find("GL_USES_POINT_COORD") != std::string::npos;
+        state->usesDepthRange = translatedSource.find("GL_USES_DEPTH_RANGE") != std::string::npos;
+        state->hasMultiviewEnabled =
+            translatedSource.find("GL_MULTIVIEW_ENABLED") != std::string::npos;
+        state->usesVertexID = translatedSource.find("GL_USES_VERTEX_ID") != std::string::npos;
+        state->usesViewID   = translatedSource.find("GL_USES_VIEW_ID") != std::string::npos;
+        state->usesDiscardRewriting =
             translatedSource.find("ANGLE_USES_DISCARD_REWRITING") != std::string::npos;
-        mUsesNestedBreak = translatedSource.find("ANGLE_USES_NESTED_BREAK") != std::string::npos;
-        mRequiresIEEEStrictCompiling =
+        state->usesNestedBreak =
+            translatedSource.find("ANGLE_USES_NESTED_BREAK") != std::string::npos;
+        state->requiresIEEEStrictCompiling =
             translatedSource.find("ANGLE_REQUIRES_IEEE_STRICT_COMPILING") != std::string::npos;
 
         ShHandle compilerHandle = compiler->getHandle();
 
         if (translatedSource.find("GL_USES_FRAG_DEPTH_GREATER") != std::string::npos)
         {
-            mFragDepthUsage = FragDepthUsage::Greater;
+            state->fragDepthUsage = FragDepthUsage::Greater;
         }
         else if (translatedSource.find("GL_USES_FRAG_DEPTH_LESS") != std::string::npos)
         {
-            mFragDepthUsage = FragDepthUsage::Less;
+            state->fragDepthUsage = FragDepthUsage::Less;
         }
         else if (translatedSource.find("GL_USES_FRAG_DEPTH") != std::string::npos)
         {
-            mFragDepthUsage = FragDepthUsage::Any;
+            state->fragDepthUsage = FragDepthUsage::Any;
         }
-        mClipDistanceSize   = sh::GetClipDistanceArraySize(compilerHandle);
-        mCullDistanceSize   = sh::GetCullDistanceArraySize(compilerHandle);
-        mUniformRegisterMap = GetUniformRegisterMap(sh::GetUniformRegisterMap(compilerHandle));
-        mReadonlyImage2DRegisterIndex = sh::GetReadonlyImage2DRegisterIndex(compilerHandle);
-        mImage2DRegisterIndex         = sh::GetImage2DRegisterIndex(compilerHandle);
-        mUsedImage2DFunctionNames =
+        state->clipDistanceSize = sh::GetClipDistanceArraySize(compilerHandle);
+        state->cullDistanceSize = sh::GetCullDistanceArraySize(compilerHandle);
+        state->uniformRegisterMap =
+            GetUniformRegisterMap(sh::GetUniformRegisterMap(compilerHandle));
+        state->readonlyImage2DRegisterIndex = sh::GetReadonlyImage2DRegisterIndex(compilerHandle);
+        state->image2DRegisterIndex         = sh::GetImage2DRegisterIndex(compilerHandle);
+        state->usedImage2DFunctionNames =
             GetUsedImage2DFunctionNames(sh::GetUsedImage2DFunctionNames(compilerHandle));
 
-        for (const sh::InterfaceBlock &interfaceBlock : mState.getUniformBlocks())
+        for (const sh::InterfaceBlock &interfaceBlock : mState.getCompiledState()->uniformBlocks)
         {
             if (interfaceBlock.active)
             {
@@ -361,15 +366,17 @@
                 bool useStructuredBuffer =
                     sh::ShouldUniformBlockUseStructuredBuffer(compilerHandle, interfaceBlock.name);
 
-                mUniformBlockRegisterMap[interfaceBlock.name]            = index;
-                mUniformBlockUseStructuredBufferMap[interfaceBlock.name] = useStructuredBuffer;
+                state->uniformBlockRegisterMap[interfaceBlock.name] = index;
+                state->uniformBlockUseStructuredBufferMap[interfaceBlock.name] =
+                    useStructuredBuffer;
             }
         }
 
-        mSlowCompilingUniformBlockSet =
+        state->slowCompilingUniformBlockSet =
             GetSlowCompilingUniformBlockSet(sh::GetSlowCompilingUniformBlockSet(compilerHandle));
 
-        for (const sh::InterfaceBlock &interfaceBlock : mState.getShaderStorageBlocks())
+        for (const sh::InterfaceBlock &interfaceBlock :
+             mState.getCompiledState()->shaderStorageBlocks)
         {
             if (interfaceBlock.active)
             {
@@ -378,14 +385,14 @@
                     sh::GetShaderStorageBlockRegister(compilerHandle, interfaceBlock.name, &index);
                 ASSERT(blockRegisterResult);
 
-                mShaderStorageBlockRegisterMap[interfaceBlock.name] = index;
+                state->shaderStorageBlockRegisterMap[interfaceBlock.name] = index;
             }
         }
 
-        mDebugInfo += std::string("// ") + gl::GetShaderTypeString(mState.getShaderType()) +
-                      " SHADER BEGIN\n";
-        mDebugInfo += "\n// GLSL BEGIN\n\n" + mState.getSource() + "\n\n// GLSL END\n\n\n";
-        mDebugInfo +=
+        state->debugInfo += std::string("// ") + gl::GetShaderTypeString(mState.getShaderType()) +
+                            " SHADER BEGIN\n";
+        state->debugInfo += "\n// GLSL BEGIN\n\n" + mState.getSource() + "\n\n// GLSL END\n\n\n";
+        state->debugInfo +=
             "// INITIAL HLSL BEGIN\n\n" + translatedSource + "\n// INITIAL HLSL END\n\n\n";
         // Successive steps will append more info
         return true;
@@ -400,9 +407,9 @@
         std::move(postTranslateFunctor), translateTask);
 }
 
-bool ShaderD3D::hasUniform(const std::string &name) const
+bool CompiledShaderStateD3D::hasUniform(const std::string &name) const
 {
-    return mUniformRegisterMap.find(name) != mUniformRegisterMap.end();
+    return uniformRegisterMap.find(name) != uniformRegisterMap.end();
 }
 
 }  // namespace rx
diff --git a/src/libANGLE/renderer/d3d/ShaderD3D.h b/src/libANGLE/renderer/d3d/ShaderD3D.h
index 5209f56d..99259eb 100644
--- a/src/libANGLE/renderer/d3d/ShaderD3D.h
+++ b/src/libANGLE/renderer/d3d/ShaderD3D.h
@@ -12,6 +12,7 @@
 #include "libANGLE/renderer/ShaderImpl.h"
 
 #include <map>
+#include <memory>
 
 namespace angle
 {
@@ -49,6 +50,64 @@
     Less
 };
 
+struct CompiledShaderStateD3D : angle::NonCopyable
+{
+    CompiledShaderStateD3D();
+    ~CompiledShaderStateD3D();
+
+    bool hasUniform(const std::string &name) const;
+
+    // Query regular uniforms with their name. Query sampler fields of structs with field selection
+    // using dot (.) operator.
+    unsigned int getUniformRegister(const std::string &uniformName) const;
+
+    unsigned int getUniformBlockRegister(const std::string &blockName) const;
+    bool shouldUniformBlockUseStructuredBuffer(const std::string &blockName) const;
+    unsigned int getShaderStorageBlockRegister(const std::string &blockName) const;
+    bool useImage2DFunction(const std::string &functionName) const;
+    const std::set<std::string> &getSlowCompilingUniformBlockSet() const;
+    void appendDebugInfo(const std::string &info) { debugInfo += info; }
+
+    void generateWorkarounds(CompilerWorkaroundsD3D *workarounds) const;
+
+    ShShaderOutput compilerOutputType;
+
+    bool usesMultipleRenderTargets;
+    bool usesFragColor;
+    bool usesFragData;
+    bool usesSecondaryColor;
+    bool usesFragCoord;
+    bool usesFrontFacing;
+    bool usesHelperInvocation;
+    bool usesPointSize;
+    bool usesPointCoord;
+    bool usesDepthRange;
+    bool usesSampleID;
+    bool usesSamplePosition;
+    bool usesSampleMaskIn;
+    bool usesSampleMask;
+    bool hasMultiviewEnabled;
+    bool usesVertexID;
+    bool usesViewID;
+    bool usesDiscardRewriting;
+    bool usesNestedBreak;
+    bool requiresIEEEStrictCompiling;
+    FragDepthUsage fragDepthUsage;
+    uint8_t clipDistanceSize;
+    uint8_t cullDistanceSize;
+
+    std::string debugInfo;
+    std::map<std::string, unsigned int> uniformRegisterMap;
+    std::map<std::string, unsigned int> uniformBlockRegisterMap;
+    std::map<std::string, bool> uniformBlockUseStructuredBufferMap;
+    std::set<std::string> slowCompilingUniformBlockSet;
+    std::map<std::string, unsigned int> shaderStorageBlockRegisterMap;
+    unsigned int readonlyImage2DRegisterIndex;
+    unsigned int image2DRegisterIndex;
+    std::set<std::string> usedImage2DFunctionNames;
+};
+using SharedCompiledShaderStateD3D = std::shared_ptr<CompiledShaderStateD3D>;
+
 class ShaderD3D : public ShaderImpl
 {
   public:
@@ -61,85 +120,12 @@
 
     std::string getDebugInfo() const override;
 
-    // D3D-specific methods
-    void uncompile();
-
-    bool hasUniform(const std::string &name) const;
-
-    // Query regular uniforms with their name. Query sampler fields of structs with field selection
-    // using dot (.) operator.
-    unsigned int getUniformRegister(const std::string &uniformName) const;
-
-    unsigned int getUniformBlockRegister(const std::string &blockName) const;
-    bool shouldUniformBlockUseStructuredBuffer(const std::string &blockName) const;
-    unsigned int getShaderStorageBlockRegister(const std::string &blockName) const;
-    unsigned int getReadonlyImage2DRegisterIndex() const { return mReadonlyImage2DRegisterIndex; }
-    unsigned int getImage2DRegisterIndex() const { return mImage2DRegisterIndex; }
-    bool useImage2DFunction(const std::string &functionName) const;
-    const std::set<std::string> &getSlowCompilingUniformBlockSet() const;
-    void appendDebugInfo(const std::string &info) const { mDebugInfo += info; }
-
-    void generateWorkarounds(CompilerWorkaroundsD3D *workarounds) const;
-
-    bool usesMultipleRenderTargets() const { return mUsesMultipleRenderTargets; }
-    bool usesFragColor() const { return mUsesFragColor; }
-    bool usesFragData() const { return mUsesFragData; }
-    bool usesSecondaryColor() const { return mUsesSecondaryColor; }
-    bool usesFragCoord() const { return mUsesFragCoord; }
-    bool usesFrontFacing() const { return mUsesFrontFacing; }
-    bool usesHelperInvocation() const { return mUsesHelperInvocation; }
-    bool usesPointSize() const { return mUsesPointSize; }
-    bool usesPointCoord() const { return mUsesPointCoord; }
-    bool usesDepthRange() const { return mUsesDepthRange; }
-    bool usesVertexID() const { return mUsesVertexID; }
-    bool usesViewID() const { return mUsesViewID; }
-    bool usesSampleID() const { return mUsesSampleID; }
-    bool usesSamplePosition() const { return mUsesSamplePosition; }
-    bool usesSampleMaskIn() const { return mUsesSampleMaskIn; }
-    bool usesSampleMask() const { return mUsesSampleMask; }
-    bool hasMultiviewEnabled() const { return mHasMultiviewEnabled; }
-    FragDepthUsage getFragDepthUsage() const { return mFragDepthUsage; }
-    uint8_t getClipDistanceArraySize() const { return mClipDistanceSize; }
-    uint8_t getCullDistanceArraySize() const { return mCullDistanceSize; }
-
-    ShShaderOutput getCompilerOutputType() const;
+    const SharedCompiledShaderStateD3D &getCompiledState() const { return mCompiledState; }
 
   private:
-    bool mUsesMultipleRenderTargets;
-    bool mUsesFragColor;
-    bool mUsesFragData;
-    bool mUsesSecondaryColor;
-    bool mUsesFragCoord;
-    bool mUsesFrontFacing;
-    bool mUsesHelperInvocation;
-    bool mUsesPointSize;
-    bool mUsesPointCoord;
-    bool mUsesDepthRange;
-    bool mUsesSampleID;
-    bool mUsesSamplePosition;
-    bool mUsesSampleMaskIn;
-    bool mUsesSampleMask;
-    bool mHasMultiviewEnabled;
-    bool mUsesVertexID;
-    bool mUsesViewID;
-    bool mUsesDiscardRewriting;
-    bool mUsesNestedBreak;
-    bool mRequiresIEEEStrictCompiling;
-    FragDepthUsage mFragDepthUsage;
-    uint8_t mClipDistanceSize;
-    uint8_t mCullDistanceSize;
-
     RendererD3D *mRenderer;
-    ShShaderOutput mCompilerOutputType;
-    mutable std::string mDebugInfo;
-    std::map<std::string, unsigned int> mUniformRegisterMap;
-    std::map<std::string, unsigned int> mUniformBlockRegisterMap;
-    std::map<std::string, bool> mUniformBlockUseStructuredBufferMap;
-    std::set<std::string> mSlowCompilingUniformBlockSet;
-    std::map<std::string, unsigned int> mShaderStorageBlockRegisterMap;
-    unsigned int mReadonlyImage2DRegisterIndex;
-    unsigned int mImage2DRegisterIndex;
-    std::set<std::string> mUsedImage2DFunctionNames;
+
+    SharedCompiledShaderStateD3D mCompiledState;
 };
 }  // namespace rx
 
diff --git a/src/libANGLE/renderer/gl/ProgramGL.cpp b/src/libANGLE/renderer/gl/ProgramGL.cpp
index cd54184..b392351 100644
--- a/src/libANGLE/renderer/gl/ProgramGL.cpp
+++ b/src/libANGLE/renderer/gl/ProgramGL.cpp
@@ -29,6 +29,67 @@
 
 namespace rx
 {
+namespace
+{
+
+// Returns mapped name of a transform feedback varying. The original name may contain array
+// brackets with an index inside, which will get copied to the mapped name. The varying must be
+// known to be declared in the shader.
+std::string GetTransformFeedbackVaryingMappedName(const gl::SharedCompiledShaderState &shaderState,
+                                                  const std::string &tfVaryingName)
+{
+    ASSERT(shaderState->shaderType != gl::ShaderType::Fragment &&
+           shaderState->shaderType != gl::ShaderType::Compute);
+    const auto &varyings = shaderState->outputVaryings;
+    auto bracketPos      = tfVaryingName.find("[");
+    if (bracketPos != std::string::npos)
+    {
+        auto tfVaryingBaseName = tfVaryingName.substr(0, bracketPos);
+        for (const auto &varying : varyings)
+        {
+            if (varying.name == tfVaryingBaseName)
+            {
+                std::string mappedNameWithArrayIndex =
+                    varying.mappedName + tfVaryingName.substr(bracketPos);
+                return mappedNameWithArrayIndex;
+            }
+        }
+    }
+    else
+    {
+        for (const auto &varying : varyings)
+        {
+            if (varying.name == tfVaryingName)
+            {
+                return varying.mappedName;
+            }
+            else if (varying.isStruct())
+            {
+                GLuint fieldIndex = 0;
+                const auto *field = varying.findField(tfVaryingName, &fieldIndex);
+                if (field == nullptr)
+                {
+                    continue;
+                }
+                ASSERT(field != nullptr && !field->isStruct() &&
+                       (!field->isArray() || varying.isShaderIOBlock));
+                std::string mappedName;
+                // If it's an I/O block without an instance name, don't include the block name.
+                if (!varying.isShaderIOBlock || !varying.name.empty())
+                {
+                    mappedName = varying.isShaderIOBlock ? varying.mappedStructOrBlockName
+                                                         : varying.mappedName;
+                    mappedName += '.';
+                }
+                return mappedName + field->mappedName;
+            }
+        }
+    }
+    UNREACHABLE();
+    return std::string();
+}
+
+}  // anonymous namespace
 
 ProgramGL::ProgramGL(const gl::ProgramState &data,
                      const FunctionsGL *functions,
@@ -221,6 +282,20 @@
     PostLinkImplFunctor mPostLinkImplFunctor;
 };
 
+void ProgramGL::prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders)
+{
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        mAttachedShaders[shaderType] = 0;
+
+        if (shaders[shaderType] != nullptr)
+        {
+            const ShaderGL *shaderGL     = GetAs<ShaderGL>(shaders[shaderType]);
+            mAttachedShaders[shaderType] = shaderGL->getShaderID();
+        }
+    }
+}
+
 std::unique_ptr<LinkEvent> ProgramGL::link(const gl::Context *context,
                                            const gl::ProgramLinkedResources &resources,
                                            gl::InfoLog &infoLog,
@@ -230,12 +305,9 @@
 
     preLink();
 
-    if (mState.getAttachedShader(gl::ShaderType::Compute))
+    if (mAttachedShaders[gl::ShaderType::Compute] != 0)
     {
-        const ShaderGL *computeShaderGL =
-            GetImplAs<ShaderGL>(mState.getAttachedShader(gl::ShaderType::Compute));
-
-        mFunctions->attachShader(mProgramID, computeShaderGL->getShaderID());
+        mFunctions->attachShader(mProgramID, mAttachedShaders[gl::ShaderType::Compute]);
     }
     else
     {
@@ -247,9 +319,8 @@
                 mState.getExecutable().hasLinkedShaderStage(gl::ShaderType::Geometry)
                     ? gl::ShaderType::Geometry
                     : gl::ShaderType::Vertex;
-            std::string tfVaryingMappedName =
-                mState.getAttachedShader(tfShaderType)
-                    ->getTransformFeedbackVaryingMappedName(context, tfVarying);
+            std::string tfVaryingMappedName = GetTransformFeedbackVaryingMappedName(
+                mState.getAttachedShader(tfShaderType), tfVarying);
             transformFeedbackVaryingMappedNames.push_back(tfVaryingMappedName);
         }
 
@@ -281,11 +352,9 @@
 
         for (const gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
         {
-            const ShaderGL *shaderGL =
-                rx::SafeGetImplAs<ShaderGL, gl::Shader>(mState.getAttachedShader(shaderType));
-            if (shaderGL)
+            if (mAttachedShaders[shaderType] != 0)
             {
-                mFunctions->attachShader(mProgramID, shaderGL->getShaderID());
+                mFunctions->attachShader(mProgramID, mAttachedShaders[shaderType]);
             }
         }
 
@@ -306,12 +375,12 @@
         // Otherwise shader-assigned locations will work.
         if (context->getExtensions().blendFuncExtendedEXT)
         {
-            gl::Shader *fragmentShader = mState.getAttachedShader(gl::ShaderType::Fragment);
-            if (fragmentShader && fragmentShader->getShaderVersion(context) == 100 &&
+            const gl::SharedCompiledShaderState &fragmentShader =
+                mState.getAttachedShader(gl::ShaderType::Fragment);
+            if (fragmentShader && fragmentShader->shaderVersion == 100 &&
                 mFunctions->standard == STANDARD_GL_DESKTOP)
             {
-                const auto &shaderOutputs = mState.getAttachedShader(gl::ShaderType::Fragment)
-                                                ->getActiveOutputVariables(context);
+                const auto &shaderOutputs = fragmentShader->activeOutputVariables;
                 for (const auto &output : shaderOutputs)
                 {
                     // TODO(http://anglebug.com/1085) This could be cleaner if the transformed names
@@ -352,7 +421,7 @@
                     }
                 }
             }
-            else if (fragmentShader && fragmentShader->getShaderVersion(context) >= 300)
+            else if (fragmentShader && fragmentShader->shaderVersion >= 300)
             {
                 // ESSL 3.00 and up.
                 const auto &outputLocations          = mState.getOutputLocations();
@@ -436,22 +505,17 @@
             mFunctions->linkProgram(mProgramID);
         }
 
-        if (mState.getAttachedShader(gl::ShaderType::Compute))
+        if (mAttachedShaders[gl::ShaderType::Compute] != 0)
         {
-            const ShaderGL *computeShaderGL =
-                GetImplAs<ShaderGL>(mState.getAttachedShader(gl::ShaderType::Compute));
-
-            mFunctions->detachShader(mProgramID, computeShaderGL->getShaderID());
+            mFunctions->detachShader(mProgramID, mAttachedShaders[gl::ShaderType::Compute]);
         }
         else
         {
             for (const gl::ShaderType shaderType : gl::kAllGraphicsShaderTypes)
             {
-                const ShaderGL *shaderGL =
-                    rx::SafeGetImplAs<ShaderGL>(mState.getAttachedShader(shaderType));
-                if (shaderGL)
+                if (mAttachedShaders[shaderType] != 0)
                 {
-                    mFunctions->detachShader(mProgramID, shaderGL->getShaderID());
+                    mFunctions->detachShader(mProgramID, mAttachedShaders[shaderType]);
                 }
             }
         }
diff --git a/src/libANGLE/renderer/gl/ProgramGL.h b/src/libANGLE/renderer/gl/ProgramGL.h
index 574e0bc..a874855 100644
--- a/src/libANGLE/renderer/gl/ProgramGL.h
+++ b/src/libANGLE/renderer/gl/ProgramGL.h
@@ -43,6 +43,7 @@
     void setBinaryRetrievableHint(bool retrievable) override;
     void setSeparable(bool separable) override;
 
+    void prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders) override;
     std::unique_ptr<LinkEvent> link(const gl::Context *contextImpl,
                                     const gl::ProgramLinkedResources &resources,
                                     gl::InfoLog &infoLog,
@@ -150,6 +151,8 @@
     const angle::FeaturesGL &mFeatures;
     StateManagerGL *mStateManager;
 
+    gl::ShaderMap<GLuint> mAttachedShaders;
+
     std::vector<GLint> mUniformRealLocationMap;
     std::vector<GLuint> mUniformBlockRealLocationMap;
 
diff --git a/src/libANGLE/renderer/gl/ShaderGL.cpp b/src/libANGLE/renderer/gl/ShaderGL.cpp
index 46b717e..efe6cdf 100644
--- a/src/libANGLE/renderer/gl/ShaderGL.cpp
+++ b/src/libANGLE/renderer/gl/ShaderGL.cpp
@@ -461,7 +461,7 @@
 
 std::string ShaderGL::getDebugInfo() const
 {
-    return mState.getTranslatedSource();
+    return mState.getCompiledState()->translatedSource;
 }
 
 GLuint ShaderGL::getShaderID() const
diff --git a/src/libANGLE/renderer/metal/ProgramMtl.h b/src/libANGLE/renderer/metal/ProgramMtl.h
index 8ff8b9e..888f5e7 100644
--- a/src/libANGLE/renderer/metal/ProgramMtl.h
+++ b/src/libANGLE/renderer/metal/ProgramMtl.h
@@ -18,6 +18,7 @@
 #include "common/Optional.h"
 #include "common/utilities.h"
 #include "libANGLE/renderer/ProgramImpl.h"
+#include "libANGLE/renderer/metal/ShaderMtl.h"
 #include "libANGLE/renderer/metal/mtl_buffer_pool.h"
 #include "libANGLE/renderer/metal/mtl_command_buffer.h"
 #include "libANGLE/renderer/metal/mtl_common.h"
@@ -127,6 +128,7 @@
     void setBinaryRetrievableHint(bool retrievable) override;
     void setSeparable(bool separable) override;
 
+    void prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders) override;
     std::unique_ptr<LinkEvent> link(const gl::Context *context,
                                     const gl::ProgramLinkedResources &resources,
                                     gl::InfoLog &infoLog,
@@ -247,7 +249,7 @@
                                                     const std::vector<gl::InterfaceBlock> &blocks,
                                                     gl::ShaderType shaderType);
 
-    void initUniformBlocksRemapper(gl::Shader *shader, const gl::Context *glContext);
+    void initUniformBlocksRemapper(const gl::SharedCompiledShaderState &shader);
 
     angle::Result encodeUniformBuffersInfoArgumentBuffer(
         ContextMtl *context,
@@ -263,9 +265,9 @@
     void saveShaderInternalInfo(gl::BinaryOutputStream *stream);
     void loadShaderInternalInfo(gl::BinaryInputStream *stream);
 
-    void linkUpdateHasFlatAttributes(const gl::Context *context);
+    void linkUpdateHasFlatAttributes();
 
-    void linkResources(const gl::Context *context, const gl::ProgramLinkedResources &resources);
+    void linkResources(const gl::ProgramLinkedResources &resources);
     std::unique_ptr<LinkEvent> compileMslShaderLibs(const gl::Context *context,
                                                     gl::InfoLog &infoLog);
 
@@ -289,6 +291,8 @@
     gl::ShaderBitSet mDefaultUniformBlocksDirty;
     gl::ShaderBitSet mSamplerBindingsDirty;
 
+    gl::ShaderMap<SharedCompiledShaderStateMtl> mAttachedShaders;
+
     gl::ShaderMap<DefaultUniformBlock> mDefaultUniformBlocks;
     std::unordered_map<std::string, UBOConversionInfo> mUniformBlockConversions;
 
diff --git a/src/libANGLE/renderer/metal/ProgramMtl.mm b/src/libANGLE/renderer/metal/ProgramMtl.mm
index 6028765..c9e862e 100644
--- a/src/libANGLE/renderer/metal/ProgramMtl.mm
+++ b/src/libANGLE/renderer/metal/ProgramMtl.mm
@@ -227,7 +227,6 @@
 }
 
 void InitDefaultUniformBlock(const std::vector<sh::Uniform> &uniforms,
-                             gl::Shader *shader,
                              sh::BlockLayoutMap *blockLayoutMapOut,
                              size_t *blockSizeOut)
 {
@@ -430,10 +429,10 @@
 }  // namespace
 
 // TODO(angleproject:7979) Upgrade ANGLE Uniform buffer remapper to compute shaders
-void ProgramMtl::initUniformBlocksRemapper(gl::Shader *shader, const gl::Context *glContext)
+void ProgramMtl::initUniformBlocksRemapper(const gl::SharedCompiledShaderState &shader)
 {
     std::unordered_map<std::string, UBOConversionInfo> conversionMap;
-    const std::vector<sh::InterfaceBlock> ibs = shader->getUniformBlocks(glContext);
+    const std::vector<sh::InterfaceBlock> ibs = shader->uniformBlocks;
     for (size_t i = 0; i < ibs.size(); ++i)
     {
 
@@ -643,6 +642,20 @@
     UNIMPLEMENTED();
 }
 
+void ProgramMtl::prepareForLink(const gl::ShaderMap<ShaderImpl *> &shaders)
+{
+    for (gl::ShaderType shaderType : gl::AllShaderTypes())
+    {
+        mAttachedShaders[shaderType].reset();
+
+        if (shaders[shaderType] != nullptr)
+        {
+            const ShaderMtl *shaderMtl   = GetAs<ShaderMtl>(shaders[shaderType]);
+            mAttachedShaders[shaderType] = shaderMtl->getCompiledState();
+        }
+    }
+}
+
 std::unique_ptr<LinkEvent> ProgramMtl::link(const gl::Context *context,
                                             const gl::ProgramLinkedResources &resources,
                                             gl::InfoLog &infoLog,
@@ -652,24 +665,24 @@
 
     // Link resources before calling GetShaderSource to make sure they are ready for the set/binding
     // assignment done in that function.
-    linkResources(context, resources);
+    linkResources(resources);
 
     reset(contextMtl);
     ANGLE_PARALLEL_LINK_TRY(initDefaultUniformBlocks(context));
-    linkUpdateHasFlatAttributes(context);
+    linkUpdateHasFlatAttributes();
 
     gl::ShaderMap<std::string> shaderSources;
-    mtl::MSLGetShaderSource(context, mState, resources, &shaderSources);
+    mtl::MSLGetShaderSource(mState, resources, &shaderSources);
 
     ANGLE_PARALLEL_LINK_TRY(mtl::MTLGetMSL(
-        context, mState, contextMtl->getCaps(), shaderSources, &mMslShaderTranslateInfo,
-        mState.getExecutable().getTransformFeedbackBufferCount()));
+        context, mState, contextMtl->getCaps(), shaderSources, mAttachedShaders,
+        &mMslShaderTranslateInfo, mState.getExecutable().getTransformFeedbackBufferCount()));
     mMslXfbOnlyVertexShaderInfo = mMslShaderTranslateInfo[gl::ShaderType::Vertex];
 
     return compileMslShaderLibs(context, infoLog);
 }
 
-void ProgramMtl::linkUpdateHasFlatAttributes(const gl::Context *context)
+void ProgramMtl::linkUpdateHasFlatAttributes()
 {
     mProgramHasFlatAttributes = false;
 
@@ -683,8 +696,7 @@
         }
     }
 
-    const auto &flatVaryings =
-        mState.getAttachedShader(gl::ShaderType::Vertex)->getOutputVaryings(context);
+    const auto &flatVaryings = mState.getAttachedShader(gl::ShaderType::Vertex)->outputVaryings;
     for (auto &attribute : flatVaryings)
     {
         if (attribute.interpolation == sh::INTERPOLATION_FLAT)
@@ -824,13 +836,12 @@
     }
     return mAuxBufferPool;
 }
-void ProgramMtl::linkResources(const gl::Context *context,
-                               const gl::ProgramLinkedResources &resources)
+void ProgramMtl::linkResources(const gl::ProgramLinkedResources &resources)
 {
     Std140BlockLayoutEncoderFactory std140EncoderFactory;
     gl::ProgramLinkedResourcesLinker linker(&std140EncoderFactory);
 
-    linker.linkResources(context, mState, resources);
+    linker.linkResources(mState, resources);
 }
 
 angle::Result ProgramMtl::initDefaultUniformBlocks(const gl::Context *glContext)
@@ -842,14 +853,14 @@
 
     for (gl::ShaderType shaderType : gl::kAllGLES2ShaderTypes)
     {
-        gl::Shader *shader = mState.getAttachedShader(shaderType);
+        const gl::SharedCompiledShaderState &shader = mState.getAttachedShader(shaderType);
         if (shader)
         {
-            const std::vector<sh::Uniform> &uniforms = shader->getUniforms(glContext);
-            InitDefaultUniformBlock(uniforms, shader, &layoutMap[shaderType],
+            const std::vector<sh::Uniform> &uniforms = shader->uniforms;
+            InitDefaultUniformBlock(uniforms, &layoutMap[shaderType],
                                     &requiredBufferSize[shaderType]);
             // Set up block conversion buffer
-            initUniformBlocksRemapper(shader, glContext);
+            initUniformBlocksRemapper(shader);
         }
     }
 
diff --git a/src/libANGLE/renderer/metal/ShaderMtl.h b/src/libANGLE/renderer/metal/ShaderMtl.h
index 682ae6e..e9f7572 100644
--- a/src/libANGLE/renderer/metal/ShaderMtl.h
+++ b/src/libANGLE/renderer/metal/ShaderMtl.h
@@ -11,11 +11,11 @@
 
 #include <map>
 
-#include "compiler/translator/msl/TranslatorMSL.h"
 #include "libANGLE/renderer/ShaderImpl.h"
+#include "libANGLE/renderer/metal/mtl_msl_utils.h"
+
 namespace rx
 {
-
 class ShaderMtl : public ShaderImpl
 {
   public:
@@ -26,20 +26,17 @@
                                                   gl::ShCompilerInstance *compilerInstance,
                                                   ShCompileOptions *options) override;
 
-    sh::TranslatorMetalReflection *getTranslatorMetalReflection()
-    {
-        return &translatorMetalReflection;
-    }
+    const SharedCompiledShaderStateMtl &getCompiledState() const { return mCompiledState; }
 
     std::string getDebugInfo() const override;
 
-    sh::TranslatorMetalReflection translatorMetalReflection = {};
-
   private:
     std::shared_ptr<WaitableCompileEvent> compileImplMtl(const gl::Context *context,
                                                          gl::ShCompilerInstance *compilerInstance,
                                                          const std::string &source,
                                                          ShCompileOptions *compileOptions);
+
+    SharedCompiledShaderStateMtl mCompiledState;
 };
 
 }  // namespace rx
diff --git a/src/libANGLE/renderer/metal/ShaderMtl.mm b/src/libANGLE/renderer/metal/ShaderMtl.mm
index de1c47f..59015b5 100644
--- a/src/libANGLE/renderer/metal/ShaderMtl.mm
+++ b/src/libANGLE/renderer/metal/ShaderMtl.mm
@@ -52,7 +52,7 @@
 class MTLWaitableCompileEventImpl final : public WaitableCompileEvent
 {
   public:
-    MTLWaitableCompileEventImpl(ShaderMtl *shader,
+    MTLWaitableCompileEventImpl(const SharedCompiledShaderStateMtl &shader,
                                 std::shared_ptr<angle::WaitableEvent> waitableEvent,
                                 std::shared_ptr<TranslateTask> translateTask)
         : WaitableCompileEvent(waitableEvent), mTranslateTask(translateTask), mShader(shader)
@@ -76,7 +76,7 @@
 
   private:
     std::shared_ptr<TranslateTask> mTranslateTask;
-    ShaderMtl *mShader;
+    SharedCompiledShaderStateMtl mShader;
 };
 
 std::shared_ptr<WaitableCompileEvent> ShaderMtl::compileImplMtl(
@@ -96,7 +96,7 @@
         std::make_shared<TranslateTask>(compilerInstance->getHandle(), *compileOptions, source);
 
     return std::make_shared<MTLWaitableCompileEventImpl>(
-        this, workerThreadPool->postWorkerTask(translateTask), translateTask);
+        mCompiledState, workerThreadPool->postWorkerTask(translateTask), translateTask);
 }
 
 std::shared_ptr<WaitableCompileEvent> ShaderMtl::compile(const gl::Context *context,
@@ -106,6 +106,10 @@
     ContextMtl *contextMtl = mtl::GetImpl(context);
     DisplayMtl *displayMtl = contextMtl->getDisplay();
 
+    // Create a new compiled shader state.  Currently running program link jobs will use the
+    // previous state.
+    mCompiledState = std::make_shared<CompiledShaderStateMtl>();
+
     options->initializeUninitializedLocals = true;
 
     if (context->isWebGL() && mState.getShaderType() != gl::ShaderType::Compute)
@@ -150,12 +154,7 @@
 
 std::string ShaderMtl::getDebugInfo() const
 {
-    std::string debugInfo = mState.getTranslatedSource();
-    if (debugInfo.empty())
-    {
-        return mState.getCompiledBinary().empty() ? "" : "<binary blob>";
-    }
-    return debugInfo;
+    return mState.getCompiledState()->translatedSource;
 }
 
 }  // namespace rx
diff --git a/src/libANGLE/renderer/metal/mtl_msl_utils.h b/src/libANGLE/renderer/metal/mtl_msl_utils.h
index f243b34..413ffbc 100644
--- a/src/libANGLE/renderer/metal/mtl_msl_utils.h
+++ b/src/libANGLE/renderer/metal/mtl_msl_utils.h
@@ -8,12 +8,22 @@
 
 #ifndef mtl_msl_utils_h
 #define mtl_msl_utils_h
+
+#include <memory>
+
+#include "compiler/translator/msl/TranslatorMSL.h"
 #include "libANGLE/Context.h"
 #include "libANGLE/renderer/ProgramImpl.h"
 #include "libANGLE/renderer/metal/mtl_common.h"
 
 namespace rx
 {
+struct CompiledShaderStateMtl : angle::NonCopyable
+{
+    sh::TranslatorMetalReflection translatorMetalReflection = {};
+};
+using SharedCompiledShaderStateMtl = std::shared_ptr<CompiledShaderStateMtl>;
+
 namespace mtl
 {
 struct SamplerBinding
@@ -36,8 +46,8 @@
     bool hasUBOArgumentBuffer;
     bool hasInvariant;
 };
-void MSLGetShaderSource(const gl::Context *context,
-                        const gl::ProgramState &programState,
+
+void MSLGetShaderSource(const gl::ProgramState &programState,
                         const gl::ProgramLinkedResources &resources,
                         gl::ShaderMap<std::string> *shaderSourcesOut);
 
@@ -45,6 +55,7 @@
                         const gl::ProgramState &programState,
                         const gl::Caps &glCaps,
                         const gl::ShaderMap<std::string> &shaderSources,
+                        const gl::ShaderMap<SharedCompiledShaderStateMtl> &shadersState,
                         gl::ShaderMap<TranslatedShaderInfo> *mslShaderInfoOut,
                         size_t xfbBufferCount);
 
diff --git a/src/libANGLE/renderer/metal/mtl_msl_utils.mm b/src/libANGLE/renderer/metal/mtl_msl_utils.mm
index c796bc5..af063c7 100644
--- a/src/libANGLE/renderer/metal/mtl_msl_utils.mm
+++ b/src/libANGLE/renderer/metal/mtl_msl_utils.mm
@@ -112,15 +112,14 @@
     return samplerName;
 }
 
-void MSLGetShaderSource(const gl::Context *context,
-                        const gl::ProgramState &programState,
+void MSLGetShaderSource(const gl::ProgramState &programState,
                         const gl::ProgramLinkedResources &resources,
                         gl::ShaderMap<std::string> *shaderSourcesOut)
 {
     for (const gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        gl::Shader *glShader            = programState.getAttachedShader(shaderType);
-        (*shaderSourcesOut)[shaderType] = glShader ? glShader->getTranslatedSource(context) : "";
+        const gl::SharedCompiledShaderState &glShader = programState.getAttachedShader(shaderType);
+        (*shaderSourcesOut)[shaderType]               = glShader ? glShader->translatedSource : "";
     }
 }
 
@@ -159,12 +158,6 @@
     }
 }
 
-sh::TranslatorMetalReflection *getReflectionFromShader(gl::Shader *shader)
-{
-    ShaderMtl *shaderInstance = static_cast<ShaderMtl *>(shader->getImplementation());
-    return shaderInstance->getTranslatorMetalReflection();
-}
-
 std::string updateShaderAttributes(std::string shaderSourceIn, const gl::ProgramState &programState)
 {
     // Build string to attrib map.
@@ -466,6 +459,7 @@
                         const gl::ProgramState &programState,
                         const gl::Caps &glCaps,
                         const gl::ShaderMap<std::string> &shaderSources,
+                        const gl::ShaderMap<SharedCompiledShaderStateMtl> &shadersState,
                         gl::ShaderMap<TranslatedShaderInfo> *mslShaderInfoOut,
                         size_t xfbBufferCount)
 {
@@ -532,8 +526,8 @@
         }
         (*mslShaderInfoOut)[type].metalShaderSource =
             std::make_shared<const std::string>(std::move(source));
-        gl::Shader *shader                              = programState.getAttachedShader(type);
-        const sh::TranslatorMetalReflection *reflection = getReflectionFromShader(shader);
+        const sh::TranslatorMetalReflection *reflection =
+            &shadersState[type]->translatorMetalReflection;
         if (reflection->hasUBOs)
         {
             (*mslShaderInfoOut)[type].hasUBOArgumentBuffer = true;
diff --git a/src/libANGLE/renderer/vulkan/ProgramVk.cpp b/src/libANGLE/renderer/vulkan/ProgramVk.cpp
index 0223c3d..b84cc58 100644
--- a/src/libANGLE/renderer/vulkan/ProgramVk.cpp
+++ b/src/libANGLE/renderer/vulkan/ProgramVk.cpp
@@ -316,7 +316,7 @@
     ContextVk *contextVk = vk::GetImpl(context);
     // Link resources before calling GetShaderSource to make sure they are ready for the set/binding
     // assignment done in that function.
-    linkResources(context, resources);
+    linkResources(resources);
 
     reset(contextVk);
     mExecutable.clearVariableInfoMap();
@@ -324,8 +324,8 @@
     // Gather variable info and compiled SPIR-V binaries.
     gl::ShaderMap<const angle::spirv::Blob *> spirvBlobs;
     SpvSourceOptions options = SpvCreateSourceOptions(contextVk->getFeatures());
-    SpvGetShaderSpirvCode(context, options, mState, resources, &mSpvProgramInterfaceInfo,
-                          &spirvBlobs, &mExecutable.mVariableInfoMap);
+    SpvGetShaderSpirvCode(options, mState, resources, &mSpvProgramInterfaceInfo, &spirvBlobs,
+                          &mExecutable.mVariableInfoMap);
 
     if (contextVk->getFeatures().varyingsRequireMatchingPrecisionInSpirv.enabled &&
         contextVk->getFeatures().enablePrecisionQualifiers.enabled)
@@ -364,13 +364,12 @@
     return std::make_unique<LinkEventVulkan>(context->getShaderCompileThreadPool(), linkTask);
 }
 
-void ProgramVk::linkResources(const gl::Context *context,
-                              const gl::ProgramLinkedResources &resources)
+void ProgramVk::linkResources(const gl::ProgramLinkedResources &resources)
 {
     Std140BlockLayoutEncoderFactory std140EncoderFactory;
     gl::ProgramLinkedResourcesLinker linker(&std140EncoderFactory);
 
-    linker.linkResources(context, mState, resources);
+    linker.linkResources(mState, resources);
 }
 
 angle::Result ProgramVk::initDefaultUniformBlocks(const gl::Context *glContext)
@@ -382,7 +381,7 @@
     gl::ShaderMap<size_t> requiredBufferSize;
     requiredBufferSize.fill(0);
 
-    generateUniformLayoutMapping(glContext, layoutMap, requiredBufferSize);
+    generateUniformLayoutMapping(layoutMap, requiredBufferSize);
     initDefaultUniformLayoutMapping(layoutMap);
 
     // All uniform initializations are complete, now resize the buffers accordingly and return
@@ -390,19 +389,18 @@
                                                 requiredBufferSize);
 }
 
-void ProgramVk::generateUniformLayoutMapping(const gl::Context *context,
-                                             gl::ShaderMap<sh::BlockLayoutMap> &layoutMap,
+void ProgramVk::generateUniformLayoutMapping(gl::ShaderMap<sh::BlockLayoutMap> &layoutMap,
                                              gl::ShaderMap<size_t> &requiredBufferSize)
 {
     const gl::ProgramExecutable &glExecutable = mState.getExecutable();
 
     for (const gl::ShaderType shaderType : glExecutable.getLinkedShaderStages())
     {
-        gl::Shader *shader = mState.getAttachedShader(shaderType);
+        const gl::SharedCompiledShaderState &shader = mState.getAttachedShader(shaderType);
 
         if (shader)
         {
-            const std::vector<sh::ShaderVariable> &uniforms = shader->getUniforms(context);
+            const std::vector<sh::ShaderVariable> &uniforms = shader->uniforms;
             InitDefaultUniformBlock(uniforms, &layoutMap[shaderType],
                                     &requiredBufferSize[shaderType]);
         }
diff --git a/src/libANGLE/renderer/vulkan/ProgramVk.h b/src/libANGLE/renderer/vulkan/ProgramVk.h
index 8574706..e17540d 100644
--- a/src/libANGLE/renderer/vulkan/ProgramVk.h
+++ b/src/libANGLE/renderer/vulkan/ProgramVk.h
@@ -128,8 +128,7 @@
 
     void reset(ContextVk *contextVk);
     angle::Result initDefaultUniformBlocks(const gl::Context *glContext);
-    void generateUniformLayoutMapping(const gl::Context *context,
-                                      gl::ShaderMap<sh::BlockLayoutMap> &layoutMap,
+    void generateUniformLayoutMapping(gl::ShaderMap<sh::BlockLayoutMap> &layoutMap,
                                       gl::ShaderMap<size_t> &requiredBufferSize);
     void initDefaultUniformLayoutMapping(gl::ShaderMap<sh::BlockLayoutMap> &layoutMap);
 
@@ -138,7 +137,7 @@
 
     template <typename T>
     void setUniformImpl(GLint location, GLsizei count, const T *v, GLenum entryPointType);
-    void linkResources(const gl::Context *context, const gl::ProgramLinkedResources &resources);
+    void linkResources(const gl::ProgramLinkedResources &resources);
 
     angle::Result createGraphicsPipelineWithDefaultState(const gl::Context *context,
                                                          vk::PipelineCacheAccess *pipelineCache);
diff --git a/src/libANGLE/renderer/vulkan/ShaderVk.cpp b/src/libANGLE/renderer/vulkan/ShaderVk.cpp
index 097ccdf..365d7eb 100644
--- a/src/libANGLE/renderer/vulkan/ShaderVk.cpp
+++ b/src/libANGLE/renderer/vulkan/ShaderVk.cpp
@@ -133,7 +133,7 @@
 
 std::string ShaderVk::getDebugInfo() const
 {
-    return mState.getCompiledBinary().empty() ? "" : "<binary blob>";
+    return mState.getCompiledState()->compiledBinary.empty() ? "" : "<binary blob>";
 }
 
 }  // namespace rx
diff --git a/src/libANGLE/renderer/vulkan/spv_utils.cpp b/src/libANGLE/renderer/vulkan/spv_utils.cpp
index c571215..79088e2 100644
--- a/src/libANGLE/renderer/vulkan/spv_utils.cpp
+++ b/src/libANGLE/renderer/vulkan/spv_utils.cpp
@@ -4943,8 +4943,7 @@
     }
 }
 
-void SpvGetShaderSpirvCode(const gl::Context *context,
-                           const SpvSourceOptions &options,
+void SpvGetShaderSpirvCode(const SpvSourceOptions &options,
                            const gl::ProgramState &programState,
                            const gl::ProgramLinkedResources &resources,
                            SpvProgramInterfaceInfo *programInterfaceInfo,
@@ -4953,8 +4952,8 @@
 {
     for (const gl::ShaderType shaderType : gl::AllShaderTypes())
     {
-        gl::Shader *glShader         = programState.getAttachedShader(shaderType);
-        (*spirvBlobsOut)[shaderType] = glShader ? &glShader->getCompiledBinary(context) : nullptr;
+        const gl::SharedCompiledShaderState &glShader = programState.getAttachedShader(shaderType);
+        (*spirvBlobsOut)[shaderType] = glShader ? &glShader->compiledBinary : nullptr;
     }
 
     const gl::ProgramExecutable &programExecutable = programState.getExecutable();
diff --git a/src/libANGLE/renderer/vulkan/spv_utils.h b/src/libANGLE/renderer/vulkan/spv_utils.h
index 61940be..ea7a5f9 100644
--- a/src/libANGLE/renderer/vulkan/spv_utils.h
+++ b/src/libANGLE/renderer/vulkan/spv_utils.h
@@ -126,8 +126,7 @@
                                          ShaderInterfaceVariableInfoMap *variableInfoMapOut);
 
 // Retrieves the compiled SPIR-V code for each shader stage, and calls |SpvAssignLocations|.
-void SpvGetShaderSpirvCode(const gl::Context *context,
-                           const SpvSourceOptions &options,
+void SpvGetShaderSpirvCode(const SpvSourceOptions &options,
                            const gl::ProgramState &programState,
                            const gl::ProgramLinkedResources &resources,
                            SpvProgramInterfaceInfo *programInterfaceInfo,
diff --git a/src/libANGLE/validationES.cpp b/src/libANGLE/validationES.cpp
index 6812efa..4903df3 100644
--- a/src/libANGLE/validationES.cpp
+++ b/src/libANGLE/validationES.cpp
@@ -1409,7 +1409,7 @@
 
     if (!validProgram)
     {
-        if (context->getShader(id))
+        if (context->getShaderNoResolveCompile(id))
         {
             ANGLE_VALIDATION_ERROR(GL_INVALID_OPERATION, kExpectedProgramName);
         }
@@ -1436,7 +1436,7 @@
 {
     // See ValidProgram for spec details.
 
-    Shader *validShader = context->getShader(id);
+    Shader *validShader = context->getShaderNoResolveCompile(id);
 
     if (!validShader)
     {
diff --git a/src/libANGLE/validationES2.cpp b/src/libANGLE/validationES2.cpp
index f4b7e07..03c24fe 100644
--- a/src/libANGLE/validationES2.cpp
+++ b/src/libANGLE/validationES2.cpp
@@ -2153,7 +2153,7 @@
             return true;
 
         case GL_SHADER:
-            if (context->getShader({name}) == nullptr)
+            if (context->getShaderNoResolveCompile({name}) == nullptr)
             {
                 ANGLE_VALIDATION_ERROR(GL_INVALID_VALUE, kInvalidShaderName);
                 return false;
@@ -4192,7 +4192,7 @@
 
     if (!context->getProgramResolveLink(program))
     {
-        if (context->getShader(program))
+        if (context->getShaderNoResolveCompile(program))
         {
             ANGLE_VALIDATION_ERROR(GL_INVALID_OPERATION, kExpectedProgramName);
             return false;
@@ -4216,7 +4216,7 @@
         return false;
     }
 
-    if (!context->getShader(shader))
+    if (!context->getShaderNoResolveCompile(shader))
     {
         if (context->getProgramResolveLink(shader))
         {
@@ -6000,7 +6000,7 @@
         if (!programObject)
         {
             // ES 3.1.0 section 7.3 page 72
-            if (context->getShader(program))
+            if (context->getShaderNoResolveCompile(program))
             {
                 ANGLE_VALIDATION_ERROR(GL_INVALID_OPERATION, kExpectedProgramName);
                 return false;
@@ -6162,7 +6162,7 @@
         return false;
     }
 
-    Shader *shaderObject = context->getShader(shader);
+    Shader *shaderObject = context->getShaderNoResolveCompile(shader);
 
     if (!shaderObject)
     {
diff --git a/src/libANGLE/validationESEXT.cpp b/src/libANGLE/validationESEXT.cpp
index 9478a09..6b630a0 100644
--- a/src/libANGLE/validationESEXT.cpp
+++ b/src/libANGLE/validationESEXT.cpp
@@ -117,7 +117,7 @@
                 ANGLE_VALIDATION_ERROR(GL_INVALID_ENUM, kInvalidType);
                 return false;
             }
-            if (context->getShader({name}) == nullptr)
+            if (context->getShaderNoResolveCompile({name}) == nullptr)
             {
                 ANGLE_VALIDATION_ERROR(GL_INVALID_OPERATION, kInvalidShaderName);
                 return false;
diff --git a/src/tests/gl_tests/LinkAndRelinkTest.cpp b/src/tests/gl_tests/LinkAndRelinkTest.cpp
index 4fa5a6c..c828513 100644
--- a/src/tests/gl_tests/LinkAndRelinkTest.cpp
+++ b/src/tests/gl_tests/LinkAndRelinkTest.cpp
@@ -27,6 +27,12 @@
     LinkAndRelinkTestES31() {}
 };
 
+class LinkAndRelinkTestES32 : public ANGLETest<>
+{
+  protected:
+    LinkAndRelinkTestES32() {}
+};
+
 // When a program link or relink fails, if you try to install the unsuccessfully
 // linked program (via UseProgram) and start rendering or dispatch compute,
 // We can not always report INVALID_OPERATION for rendering/compute pipeline.
@@ -445,6 +451,166 @@
     EXPECT_GL_NO_ERROR();
 }
 
+// Parallel link should continue unscathed even if the attached shaders to the program are modified.
+TEST_P(LinkAndRelinkTestES31, ReattachShadersWhileParallelLinking)
+{
+    constexpr char kVS[]      = R"(#version 300 es
+void main()
+{
+    vec2 position = vec2(-1, -1);
+    if (gl_VertexID == 1)
+        position = vec2(3, -1);
+    else if (gl_VertexID == 2)
+        position = vec2(-1, 3);
+    gl_Position = vec4(position, 0, 1);
+})";
+    constexpr char kFSGreen[] = R"(#version 300 es
+out mediump vec4 color;
+void main()
+{
+    color = vec4(0, 1, 0, 1);
+})";
+    constexpr char kFSRed[]   = R"(#version 300 es
+out mediump vec4 color;
+void main()
+{
+    color = vec4(1, 0, 0, 1);
+})";
+
+    GLuint program = glCreateProgram();
+
+    GLuint vs    = CompileShader(GL_VERTEX_SHADER, kVS);
+    GLuint green = CompileShader(GL_FRAGMENT_SHADER, kFSGreen);
+    GLuint red   = CompileShader(GL_FRAGMENT_SHADER, kFSRed);
+
+    EXPECT_NE(0u, vs);
+    EXPECT_NE(0u, green);
+    EXPECT_NE(0u, red);
+
+    glAttachShader(program, vs);
+    glAttachShader(program, green);
+    glLinkProgram(program);
+    ASSERT_GL_NO_ERROR();
+
+    // Immediately reattach another shader
+    glDetachShader(program, green);
+    glAttachShader(program, red);
+    ASSERT_GL_NO_ERROR();
+
+    // Make sure the linked program draws with green
+    glUseProgram(program);
+    ASSERT_GL_NO_ERROR();
+
+    glDrawArrays(GL_TRIANGLES, 0, 3);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+    ASSERT_GL_NO_ERROR();
+
+    glDeleteShader(vs);
+    glDeleteShader(green);
+    glDeleteShader(red);
+    ASSERT_GL_NO_ERROR();
+}
+
+// Parallel link should continue unscathed even if new shaders are attached to the program.
+TEST_P(LinkAndRelinkTestES31, AttachNewShadersWhileParallelLinking)
+{
+    ANGLE_SKIP_TEST_IF(!IsGLExtensionEnabled("GL_EXT_geometry_shader"));
+
+    constexpr char kVS[] = R"(#version 310 es
+#extension GL_EXT_geometry_shader : require
+void main()
+{
+    vec2 position = vec2(-1, -1);
+    if (gl_VertexID == 1)
+        position = vec2(3, -1);
+    else if (gl_VertexID == 2)
+        position = vec2(-1, 3);
+    gl_Position = vec4(position, 0, 1);
+})";
+    constexpr char kFS[] = R"(#version 310 es
+#extension GL_EXT_geometry_shader : require
+out mediump vec4 color;
+void main()
+{
+    color = vec4(0, 1, 0, 1);
+})";
+    constexpr char kGS[] = R"(#version 310 es
+#extension GL_EXT_geometry_shader : require
+layout (invocations = 3, triangles) in;
+layout (triangle_strip, max_vertices = 3) out;
+void main()
+{
+})";
+
+    GLuint program = glCreateProgram();
+
+    GLuint vs = CompileShader(GL_VERTEX_SHADER, kVS);
+    GLuint fs = CompileShader(GL_FRAGMENT_SHADER, kFS);
+    GLuint gs = CompileShader(GL_GEOMETRY_SHADER, kGS);
+
+    EXPECT_NE(0u, vs);
+    EXPECT_NE(0u, fs);
+    EXPECT_NE(0u, gs);
+
+    glAttachShader(program, vs);
+    glAttachShader(program, fs);
+    glLinkProgram(program);
+    ASSERT_GL_NO_ERROR();
+
+    // Immediately attach another shader
+    glAttachShader(program, gs);
+    ASSERT_GL_NO_ERROR();
+
+    // Make sure the linked program draws with green
+    glUseProgram(program);
+    ASSERT_GL_NO_ERROR();
+
+    glDrawArrays(GL_TRIANGLES, 0, 3);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+    ASSERT_GL_NO_ERROR();
+
+    glDeleteShader(vs);
+    glDeleteShader(fs);
+    glDeleteShader(gs);
+    ASSERT_GL_NO_ERROR();
+}
+
+// Make sure the shader can be compiled in between attach and link
+TEST_P(LinkAndRelinkTest, AttachShaderThenCompile)
+{
+    GLuint program = glCreateProgram();
+
+    GLShader vs(GL_VERTEX_SHADER);
+    GLShader fs(GL_FRAGMENT_SHADER);
+
+    // Attach the shaders to the program first.  This makes sure the program doesn't prematurely
+    // attempt to look into the shader's compilation result.
+    glAttachShader(program, vs);
+    glAttachShader(program, fs);
+
+    // Compile the shaders after that.
+    const char *kVS = essl1_shaders::vs::Simple();
+    const char *kFS = essl1_shaders::fs::Green();
+    glShaderSource(vs, 1, &kVS, nullptr);
+    glShaderSource(fs, 1, &kFS, nullptr);
+    EXPECT_GL_NO_ERROR();
+
+    glCompileShader(vs);
+    glCompileShader(fs);
+
+    // Then link
+    glLinkProgram(program);
+    ASSERT_GL_NO_ERROR();
+
+    // Make sure it works
+    drawQuad(program, essl1_shaders::PositionAttrib(), 0.5);
+    EXPECT_PIXEL_COLOR_EQ(0, 0, GLColor::green);
+    ASSERT_GL_NO_ERROR();
+
+    glDeleteProgram(program);
+    ASSERT_GL_NO_ERROR();
+}
+
 ANGLE_INSTANTIATE_TEST_ES2_AND_ES3(LinkAndRelinkTest);
 
 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(LinkAndRelinkTestES31);