Compiler - split header, body and footer output
TRAC #11798
Signed-off-by: Shannon Woods
Signed-off-by: Daniel Koch

Author:    Nicolas Capens

git-svn-id: https://angleproject.googlecode.com/svn/trunk@126 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/OutputHLSL.cpp b/src/compiler/OutputHLSL.cpp
index fdb7f49..d102e20 100644
--- a/src/compiler/OutputHLSL.cpp
+++ b/src/compiler/OutputHLSL.cpp
@@ -11,14 +11,25 @@
 
 namespace sh
 {
-OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), context(context)
+OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
 {
 }
 
+void OutputHLSL::output()
+{
+    mContext.treeRoot->traverse(this);   // Output the body first to determine what has to go in the header and footer
+    header();
+    footer();
+
+    mContext.infoSink.obj << mHeader.c_str();
+    mContext.infoSink.obj << mBody.c_str();
+    mContext.infoSink.obj << mFooter.c_str();
+}
+
 void OutputHLSL::header()
 {
-    EShLanguage language = context.language;
-    TInfoSinkBase &out = context.infoSink.obj;
+    EShLanguage language = mContext.language;
+    TInfoSinkBase &out = mHeader;
 
     if (language == EShLangFragment)
     {
@@ -26,7 +37,7 @@
         TString varyingInput;
         TString varyingGlobals;
 
-        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+        TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
         int semanticIndex = 0;
 
         for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
@@ -119,7 +130,7 @@
         TString varyingOutput;
         TString varyingGlobals;
 
-        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+        TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
         int semanticIndex = 0;
 
         for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
@@ -416,9 +427,9 @@
 
 void OutputHLSL::footer()
 {
-    EShLanguage language = context.language;
-    TInfoSinkBase &out = context.infoSink.obj;
-    TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+    EShLanguage language = mContext.language;
+    TInfoSinkBase &out = mFooter;
+    TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
 
     if (language == EShLangFragment)
     {
@@ -489,7 +500,7 @@
                "    output.gl_PointSize = gl_PointSize;\n"
                "    output.gl_FragCoord = gl_Position;\n";
 
-        TSymbolTableLevel *symbols = context.symbolTable.getGlobalLevel();
+        TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
 
         for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
         {
@@ -516,7 +527,7 @@
 
 void OutputHLSL::visitSymbol(TIntermSymbol *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     TString name = node->getSymbol();
 
@@ -536,7 +547,7 @@
 
 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     switch (node->getOp())
     {
@@ -711,7 +722,7 @@
 
 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     switch (node->getOp())
     {
@@ -789,8 +800,8 @@
 
 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
 {
-    EShLanguage language = context.language;
-    TInfoSinkBase &out = context.infoSink.obj;
+    EShLanguage language = mContext.language;
+    TInfoSinkBase &out = mBody;
 
     if (node->getOp() == EOpNull)
     {
@@ -1057,7 +1068,7 @@
 
 bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     if (node->usesTernaryOperator())
     {
@@ -1098,7 +1109,7 @@
 
 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
     
     const TType &type = node->getType();
 
@@ -1231,7 +1242,7 @@
         return false;
     }
 
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     if (!node->testFirst())
     {
@@ -1288,7 +1299,7 @@
 
 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     switch (node->getFlowOp())
     {
@@ -1321,7 +1332,7 @@
 // Handle loops with more than 255 iterations (unsupported by D3D9) by splitting them
 bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     // Parse loops of the form:
     // for(int index = initial; index [comparator] limit; index += increment)
@@ -1486,7 +1497,7 @@
 
 void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString)
 {
-    TInfoSinkBase &out = context.infoSink.obj;
+    TInfoSinkBase &out = mBody;
 
     if (visit == PreVisit && preString)
     {
diff --git a/src/compiler/OutputHLSL.h b/src/compiler/OutputHLSL.h
index ccd1fe5..f4876f6 100644
--- a/src/compiler/OutputHLSL.h
+++ b/src/compiler/OutputHLSL.h
@@ -17,10 +17,13 @@
   public:
     OutputHLSL(TParseContext &context);
 
+    void output();
+
+  protected:
     void header();
     void footer();
 
-  protected:
+    // Visit AST nodes and output their code to the body stream
     void visitSymbol(TIntermSymbol*);
     void visitConstantUnion(TIntermConstantUnion*);
     bool visitBinary(Visit visit, TIntermBinary*);
@@ -39,7 +42,12 @@
     static TString arrayString(const TType &type);
     static TString initializer(const TType &type);
 
-    TParseContext &context;
+    TParseContext &mContext;
+
+    // Output streams
+    TInfoSinkBase mHeader;
+    TInfoSinkBase mBody;
+    TInfoSinkBase mFooter;
 };
 }
 
diff --git a/src/compiler/TranslatorHLSL.cpp b/src/compiler/TranslatorHLSL.cpp
index c92764c..1d18c68 100644
--- a/src/compiler/TranslatorHLSL.cpp
+++ b/src/compiler/TranslatorHLSL.cpp
@@ -17,9 +17,7 @@
     TParseContext& parseContext = *GetGlobalParseContext();
     sh::OutputHLSL outputHLSL(parseContext);
 
-    outputHLSL.header();
-    parseContext.treeRoot->traverse(&outputHLSL);
-    outputHLSL.footer();
+    outputHLSL.output();
     
     return true;
 }