Compiler - only declare used HLSL functions
TRAC #11315
Signed-off-by: Shannon Woods
Signed-off-by: Daniel Koch
Author:    Nicolas Capens

git-svn-id: https://angleproject.googlecode.com/svn/trunk@75 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/OutputHLSL.cpp b/src/compiler/OutputHLSL.cpp
index c8a7c98..5566673 100644
--- a/src/compiler/OutputHLSL.cpp
+++ b/src/compiler/OutputHLSL.cpp
@@ -411,6 +411,113 @@
            "\n";
 }
 
+void OutputHLSL::footer()
+{
+    EShLanguage language = context.language;
+    TInfoSinkBase &out = context.infoSink.obj;
+    TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+
+    if (language == EShLangFragment)
+    {
+        out << "PS_OUTPUT main(PS_INPUT input)\n"   // FIXME: Prevent name clashes
+               "{\n"
+               "    float rhw = 1.0 / input.gl_FragCoord.w;\n"
+               "    gl_FragCoord.x = (input.gl_FragCoord.x * rhw) * gl_Window.x + gl_Window.z;\n"
+               "    gl_FragCoord.y = (input.gl_FragCoord.y * rhw) * gl_Window.y + gl_Window.w;\n"
+               "    gl_FragCoord.z = (input.gl_FragCoord.z * rhw) * gl_Depth.x + gl_Depth.y;\n"
+               "    gl_FragCoord.w = rhw;\n"
+               "    gl_FrontFacing = __frontCCW ? (input.__vFace >= 0.0) : (input.__vFace <= 0.0);\n";
+
+        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
+        {
+            const TSymbol *symbol = (*namedSymbol).second;
+            const TString &name = symbol->getName();
+
+            if (symbol->isVariable())
+            {
+                const TVariable *variable = static_cast<const TVariable*>(symbol);
+                const TType &type = variable->getType();
+                TQualifier qualifier = type.getQualifier();
+
+                if (qualifier == EvqVaryingIn)
+                {
+                    out << "    " + name + " = input." + name + ";\n";   // FIXME: Prevent name clashes
+                }
+            }
+        }
+
+        out << "\n"
+               "    gl_main();\n"
+               "\n"
+               "    PS_OUTPUT output;\n"                    // FIXME: Prevent name clashes
+               "    output.gl_Color[0] = gl_Color[0];\n";   // FIXME: Prevent name clashes
+
+        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+
+        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
+        {
+            const TSymbol *symbol = (*namedSymbol).second;
+            const TString &name = symbol->getName();
+        }
+    }
+    else
+    {
+        out << "VS_OUTPUT main(VS_INPUT input)\n"   // FIXME: Prevent name clashes
+               "{\n";
+
+        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
+        {
+            const TSymbol *symbol = (*namedSymbol).second;
+            const TString &name = symbol->getName();
+
+            if (symbol->isVariable())
+            {
+                const TVariable *variable = static_cast<const TVariable*>(symbol);
+                const TType &type = variable->getType();
+                TQualifier qualifier = type.getQualifier();
+
+                if (qualifier == EvqAttribute)
+                {
+                    out << "    " + name + " = input." + name + ";\n";   // FIXME: Prevent name clashes
+                }
+            }
+        }
+
+        out << "\n"
+               "    gl_main();\n"
+               "\n"
+               "    VS_OUTPUT output;\n"   // FIXME: Prevent name clashes
+               "    output.gl_Position.x = gl_Position.x - gl_HalfPixelSize.x * gl_Position.w;\n"
+               "    output.gl_Position.y = -(gl_Position.y - gl_HalfPixelSize.y * gl_Position.w);\n"
+               "    output.gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;\n"
+               "    output.gl_Position.w = gl_Position.w;\n"
+               "    output.gl_PointSize = gl_PointSize;\n"
+               "    output.gl_FragCoord = gl_Position;\n";
+
+        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+
+        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
+        {
+            const TSymbol *symbol = (*namedSymbol).second;
+            const TString &name = symbol->getName();
+
+            if (symbol->isVariable())
+            {
+                const TVariable *variable = static_cast<const TVariable*>(symbol);
+                TQualifier qualifier = variable->getType().getQualifier();
+
+                if (qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut)
+                {
+                    out << "    output." + name + " = " + name + ";\n";   // FIXME: Prevent name clashes
+                }
+            }
+        }
+    }
+
+    out << "    return output;\n"   // FIXME: Prevent name clashes
+           "}\n";
+}
+
 void OutputHLSL::visitSymbol(TIntermSymbol *node)
 {
     TInfoSinkBase &out = context.infoSink.obj;
@@ -678,151 +785,41 @@
             {
                 if (name == "main")
                 {
-                    TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
-
-                    if (language == EShLangFragment)
-                    {
-                        out << "PS_OUTPUT main(PS_INPUT input)\n"   // FIXME: Prevent name clashes
-                               "{\n"
-                               "    float rhw = 1.0 / input.gl_FragCoord.w;\n"
-                               "    gl_FragCoord.x = (input.gl_FragCoord.x * rhw) * gl_Window.x + gl_Window.z;\n"
-                               "    gl_FragCoord.y = (input.gl_FragCoord.y * rhw) * gl_Window.y + gl_Window.w;\n"
-                               "    gl_FragCoord.z = (input.gl_FragCoord.z * rhw) * gl_Depth.x + gl_Depth.y;\n"
-                               "    gl_FragCoord.w = rhw;\n"
-                               "    gl_FrontFacing = __frontCCW ? (input.__vFace >= 0.0) : (input.__vFace <= 0.0);\n";
-
-                        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
-                        {
-                            const TSymbol *symbol = (*namedSymbol).second;
-                            const TString &name = symbol->getName();
-
-                            if (symbol->isVariable())
-                            {
-                                const TVariable *variable = static_cast<const TVariable*>(symbol);
-                                const TType &type = variable->getType();
-                                TQualifier qualifier = type.getQualifier();
-
-                                if(qualifier == EvqVaryingIn)
-                                {
-                                    out << "    " + name + " = input." + name + ";\n";   // FIXME: Prevent name clashes
-                                }
-                            }
-                        }
-                    }
-                    else
-                    {
-                        out << "VS_OUTPUT main(VS_INPUT input)\n"   // FIXME: Prevent name clashes
-                               "{\n";
-
-                        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
-                        {
-                            const TSymbol *symbol = (*namedSymbol).second;
-                            const TString &name = symbol->getName();
-
-                            if (symbol->isVariable())
-                            {
-                                const TVariable *variable = static_cast<const TVariable*>(symbol);
-                                const TType &type = variable->getType();
-                                TQualifier qualifier = type.getQualifier();
-
-                                if (qualifier == EvqAttribute)
-                                {
-                                    out << "    " + name + " = input." + name + ";\n";   // FIXME: Prevent name clashes
-                                }
-                            }
-                        }
-                    }
-
-                    // Erase the (empty) argument list
-                    TIntermSequence &sequence = node->getSequence();
-                    sequence.erase(sequence.begin());
+                    name = "gl_main";
                 }
-                else
+
+                out << typeString(node->getType()) << " " << name << "(";
+
+                TIntermSequence &sequence = node->getSequence();
+                TIntermSequence &arguments = sequence[0]->getAsAggregate()->getSequence();
+
+                for (unsigned int i = 0; i < arguments.size(); i++)
                 {
-                    out << typeString(node->getType()) << " " << name << "(";
+                    TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
 
-                    TIntermSequence &sequence = node->getSequence();
-                    TIntermSequence &arguments = sequence[0]->getAsAggregate()->getSequence();
-
-                    for (unsigned int i = 0; i < arguments.size(); i++)
+                    if (symbol)
                     {
-                        TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
+                        const TType &type = symbol->getType();
+                        const TString &name = symbol->getSymbol();
 
-                        if (symbol)
+                        out << typeString(type) + " " + name;
+
+                        if (i < arguments.size() - 1)
                         {
-                            const TType &type = symbol->getType();
-                            const TString &name = symbol->getSymbol();
-
-                            out << typeString(type) + " " + name;
-
-                            if(i < arguments.size() - 1)
-                            {
-                                out << ", ";
-                            }
+                            out << ", ";
                         }
-                        else UNREACHABLE();
                     }
-
-                    sequence.erase(sequence.begin());
-
-                    out << ")\n"
-                           "{\n";
+                    else UNREACHABLE();
                 }
+
+                sequence.erase(sequence.begin());
+
+                out << ")\n"
+                       "{\n";
             }
             else if (visit == PostVisit)
             {
-                if (name == "main")
-                {
-                    if (language == EShLangFragment)
-                    {
-                        out << "    PS_OUTPUT output;\n"                    // FIXME: Prevent name clashes
-                               "    output.gl_Color[0] = gl_Color[0];\n";   // FIXME: Prevent name clashes
-
-                        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
-
-                        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
-                        {
-                            const TSymbol *symbol = (*namedSymbol).second;
-                            const TString &name = symbol->getName();
-                        }
-                    }
-                    else
-                    {
-                        out << "    VS_OUTPUT output;\n"   // FIXME: Prevent name clashes
-                               "    output.gl_Position.x = gl_Position.x - gl_HalfPixelSize.x * gl_Position.w;\n"
-                               "    output.gl_Position.y = -(gl_Position.y - gl_HalfPixelSize.y * gl_Position.w);\n"
-                               "    output.gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;\n"
-                               "    output.gl_Position.w = gl_Position.w;\n"
-                               "    output.gl_PointSize = 1.0;\n"
-                               "    output.gl_FragCoord = gl_Position;\n";
-
-                        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
-
-                        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
-                        {
-                            const TSymbol *symbol = (*namedSymbol).second;
-                            const TString &name = symbol->getName();
-
-                            if (symbol->isVariable())
-                            {
-                                const TVariable *variable = static_cast<const TVariable*>(symbol);
-                                TQualifier qualifier = variable->getType().getQualifier();
-
-                                if(qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut)
-                                {
-                                    out << "    output." + name + " = " + name + ";\n";   // FIXME: Prevent name clashes
-                                }
-                            }
-                        }
-                    }
-
-                    out << "    return output;\n"   // FIXME: Prevent name clashes
-                           "}\n";
-                }
-                else
-                {
-                    out << "}\n";
-                }
+                out << "}\n";
             }
         }
         break;
@@ -944,7 +941,7 @@
 {
     TInfoSinkBase &out = context.infoSink.obj;
 
-    if(node->getType().getBasicType() == EbtVoid)   // if/else statement
+    if (node->getType().getBasicType() == EbtVoid)   // if/else statement
     {
         out << "if(";
 
diff --git a/src/compiler/OutputHLSL.h b/src/compiler/OutputHLSL.h
index 9af7942..667d9e8 100644
--- a/src/compiler/OutputHLSL.h
+++ b/src/compiler/OutputHLSL.h
@@ -23,6 +23,7 @@
     OutputHLSL(TParseContext &context);
 
     void header();
+    void footer();
 
   protected:
     void visitSymbol(TIntermSymbol*);
diff --git a/src/compiler/TranslatorHLSL.cpp b/src/compiler/TranslatorHLSL.cpp
index 52ce812..c92764c 100644
--- a/src/compiler/TranslatorHLSL.cpp
+++ b/src/compiler/TranslatorHLSL.cpp
@@ -7,16 +7,19 @@
 #include "TranslatorHLSL.h"
 #include "OutputHLSL.h"
 
-TranslatorHLSL::TranslatorHLSL(EShLanguage l, int dOptions)
-        : TCompiler(l),
-          debugOptions(dOptions) {
+TranslatorHLSL::TranslatorHLSL(EShLanguage language, int debugOptions)
+    : TCompiler(language), debugOptions(debugOptions)
+{
 }
 
-bool TranslatorHLSL::compile(TIntermNode* root) {
+bool TranslatorHLSL::compile(TIntermNode *root)
+{
     TParseContext& parseContext = *GetGlobalParseContext();
     sh::OutputHLSL outputHLSL(parseContext);
+
     outputHLSL.header();
     parseContext.treeRoot->traverse(&outputHLSL);
-
+    outputHLSL.footer();
+    
     return true;
 }