Merge "Initialize temp var for a return value"
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 906d08c..eaaf1e0 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -211,77 +211,6 @@
   void VisitWhileStmt(clang::WhileStmt *WS);
 };
 
-// Given a return statement RS that returns an rsObject, creates a temporary
-// variable, sets it to the original return expression using rsSetObject(),
-// adds these new statements into NewStmts.
-// Finally, creates and returns a new return statement that returns the
-// temporary variable.
-clang::CompoundStmt* CreateRetStmtWithTempVar(
-    clang::ASTContext& C,
-    clang::DeclContext* DC,
-    clang::ReturnStmt* RS,
-    const unsigned id) {
-  std::list<clang::Stmt*> NewStmts;
-  // Since we insert rsClearObj() calls before the return statement, we need
-  // to make sure none of the cleared RS objects are referenced in the
-  // return statement.
-  // For that, we create a new local variable named .rs.retval, assign the
-  // original return expression to it, make all necessary rsClearObj()
-  // calls, then return .rs.retval. Note rsClearObj() is not called on
-  // .rs.retval.
-
-  clang::SourceLocation Loc = RS->getLocStart();
-  std::stringstream ss;
-  ss << ".rs.retval" << id;
-  llvm::StringRef VarName(ss.str());
-
-  clang::Expr* RetVal = RS->getRetValue();
-  const clang::QualType RetTy = RetVal->getType();
-  clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
-      C,                                     // AST context
-      DC,                                    // Decl context
-      Loc,                                   // Start location
-      Loc,                                   // Id location
-      &C.Idents.get(VarName),                // Id
-      RetTy,                                 // Type
-      C.getTrivialTypeSourceInfo(RetTy),     // Type info
-      clang::SC_None                         // Storage class
-  );
-  clang::Decl* Decls[] = { RSRetValDecl };
-  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
-      C, Decls, sizeof(Decls) / sizeof(*Decls));
-  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
-  NewStmts.push_back(DS);
-
-  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
-      C,
-      clang::NestedNameSpecifierLoc(),       // QualifierLoc
-      Loc,                                   // TemplateKWLoc
-      RSRetValDecl,
-      false,                                 // RefersToEnclosingVariableOrCapture
-      Loc,                                   // NameLoc
-      RetTy,
-      clang::VK_LValue
-  );
-  clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
-  NewStmts.push_back(SetRetTempVar);
-
-  // Creates a new return statement
-  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
-  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
-      C,
-      RetTy,
-      clang::CK_LValueToRValue,
-      DRE,
-      nullptr,
-      clang::VK_RValue
-  );
-  NewRet->setRetValue(CastExpr);
-  NewStmts.push_back(NewRet);
-
-  return BuildCompoundStmt(C, NewStmts, Loc);
-}
-
 void DestructorVisitor::VisitStmt(clang::Stmt *S) {
   for (clang::Stmt* Child : S->children()) {
     if (Child) {
@@ -1498,6 +1427,78 @@
   return Res;
 }
 
+clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
+    clang::ASTContext& C,
+    clang::DeclContext* DC,
+    clang::ReturnStmt* RS,
+    const unsigned id) {
+  std::list<clang::Stmt*> NewStmts;
+  // Since we insert rsClearObj() calls before the return statement, we need
+  // to make sure none of the cleared RS objects are referenced in the
+  // return statement.
+  // For that, we create a new local variable named .rs.retval, assign the
+  // original return expression to it, make all necessary rsClearObj()
+  // calls, then return .rs.retval. Note rsClearObj() is not called on
+  // .rs.retval.
+
+  clang::SourceLocation Loc = RS->getLocStart();
+  std::stringstream ss;
+  ss << ".rs.retval" << id;
+  llvm::StringRef VarName(ss.str());
+
+  clang::Expr* RetVal = RS->getRetValue();
+  const clang::QualType RetTy = RetVal->getType();
+  clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
+      C,                                     // AST context
+      DC,                                    // Decl context
+      Loc,                                   // Start location
+      Loc,                                   // Id location
+      &C.Idents.get(VarName),                // Id
+      RetTy,                                 // Type
+      C.getTrivialTypeSourceInfo(RetTy),     // Type info
+      clang::SC_None                         // Storage class
+  );
+  const clang::Type *T = RetTy.getTypePtr();
+  DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
+  clang::Expr *ZeroInitializer =
+      RSObjectRefCount::CreateZeroInitializerForRSSpecificType(DT, C, Loc);
+  ZeroInitializer->setType(T->getCanonicalTypeInternal());
+  RSRetValDecl->setInit(ZeroInitializer);
+  clang::Decl* Decls[] = { RSRetValDecl };
+  const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
+      C, Decls, sizeof(Decls) / sizeof(*Decls));
+  clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
+  NewStmts.push_back(DS);
+
+  clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
+      C,
+      clang::NestedNameSpecifierLoc(),       // QualifierLoc
+      Loc,                                   // TemplateKWLoc
+      RSRetValDecl,
+      false,                                 // RefersToEnclosingVariableOrCapture
+      Loc,                                   // NameLoc
+      RetTy,
+      clang::VK_LValue
+  );
+  clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
+  NewStmts.push_back(SetRetTempVar);
+
+  // Creates a new return statement
+  clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
+  clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
+      C,
+      RetTy,
+      clang::CK_LValueToRValue,
+      DRE,
+      nullptr,
+      clang::VK_RValue
+  );
+  NewRet->setRetValue(CastExpr);
+  NewStmts.push_back(NewRet);
+
+  return BuildCompoundStmt(C, NewStmts, Loc);
+}
+
 void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
   VisitStmt(DS);
   getCurrentScope()->setCurrentStmt(DS);
diff --git a/slang_rs_object_ref_count.h b/slang_rs_object_ref_count.h
index 5b49c34..88a7a49 100644
--- a/slang_rs_object_ref_count.h
+++ b/slang_rs_object_ref_count.h
@@ -138,6 +138,17 @@
       clang::ASTContext &C,
       const clang::SourceLocation &Loc);
 
+  // Given a return statement RS that returns an rsObject, creates a temporary
+  // variable, and sets it to the original return expression using rsSetObject().
+  // Creates a new return statement that returns the temporary variable.
+  // Returns a new compound statement that contains the new variable declaration,
+  // the rsSetOjbect() call, and the new return statement.
+  static clang::CompoundStmt* CreateRetStmtWithTempVar(
+      clang::ASTContext& C,
+      clang::DeclContext* DC,
+      clang::ReturnStmt* RS,
+      const unsigned id);
+
  public:
   explicit RSObjectRefCount(clang::ASTContext &C)
       : mCtx(C), RSInitFD(false), mTempID(0) {