Apply SH_TIMING_RESTRICTIONS to all samplers.

Issue: 332
Review URL: https://codereview.appspot.com/6273044/


git-svn-id: https://angleproject.googlecode.com/svn/trunk@1131 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/Compiler.cpp b/src/compiler/Compiler.cpp
index c6fcb41..245bf35 100644
--- a/src/compiler/Compiler.cpp
+++ b/src/compiler/Compiler.cpp
@@ -165,12 +165,8 @@
         if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING))
             success = validateLimitations(root);
 
-        // FIXME(mvujovic): For now, we only consider "u_texture" to be a potentially unsafe symbol.
-        // If we end up using timing restrictions in WebGL and CSS Shaders, we should expose an API
-        // to pass in the names of other potentially unsafe symbols (e.g. uniforms referencing 
-        // cross-domain textures).
         if (success && (compileOptions & SH_TIMING_RESTRICTIONS))
-            success = enforceTimingRestrictions(root, "u_texture", (compileOptions & SH_DEPENDENCY_GRAPH) != 0);
+            success = enforceTimingRestrictions(root, (compileOptions & SH_DEPENDENCY_GRAPH) != 0);
 
         // Unroll for-loop markup needs to happen after validateLimitations pass.
         if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX))
@@ -252,9 +248,7 @@
     return validate.numErrors() == 0;
 }
 
-bool TCompiler::enforceTimingRestrictions(TIntermNode* root,
-                                          const TString& restrictedSymbol,
-                                          bool outputGraph)
+bool TCompiler::enforceTimingRestrictions(TIntermNode* root, bool outputGraph)
 {
     if (shaderSpec != SH_WEBGL_SPEC) {
         infoSink.info << "Timing restrictions must be enforced under the WebGL spec.";
@@ -265,7 +259,7 @@
         TDependencyGraph graph(root);
 
         // Output any errors first.
-        bool success = enforceFragmentShaderTimingRestrictions(graph, restrictedSymbol);
+        bool success = enforceFragmentShaderTimingRestrictions(graph);
         
         // Then, output the dependency graph.
         if (outputGraph) {
@@ -276,22 +270,20 @@
         return success;
     }
     else {
-        return enforceVertexShaderTimingRestrictions(root, restrictedSymbol);
+        return enforceVertexShaderTimingRestrictions(root);
     }
 }
 
-bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph,
-                                                        const TString& restrictedSymbol)
+bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
 {
-    RestrictFragmentShaderTiming restrictor(infoSink.info, restrictedSymbol);
+    RestrictFragmentShaderTiming restrictor(infoSink.info);
     restrictor.enforceRestrictions(graph);
     return restrictor.numErrors() == 0;
 }
 
-bool TCompiler::enforceVertexShaderTimingRestrictions(TIntermNode* root,
-                                                      const TString& restrictedSymbol)
+bool TCompiler::enforceVertexShaderTimingRestrictions(TIntermNode* root)
 {
-    RestrictVertexShaderTiming restrictor(infoSink.info, restrictedSymbol);
+    RestrictVertexShaderTiming restrictor(infoSink.info);
     restrictor.enforceRestrictions(root);
     return restrictor.numErrors() == 0;
 }
diff --git a/src/compiler/ShHandle.h b/src/compiler/ShHandle.h
index 0faaeb1..5e5b893 100644
--- a/src/compiler/ShHandle.h
+++ b/src/compiler/ShHandle.h
@@ -81,16 +81,12 @@
     // Translate to object code.
     virtual void translate(TIntermNode* root) = 0;
     // Returns true if the shader passes the restrictions that aim to prevent timing attacks.
-    bool enforceTimingRestrictions(TIntermNode* root,
-                                   const TString& restrictedSymbol,
-                                   bool outputGraph);
-    // Returns true if the shader does not define the restricted symbol.
-    bool enforceVertexShaderTimingRestrictions(TIntermNode* root,
-                                               const TString& restrictedSymbol);
-    // Returns true if the shader does not use the restricted symbol to affect control flow or in
-    // operations whose time can depend on the input values.
-    bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph,
-                                                 const TString& restrictedSymbol);
+    bool enforceTimingRestrictions(TIntermNode* root, bool outputGraph);
+    // Returns true if the shader does not use samplers.
+    bool enforceVertexShaderTimingRestrictions(TIntermNode* root);
+    // Returns true if the shader does not use sampler dependent values to affect control 
+    // flow or in operations whose time can depend on the input values.
+    bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph);
     // Get built-in extensions with default behavior.
     const TExtensionBehavior& getExtensionBehavior() const;
 
diff --git a/src/compiler/depgraph/DependencyGraph.cpp b/src/compiler/depgraph/DependencyGraph.cpp
index 9b69cc6..ca661d6 100644
--- a/src/compiler/depgraph/DependencyGraph.cpp
+++ b/src/compiler/depgraph/DependencyGraph.cpp
@@ -23,17 +23,6 @@
     }
 }
 
-TGraphSymbol* TDependencyGraph::getGlobalSymbolByName(const TString& name) const
-{
-    TSymbolNameMap::const_iterator iter = mGlobalSymbolMap.find(name);
-    if (iter == mGlobalSymbolMap.end())
-        return NULL;
-
-    TSymbolNamePair pair = *iter;
-    TGraphSymbol* symbol = pair.second;
-    return symbol;
-}
-
 TGraphArgument* TDependencyGraph::createArgument(TIntermAggregate* intermFunctionCall,
                                                  int argumentNumber)
 {
@@ -51,7 +40,7 @@
     return functionCall;
 }
 
-TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol, bool isGlobalSymbol)
+TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol)
 {
     TSymbolIdMap::const_iterator iter = mSymbolIdMap.find(intermSymbol->getId());
 
@@ -67,12 +56,9 @@
         TSymbolIdPair pair(intermSymbol->getId(), symbol);
         mSymbolIdMap.insert(pair);
 
-        if (isGlobalSymbol) {
-            // We map all symbols in the global scope by name, so traversers of the graph can
-            // quickly start searches at global symbols with specific names.
-            TSymbolNamePair pair(intermSymbol->getSymbol(), symbol);
-            mGlobalSymbolMap.insert(pair);
-        }
+        // We save all sampler symbols in a collection, so we can start graph traversals from them quickly.
+        if (IsSampler(intermSymbol->getBasicType()))
+            mSamplerSymbols.push_back(symbol);
     }
 
     return symbol;
diff --git a/src/compiler/depgraph/DependencyGraph.h b/src/compiler/depgraph/DependencyGraph.h
index 14aefb9..5a9c35d 100644
--- a/src/compiler/depgraph/DependencyGraph.h
+++ b/src/compiler/depgraph/DependencyGraph.h
@@ -25,6 +25,7 @@
 
 typedef std::set<TGraphNode*> TGraphNodeSet;
 typedef std::vector<TGraphNode*> TGraphNodeVector;
+typedef std::vector<TGraphSymbol*> TGraphSymbolVector;
 typedef std::vector<TGraphFunctionCall*> TFunctionCallVector;
 
 //
@@ -142,6 +143,16 @@
     TGraphNodeVector::const_iterator begin() const { return mAllNodes.begin(); }
     TGraphNodeVector::const_iterator end() const { return mAllNodes.end(); }
 
+    TGraphSymbolVector::const_iterator beginSamplerSymbols() const
+    {
+        return mSamplerSymbols.begin();
+    }
+
+    TGraphSymbolVector::const_iterator endSamplerSymbols() const
+    {
+        return mSamplerSymbols.end();
+    }
+
     TFunctionCallVector::const_iterator beginUserDefinedFunctionCalls() const
     {
         return mUserDefinedFunctionCalls.begin();
@@ -152,12 +163,9 @@
         return mUserDefinedFunctionCalls.end();
     }
 
-    // Returns NULL if the symbol is not found.
-    TGraphSymbol* getGlobalSymbolByName(const TString& name) const;
-
     TGraphArgument* createArgument(TIntermAggregate* intermFunctionCall, int argumentNumber);
     TGraphFunctionCall* createFunctionCall(TIntermAggregate* intermFunctionCall);
-    TGraphSymbol* getOrCreateSymbol(TIntermSymbol* intermSymbol, bool isGlobalSymbol);
+    TGraphSymbol* getOrCreateSymbol(TIntermSymbol* intermSymbol);
     TGraphSelection* createSelection(TIntermSelection* intermSelection);
     TGraphLoop* createLoop(TIntermLoop* intermLoop);
     TGraphLogicalOp* createLogicalOp(TIntermBinary* intermLogicalOp);
@@ -165,13 +173,10 @@
     typedef TMap<int, TGraphSymbol*> TSymbolIdMap;
     typedef std::pair<int, TGraphSymbol*> TSymbolIdPair;
 
-    typedef TMap<TString, TGraphSymbol*> TSymbolNameMap;
-    typedef std::pair<TString, TGraphSymbol*> TSymbolNamePair;
-
-    TSymbolIdMap mSymbolIdMap;
-    TSymbolNameMap mGlobalSymbolMap;
-    TFunctionCallVector mUserDefinedFunctionCalls;
     TGraphNodeVector mAllNodes;
+    TGraphSymbolVector mSamplerSymbols;
+    TFunctionCallVector mUserDefinedFunctionCalls;
+    TSymbolIdMap mSymbolIdMap;
 };
 
 //
diff --git a/src/compiler/depgraph/DependencyGraphBuilder.cpp b/src/compiler/depgraph/DependencyGraphBuilder.cpp
index 8e45c37..a49870a 100644
--- a/src/compiler/depgraph/DependencyGraphBuilder.cpp
+++ b/src/compiler/depgraph/DependencyGraphBuilder.cpp
@@ -31,18 +31,11 @@
 
 void TDependencyGraphBuilder::visitFunctionDefinition(TIntermAggregate* intermAggregate)
 {
-    // Function defintions should only exist in the global scope.
-    ASSERT(mIsGlobalScope);
-
     // Currently, we do not support user defined functions.
     if (intermAggregate->getName() != "main(")
         return;
 
-    mIsGlobalScope = false;
-
     visitAggregateChildren(intermAggregate);
-
-    mIsGlobalScope = true;
 }
 
 // Takes an expression like "f(x)" and creates a dependency graph like
@@ -93,7 +86,7 @@
 {
     // Push this symbol into the set of dependent symbols for the current assignment or condition
     // that we are traversing.
-    TGraphSymbol* symbol = mGraph->getOrCreateSymbol(intermSymbol, mIsGlobalScope);
+    TGraphSymbol* symbol = mGraph->getOrCreateSymbol(intermSymbol);
     mNodeSets.insertIntoTopSet(symbol);
 
     // If this symbol is the current leftmost symbol under an assignment, replace the previous
diff --git a/src/compiler/depgraph/DependencyGraphBuilder.h b/src/compiler/depgraph/DependencyGraphBuilder.h
index 91bc490..c7c29e2 100644
--- a/src/compiler/depgraph/DependencyGraphBuilder.h
+++ b/src/compiler/depgraph/DependencyGraphBuilder.h
@@ -164,8 +164,7 @@
 
     TDependencyGraphBuilder(TDependencyGraph* graph)
         : TIntermTraverser(true, false, false)
-        , mGraph(graph)
-        , mIsGlobalScope(true) {}
+        , mGraph(graph) {}
     void build(TIntermNode* intermNode) { intermNode->traverse(this); }
 
     void connectMultipleNodesToSingleNode(TParentNodeSet* nodes, TGraphNode* node) const;
@@ -180,7 +179,6 @@
     TDependencyGraph* mGraph;
     TNodeSetStack mNodeSets;
     TSymbolStack mLeftmostSymbols;
-    bool mIsGlobalScope;
 };
 
 #endif  // COMPILER_DEPGRAPH_DEPENDENCY_GRAPH_BUILDER_H
diff --git a/src/compiler/timing/RestrictFragmentShaderTiming.cpp b/src/compiler/timing/RestrictFragmentShaderTiming.cpp
index 3b3bbeb..47eb581 100644
--- a/src/compiler/timing/RestrictFragmentShaderTiming.cpp
+++ b/src/compiler/timing/RestrictFragmentShaderTiming.cpp
@@ -19,13 +19,16 @@
     // so we generate errors for them.
     validateUserDefinedFunctionCallUsage(graph);
 
-    // Traverse the dependency graph starting at s_texture and generate an error each time we hit a
-    // condition node.
-    TGraphSymbol* uTextureGraphSymbol = graph.getGlobalSymbolByName(mRestrictedSymbol);
-    if (uTextureGraphSymbol &&
-        uTextureGraphSymbol->getIntermSymbol()->getQualifier() == EvqUniform &&
-        uTextureGraphSymbol->getIntermSymbol()->getBasicType() == EbtSampler2D)
-        uTextureGraphSymbol->traverse(this);
+    // Starting from each sampler, traverse the dependency graph and generate an error each time we
+    // hit a node where sampler dependent values are not allowed.
+    for (TGraphSymbolVector::const_iterator iter = graph.beginSamplerSymbols();
+         iter != graph.endSamplerSymbols();
+         ++iter)
+    {
+        TGraphSymbol* samplerSymbol = *iter;
+        clearVisited();
+        samplerSymbol->traverse(this);
+    }
 }
 
 void RestrictFragmentShaderTiming::validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph)
@@ -50,35 +53,33 @@
 void RestrictFragmentShaderTiming::visitArgument(TGraphArgument* parameter)
 {
     // FIXME(mvujovic): We should restrict sampler dependent values from being texture coordinates 
-    // in all available sampling operationsn supported in GLSL ES.
+    // in all available sampling operations supported in GLSL ES.
     // This includes overloaded signatures of texture2D, textureCube, and others.
     if (parameter->getIntermFunctionCall()->getName() != "texture2D(s21;vf2;" ||
         parameter->getArgumentNumber() != 1)
         return;
 
     beginError(parameter->getIntermFunctionCall());
-    mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol
-          << "' is not permitted to be the second argument of a texture2D call.\n";
+    mSink << "An expression dependent on a sampler is not permitted to be the second argument"
+          << " of a texture2D call.\n";
 }
 
 void RestrictFragmentShaderTiming::visitSelection(TGraphSelection* selection)
 {
     beginError(selection->getIntermSelection());
-    mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol
-          << "' is not permitted in a conditional statement.\n";
+    mSink << "An expression dependent on a sampler is not permitted in a conditional statement.\n";
 }
 
 void RestrictFragmentShaderTiming::visitLoop(TGraphLoop* loop)
 {
     beginError(loop->getIntermLoop());
-    mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol
-          << "' is not permitted in a loop condition.\n";
+    mSink << "An expression dependent on a sampler is not permitted in a loop condition.\n";
 }
 
 void RestrictFragmentShaderTiming::visitLogicalOp(TGraphLogicalOp* logicalOp)
 {
     beginError(logicalOp->getIntermLogicalOp());
-    mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol
-          << "' is not permitted on the left hand side of a logical " << logicalOp->getOpString()
+    mSink << "An expression dependent on a sampler is not permitted on the left hand side of a logical "
+          << logicalOp->getOpString()
           << " operator.\n";
 }
diff --git a/src/compiler/timing/RestrictFragmentShaderTiming.h b/src/compiler/timing/RestrictFragmentShaderTiming.h
index 57cb4a3..5793079 100644
--- a/src/compiler/timing/RestrictFragmentShaderTiming.h
+++ b/src/compiler/timing/RestrictFragmentShaderTiming.h
@@ -16,9 +16,8 @@
 
 class RestrictFragmentShaderTiming : TDependencyGraphTraverser {
 public:
-    RestrictFragmentShaderTiming(TInfoSinkBase& sink, const TString& restrictedSymbol)
+    RestrictFragmentShaderTiming(TInfoSinkBase& sink)
         : mSink(sink)
-        , mRestrictedSymbol(restrictedSymbol)
         , mNumErrors(0) {}
 
     void enforceRestrictions(const TDependencyGraph& graph);
@@ -34,7 +33,6 @@
     void validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph);
 
 	TInfoSinkBase& mSink;
-    const TString mRestrictedSymbol;
     int mNumErrors;
 };
 
diff --git a/src/compiler/timing/RestrictVertexShaderTiming.cpp b/src/compiler/timing/RestrictVertexShaderTiming.cpp
index 220cb49..524c6cf 100644
--- a/src/compiler/timing/RestrictVertexShaderTiming.cpp
+++ b/src/compiler/timing/RestrictVertexShaderTiming.cpp
@@ -8,23 +8,10 @@
 
 void RestrictVertexShaderTiming::visitSymbol(TIntermSymbol* node)
 {
-    if (node->getQualifier() == EvqUniform &&
-        node->getBasicType() == EbtSampler2D &&
-        node->getSymbol() == mRestrictedSymbol) {
-        mFoundRestrictedSymbol = true;
+    if (IsSampler(node->getBasicType())) {
+        ++mNumErrors;
         mSink.prefix(EPrefixError);
         mSink.location(node->getLine());
-        mSink << "Definition of a uniform sampler2D by the name '" << mRestrictedSymbol
-              << "' is not permitted in vertex shaders.\n";
+        mSink << "Samplers are not permitted in vertex shaders.\n";
     }
 }
-
-bool RestrictVertexShaderTiming::visitAggregate(Visit visit, TIntermAggregate* node)
-{
-    // Don't keep exploring if we've found the restricted symbol, and don't explore anything besides
-    // the global scope (i.e. don't explore function definitions).
-    if (mFoundRestrictedSymbol || node->getOp() == EOpFunction)
-        return false;
-
-    return true;
-}
diff --git a/src/compiler/timing/RestrictVertexShaderTiming.h b/src/compiler/timing/RestrictVertexShaderTiming.h
index c5cdae5..19a05fa 100644
--- a/src/compiler/timing/RestrictVertexShaderTiming.h
+++ b/src/compiler/timing/RestrictVertexShaderTiming.h
@@ -16,26 +16,18 @@
 
 class RestrictVertexShaderTiming : public TIntermTraverser {
 public:
-    RestrictVertexShaderTiming(TInfoSinkBase& sink, const TString& restrictedSymbol)
+    RestrictVertexShaderTiming(TInfoSinkBase& sink)
         : TIntermTraverser(true, false, false)
         , mSink(sink)
-        , mRestrictedSymbol(restrictedSymbol)
-        , mFoundRestrictedSymbol(false) {}
+        , mNumErrors(0) {}
 
     void enforceRestrictions(TIntermNode* root) { root->traverse(this); }
-    int numErrors() { return mFoundRestrictedSymbol ? 1 : 0; }
+    int numErrors() { return mNumErrors; }
 
     virtual void visitSymbol(TIntermSymbol*);
-    virtual bool visitBinary(Visit visit, TIntermBinary*) { return false; }
-    virtual bool visitUnary(Visit visit, TIntermUnary*) { return false; }
-    virtual bool visitSelection(Visit visit, TIntermSelection*) { return false; }
-    virtual bool visitAggregate(Visit visit, TIntermAggregate*);
-    virtual bool visitLoop(Visit visit, TIntermLoop*) { return false; };
-    virtual bool visitBranch(Visit visit, TIntermBranch*) { return false; };
 private:
     TInfoSinkBase& mSink;
-    const TString mRestrictedSymbol;
-    bool mFoundRestrictedSymbol;
+    int mNumErrors;
 };
 
 #endif  // COMPILER_TIMING_RESTRICT_VERTEX_SHADER_TIMING_H_