release-request-160c4b31-7fa0-4e2b-aabe-85380836a1ce-for-git_oc-release-4129081 snap-temp-L15300000077039010

Change-Id: I869a498990cb2879d9c6a89536e9213c2810d32f
diff --git a/lit-tests/P_ref_count/ref_count.rs b/lit-tests/P_ref_count/ref_count.rs
index 081d6d1..7f6dae8 100644
--- a/lit-tests/P_ref_count/ref_count.rs
+++ b/lit-tests/P_ref_count/ref_count.rs
@@ -35,5 +35,3 @@
     rsDebug("good objects", 0);
   }
 }
-
-
diff --git a/lit-tests/P_ref_count/ref_count2.rs b/lit-tests/P_ref_count/ref_count2.rs
new file mode 100644
index 0000000..8ac2d81
--- /dev/null
+++ b/lit-tests/P_ref_count/ref_count2.rs
@@ -0,0 +1,17 @@
+// RUN: %Slang %s
+// RUN: %rs-filecheck-wrapper %s
+
+#pragma version(1)
+#pragma rs java_package_name(ref_count2)
+
+// CHECK: %[[RETVAL:[A-Za-z][A-Za-z0-9]*]] = call i32 @_Z18rsGetElementAt_int13rs_allocationj{{.*}}
+// CHECK: call void @_Z13rsClearObjectP13rs_allocation(%struct.rs_allocation{{.*}}* {{.*}})
+// CHECK: ret i32 %[[RETVAL]]
+static int goo(rs_allocation a) {
+  return rsGetElementAt_int(a, 0);
+}
+
+void entrypoint() {
+  rs_allocation a = rsCreateAllocation_int(100);
+  rsDebug("val at 0:", goo(a));
+}
diff --git a/slang_rs_object_ref_count.cpp b/slang_rs_object_ref_count.cpp
index 2471473..e1050f4 100644
--- a/slang_rs_object_ref_count.cpp
+++ b/slang_rs_object_ref_count.cpp
@@ -20,6 +20,7 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/NestedNameSpecifier.h"
 #include "clang/AST/OperationKinds.h"
+#include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
 
@@ -1389,7 +1390,18 @@
   );
 
   clang::Stmt *UpdatedStmt = nullptr;
-  if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
+  if (CountRSObjectTypes(Ty.getTypePtr()) == 0) {
+    // The expression E is not an RS object itself. Instead of calling
+    // rsSetObject(), create an assignment statement to set the value of the
+    // temporary "guard" variable to the expression.
+    // This can happen if called from RSObjectRefCount::VisitReturnStmt(),
+    // when the return expression is not an RS object but references one.
+    UpdatedStmt =
+      new(C) clang::BinaryOperator(DRE, E, clang::BO_Assign, Ty,
+                                   clang::VK_RValue, clang::OK_Ordinary, Loc,
+                                   false);
+
+  } else if (!RSExportPrimitiveType::IsRSObjectType(Ty.getTypePtr())) {
     // By definition, this is a struct assignment if we get here
     UpdatedStmt =
         CreateStructRSSetObject(C, DRE, E, Loc, Loc);
@@ -1641,6 +1653,28 @@
   }
 }
 
+namespace {
+
+class FindRSObjRefVisitor : public clang::RecursiveASTVisitor<FindRSObjRefVisitor> {
+public:
+  explicit FindRSObjRefVisitor() : mRefRSObj(false) {}
+  bool VisitExpr(clang::Expr* Expression) {
+    if (CountRSObjectTypes(Expression->getType().getTypePtr()) > 0) {
+      mRefRSObj = true;
+      // Found a reference to an RS object. Stop the AST traversal.
+      return false;
+    }
+    return true;
+  }
+
+  bool foundRSObjRef() const { return mRefRSObj; }
+
+private:
+  bool mRefRSObj;
+};
+
+}  // anonymous namespace
+
 void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
   getCurrentScope()->setCurrentStmt(RS);
 
@@ -1660,11 +1694,14 @@
     return;
   }
 
-  // If the return statement does not return anything, or if it does not return
+  FindRSObjRefVisitor visitor;
+
+  visitor.TraverseStmt(RS);
+
+  // If the return statement does not return anything, or if it does not reference
   // a rsObject, no need to transform it.
 
-  clang::Expr* RetVal = RS->getRetValue();
-  if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
+  if (!visitor.foundRSObjRef()) {
     return;
   }