Fixed multiple issues with struct declaration and construction.
Review URL: http://codereview.appspot.com/1076041

git-svn-id: http://angleproject.googlecode.com/svn/trunk@220 736b8ea6-26fd-11df-bfd4-992fa37f6226
diff --git a/src/compiler/OutputGLSL.cpp b/src/compiler/OutputGLSL.cpp
index 2b42a81..0453af7 100644
--- a/src/compiler/OutputGLSL.cpp
+++ b/src/compiler/OutputGLSL.cpp
@@ -38,56 +38,15 @@
     }
     return TString(out.c_str());
 }
-
-TString getIndentationString(int depth)
-{
-    TString indentation(depth, ' ');
-    return indentation;
-}
 }  // namespace
 
-TOutputGLSL::TOutputGLSL(TParseContext &context)
+TOutputGLSL::TOutputGLSL(TInfoSinkBase& objSink)
     : TIntermTraverser(true, true, true),
-      writeFullSymbol(false),
-      parseContext(context)
+      mObjSink(objSink),
+      mWriteFullSymbol(false)
 {
 }
 
-// Header declares user-defined structs.
-void TOutputGLSL::header()
-{
-    TInfoSinkBase& out = objSink();
-
-    TSymbolTableLevel* symbols = parseContext.symbolTable.getGlobalLevel();
-    for (TSymbolTableLevel::const_iterator symbolIter = symbols->begin(); symbolIter != symbols->end(); ++symbolIter)
-    {
-        const TSymbol* symbol = symbolIter->second;
-        if (!symbol->isVariable())
-            continue;
-
-        const TVariable* variable = static_cast<const TVariable*>(symbol);
-        if (!variable->isUserType())
-            continue;
-
-        const TType& type = variable->getType();
-        ASSERT(type.getQualifier() == EvqTemporary);
-        ASSERT(type.getBasicType() == EbtStruct);
-
-        out << "struct " << variable->getName() << "{\n";
-        const TTypeList* structure = type.getStruct();
-        ASSERT(structure != NULL);
-        incrementDepth();
-        for (size_t i = 0; i < structure->size(); ++i) {
-            const TType* fieldType = (*structure)[i].type;
-            ASSERT(fieldType != NULL);
-            out << getIndentationString(depth);
-            out << getTypeName(*fieldType) << " " << fieldType->getFieldName() << ";\n";
-        }
-        decrementDepth();
-        out << "};\n";
-    }
-}
-
 void TOutputGLSL::writeTriplet(Visit visit, const char* preStr, const char* inStr, const char* postStr)
 {
     TInfoSinkBase& out = objSink();
@@ -108,16 +67,38 @@
 void TOutputGLSL::visitSymbol(TIntermSymbol* node)
 {
     TInfoSinkBase& out = objSink();
-    if (writeFullSymbol)
+    const TType& type = node->getType();
+
+    if (mWriteFullSymbol)
     {
         TQualifier qualifier = node->getQualifier();
         if ((qualifier != EvqTemporary) && (qualifier != EvqGlobal))
             out << node->getQualifierString() << " ";
- 
-        out << getTypeName(node->getType()) << " ";
+
+        // Declare the struct if we have not done so already.
+        if ((type.getBasicType() == EbtStruct) && (mDeclaredStructs.find(type.getTypeName()) == mDeclaredStructs.end()))
+        {
+            out << "struct " << type.getTypeName() << "{\n";
+            const TTypeList* structure = type.getStruct();
+            ASSERT(structure != NULL);
+            for (size_t i = 0; i < structure->size(); ++i)
+            {
+                const TType* fieldType = (*structure)[i].type;
+                ASSERT(fieldType != NULL);
+                out << getTypeName(*fieldType) << " " << fieldType->getFieldName() << ";\n";
+            }
+            out << "} ";
+            mDeclaredStructs.insert(type.getTypeName());
+        }
+        else
+        {
+            out << getTypeName(type) << " ";
+        }
     }
+
     out << node->getSymbol();
-    if (writeFullSymbol && node->getType().isArray())
+
+    if (mWriteFullSymbol && node->getType().isArray())
     {
         out << "[" << node->getType().getArraySize() << "]";
     }
@@ -127,11 +108,13 @@
 {
     TInfoSinkBase& out = objSink();
 
-    TType type = node->getType();
+    const TType& type = node->getType();
     int size = type.getObjectSize();
-    if (size > 1)
+    bool writeType = (size > 1) || (type.getBasicType() == EbtStruct);
+    if (writeType)
         out << getTypeName(type) << "(";
-    for (int i = 0; i < size; ++i) {
+    for (int i = 0; i < size; ++i)
+    {
         const constUnion& data = node->getUnionArrayPointer()[i];
         switch (data.getType())
         {
@@ -143,7 +126,7 @@
         if (i != size - 1)
             out << ", ";
     }
-    if (size > 1)
+    if (writeType)
         out << ")";
 }
 
@@ -157,7 +140,7 @@
         case EOpInitialize:
             if (visit == InVisit) {
                 out << " = ";
-                writeFullSymbol= false;
+                mWriteFullSymbol= false;
             }
             break;
         case EOpAddAssign: writeTriplet(visit, NULL, " += ", NULL); break;
@@ -255,12 +238,39 @@
         case EOpPreIncrement: writeTriplet(visit, "(++", NULL, ")"); break;
         case EOpPreDecrement: writeTriplet(visit, "(--", NULL, ")"); break;
 
-        case EOpConvIntToBool: writeTriplet(visit, "bool(", NULL, ")"); break;
-        case EOpConvFloatToBool: writeTriplet(visit, "bool(", NULL, ")"); break;
-        case EOpConvBoolToFloat: writeTriplet(visit, "float(", NULL, ")"); break;
-        case EOpConvIntToFloat: writeTriplet(visit, "float(", NULL, ")"); break;
-        case EOpConvFloatToInt: writeTriplet(visit, "int(", NULL, ")"); break;
-        case EOpConvBoolToInt: writeTriplet(visit, "int(", NULL, ")"); break;
+        case EOpConvIntToBool:
+        case EOpConvFloatToBool:
+            switch (node->getOperand()->getType().getNominalSize())
+            {
+                case 1: writeTriplet(visit, "bool(", NULL, ")");  break;
+                case 2: writeTriplet(visit, "bvec2(", NULL, ")"); break;
+                case 3: writeTriplet(visit, "bvec3(", NULL, ")"); break;
+                case 4: writeTriplet(visit, "bvec4(", NULL, ")"); break;
+                default: UNREACHABLE();
+            }
+            break;
+        case EOpConvBoolToFloat:
+        case EOpConvIntToFloat:
+            switch (node->getOperand()->getType().getNominalSize())
+            {
+                case 1: writeTriplet(visit, "float(", NULL, ")");  break;
+                case 2: writeTriplet(visit, "vec2(", NULL, ")"); break;
+                case 3: writeTriplet(visit, "vec3(", NULL, ")"); break;
+                case 4: writeTriplet(visit, "vec4(", NULL, ")"); break;
+                default: UNREACHABLE();
+            }
+            break;
+        case EOpConvFloatToInt:
+        case EOpConvBoolToInt:
+            switch (node->getOperand()->getType().getNominalSize())
+            {
+                case 1: writeTriplet(visit, "int(", NULL, ")");  break;
+                case 2: writeTriplet(visit, "ivec2(", NULL, ")"); break;
+                case 3: writeTriplet(visit, "ivec3(", NULL, ")"); break;
+                case 4: writeTriplet(visit, "ivec4(", NULL, ")"); break;
+                default: UNREACHABLE();
+            }
+            break;
 
         case EOpRadians: writeTriplet(visit, "radians(", NULL, ")"); break;
         case EOpDegrees: writeTriplet(visit, "degrees(", NULL, ")"); break;
@@ -321,13 +331,13 @@
         {
             node->getTrueBlock()->traverse(this);
         }
-        out << getIndentationString(depth - 2) << "}";
+        out << "}";
 
         if (node->getFalseBlock())
         {
             out << " else {\n";
             node->getFalseBlock()->traverse(this);
-            out << getIndentationString(depth - 2) << "}";
+            out << "}";
         }
         decrementDepth();
         out << "\n";
@@ -341,19 +351,7 @@
     switch (node->getOp())
     {
         case EOpSequence:
-            if (visit == PreVisit)
-            {
-                out << getIndentationString(depth);
-            }
-            else if (visit == InVisit)
-            {
-                out << ";\n";
-                out << getIndentationString(depth - 1);
-            }
-            else
-            {
-                out << ";\n";
-            }
+            writeTriplet(visit, NULL, ";\n", ";\n");
             break;
         case EOpPrototype:
             // Function declaration.
@@ -361,7 +359,7 @@
             {
                 TString returnType = getTypeName(node->getType());
                 out << returnType << " " << node->getName() << "(";
-                writeFullSymbol = true;
+                mWriteFullSymbol = true;
             }
             else if (visit == InVisit)
             {
@@ -372,7 +370,7 @@
             {
                 // Called after fucntion arguments.
                 out << ")";
-                writeFullSymbol = false;
+                mWriteFullSymbol = false;
             }
             break;
         case EOpFunction:
@@ -416,7 +414,7 @@
             if (visit == PreVisit)
             {
                 out << "(";
-                writeFullSymbol = true;
+                mWriteFullSymbol = true;
             }
             else if (visit == InVisit)
             {
@@ -425,23 +423,23 @@
             else
             {
                 out << ")";
-                writeFullSymbol = false;
+                mWriteFullSymbol = false;
             }
             break;
         case EOpDeclaration:
             // Variable declaration.
             if (visit == PreVisit)
             {
-                writeFullSymbol = true;
+                mWriteFullSymbol = true;
             }
             else if (visit == InVisit)
             {
                 out << ", ";
-                writeFullSymbol = false;
+                mWriteFullSymbol = false;
             }
             else
             {
-                writeFullSymbol = false;
+                mWriteFullSymbol = false;
             }
             break;
 
@@ -460,7 +458,22 @@
         case EOpConstructMat2: writeTriplet(visit, "mat2(", ", ", ")"); break;
         case EOpConstructMat3: writeTriplet(visit, "mat3(", ", ", ")"); break;
         case EOpConstructMat4: writeTriplet(visit, "mat4(", ", ", ")"); break;
-        case EOpConstructStruct: UNIMPLEMENTED(); break;
+        case EOpConstructStruct:
+            if (visit == PreVisit)
+            {
+                const TType& type = node->getType();
+                ASSERT(type.getBasicType() == EbtStruct);
+                out << type.getTypeName() << "(";
+            }
+            else if (visit == InVisit)
+            {
+                out << ", ";
+            }
+            else
+            {
+                out << ")";
+            }
+            break;
 
         case EOpLessThan: writeTriplet(visit, "lessThan(", ", ", ")"); break;
         case EOpGreaterThan: writeTriplet(visit, "greaterThan(", ", ", ")"); break;
@@ -549,7 +562,7 @@
         case EOpKill: writeTriplet(visit, "discard", NULL, NULL); break;
         case EOpBreak: writeTriplet(visit, "break", NULL, NULL); break;
         case EOpContinue: writeTriplet(visit, "continue", NULL, NULL); break;
-        case EOpReturn: writeTriplet(visit, "return", NULL, NULL); break;
+        case EOpReturn: writeTriplet(visit, "return ", NULL, NULL); break;
         default: UNREACHABLE(); break;
     }
 
diff --git a/src/compiler/OutputGLSL.h b/src/compiler/OutputGLSL.h
index 4dde653..a25ea45 100644
--- a/src/compiler/OutputGLSL.h
+++ b/src/compiler/OutputGLSL.h
@@ -13,12 +13,10 @@
 class TOutputGLSL : public TIntermTraverser
 {
 public:
-    TOutputGLSL(TParseContext &context);
-
-    void header();
+    TOutputGLSL(TInfoSinkBase& objSink);
 
 protected:
-    TInfoSinkBase& objSink() { return parseContext.infoSink.obj; }
+    TInfoSinkBase& objSink() { return mObjSink; }
     void writeTriplet(Visit visit, const char* preStr, const char* inStr, const char* postStr);
 
     virtual void visitSymbol(TIntermSymbol* node);
@@ -31,8 +29,14 @@
     virtual bool visitBranch(Visit visit, TIntermBranch* node);
 
 private:
-    bool writeFullSymbol;
-    TParseContext &parseContext;
+    TInfoSinkBase& mObjSink;
+    bool mWriteFullSymbol;
+
+    // Structs are declared as the tree is traversed. This set contains all
+    // the structs already declared. It is maintained so that a struct is
+    // declared only once.
+    typedef std::set<TString> DeclaredStructs;
+    DeclaredStructs mDeclaredStructs;
 };
 
 #endif  // CROSSCOMPILERGLSL_OUTPUTGLSL_H_
diff --git a/src/compiler/TranslatorGLSL.cpp b/src/compiler/TranslatorGLSL.cpp
index 061543d..6149d72 100644
--- a/src/compiler/TranslatorGLSL.cpp
+++ b/src/compiler/TranslatorGLSL.cpp
@@ -14,10 +14,8 @@
 }
 
 bool TranslatorGLSL::compile(TIntermNode* root) {
-    TParseContext& parseContext = *GetGlobalParseContext();
-    TOutputGLSL outputGLSL(parseContext);
-    outputGLSL.header();
-    parseContext.treeRoot->traverse(&outputGLSL);
+    TOutputGLSL outputGLSL(infoSink.obj);
+    root->traverse(&outputGLSL);
 
     return true;
 }