Reuse SIRT for C++ references

Change-Id: I8310e55da42f55f7ec60f6b17face436c77a979f
diff --git a/src/class_linker.cc b/src/class_linker.cc
index 7ca0361..44ca80c 100644
--- a/src/class_linker.cc
+++ b/src/class_linker.cc
@@ -22,6 +22,7 @@
 #include "runtime_support.h"
 #include "ScopedLocalRef.h"
 #include "space.h"
+#include "stack_indirect_reference_table.h"
 #include "stl_util.h"
 #include "thread.h"
 #include "UniquePtr.h"
@@ -228,38 +229,38 @@
   CHECK(!init_done_);
 
   // java_lang_Class comes first, it's needed for AllocClass
-  Class* java_lang_Class = down_cast<Class*>(Heap::AllocObject(NULL, sizeof(ClassClass)));
-  CHECK(java_lang_Class != NULL);
-  java_lang_Class->SetClass(java_lang_Class);
+  SirtRef<Class> java_lang_Class(down_cast<Class*>(Heap::AllocObject(NULL, sizeof(ClassClass))));
+  CHECK(java_lang_Class.get() != NULL);
+  java_lang_Class->SetClass(java_lang_Class.get());
   java_lang_Class->SetClassSize(sizeof(ClassClass));
   // AllocClass(Class*) can now be used
 
   // Class[] is used for reflection support.
-  Class* class_array_class = AllocClass(java_lang_Class, sizeof(Class));
-  class_array_class->SetComponentType(java_lang_Class);
+  SirtRef<Class> class_array_class(AllocClass(java_lang_Class.get(), sizeof(Class)));
+  class_array_class->SetComponentType(java_lang_Class.get());
 
   // java_lang_Object comes next so that object_array_class can be created
-  Class* java_lang_Object = AllocClass(java_lang_Class, sizeof(Class));
-  CHECK(java_lang_Object != NULL);
+  SirtRef<Class> java_lang_Object(AllocClass(java_lang_Class.get(), sizeof(Class)));
+  CHECK(java_lang_Object.get() != NULL);
   // backfill Object as the super class of Class
-  java_lang_Class->SetSuperClass(java_lang_Object);
+  java_lang_Class->SetSuperClass(java_lang_Object.get());
   java_lang_Object->SetStatus(Class::kStatusLoaded);
 
   // Object[] next to hold class roots
-  Class* object_array_class = AllocClass(java_lang_Class, sizeof(Class));
-  object_array_class->SetComponentType(java_lang_Object);
+  SirtRef<Class> object_array_class(AllocClass(java_lang_Class.get(), sizeof(Class)));
+  object_array_class->SetComponentType(java_lang_Object.get());
 
   // Setup the char class to be used for char[]
-  Class* char_class = AllocClass(java_lang_Class, sizeof(Class));
+  SirtRef<Class> char_class(AllocClass(java_lang_Class.get(), sizeof(Class)));
 
   // Setup the char[] class to be used for String
-  Class* char_array_class = AllocClass(java_lang_Class, sizeof(Class));
-  char_array_class->SetComponentType(char_class);
-  CharArray::SetArrayClass(char_array_class);
+  SirtRef<Class> char_array_class(AllocClass(java_lang_Class.get(), sizeof(Class)));
+  char_array_class->SetComponentType(char_class.get());
+  CharArray::SetArrayClass(char_array_class.get());
 
   // Setup String
-  Class* java_lang_String = AllocClass(java_lang_Class, sizeof(StringClass));
-  String::SetClass(java_lang_String);
+  SirtRef<Class> java_lang_String(AllocClass(java_lang_Class.get(), sizeof(StringClass)));
+  String::SetClass(java_lang_String.get());
   java_lang_String->SetObjectSize(sizeof(String));
   java_lang_String->SetStatus(Class::kStatusResolved);
 
@@ -273,14 +274,14 @@
 
   // Create storage for root classes, save away our work so far (requires
   // descriptors)
-  class_roots_ = ObjectArray<Class>::Alloc(object_array_class, kClassRootsMax);
+  class_roots_ = ObjectArray<Class>::Alloc(object_array_class.get(), kClassRootsMax);
   CHECK(class_roots_ != NULL);
-  SetClassRoot(kJavaLangClass, java_lang_Class);
-  SetClassRoot(kJavaLangObject, java_lang_Object);
-  SetClassRoot(kClassArrayClass, class_array_class);
-  SetClassRoot(kObjectArrayClass, object_array_class);
-  SetClassRoot(kCharArrayClass, char_array_class);
-  SetClassRoot(kJavaLangString, java_lang_String);
+  SetClassRoot(kJavaLangClass, java_lang_Class.get());
+  SetClassRoot(kJavaLangObject, java_lang_Object.get());
+  SetClassRoot(kClassArrayClass, class_array_class.get());
+  SetClassRoot(kObjectArrayClass, object_array_class.get());
+  SetClassRoot(kCharArrayClass, char_array_class.get());
+  SetClassRoot(kJavaLangString, java_lang_String.get());
 
   // Setup the primitive type classes.
   SetClassRoot(kPrimitiveBoolean, CreatePrimitiveClass("Z", Class::kPrimBoolean));
@@ -297,11 +298,11 @@
   array_iftable_ = AllocObjectArray<InterfaceEntry>(2);
 
   // Create int array type for AllocDexCache (done in AppendToBootClassPath)
-  Class* int_array_class = AllocClass(java_lang_Class, sizeof(Class));
+  SirtRef<Class> int_array_class(AllocClass(java_lang_Class.get(), sizeof(Class)));
   int_array_class->SetDescriptor(intern_table_->InternStrong("[I"));
   int_array_class->SetComponentType(GetClassRoot(kPrimitiveInt));
-  IntArray::SetArrayClass(int_array_class);
-  SetClassRoot(kIntArrayClass, int_array_class);
+  IntArray::SetArrayClass(int_array_class.get());
+  SetClassRoot(kIntArrayClass, int_array_class.get());
 
   // now that these are registered, we can use AllocClass() and AllocObjectArray
 
@@ -316,43 +317,43 @@
   }
 
   // Constructor, Field, and Method are necessary so that FindClass can link members
-  Class* java_lang_reflect_Constructor = AllocClass(java_lang_Class, sizeof(MethodClass));
+  SirtRef<Class> java_lang_reflect_Constructor(AllocClass(java_lang_Class.get(), sizeof(MethodClass)));
   java_lang_reflect_Constructor->SetDescriptor(intern_table_->InternStrong("Ljava/lang/reflect/Constructor;"));
-  CHECK(java_lang_reflect_Constructor != NULL);
+  CHECK(java_lang_reflect_Constructor.get() != NULL);
   java_lang_reflect_Constructor->SetObjectSize(sizeof(Method));
-  SetClassRoot(kJavaLangReflectConstructor, java_lang_reflect_Constructor);
+  SetClassRoot(kJavaLangReflectConstructor, java_lang_reflect_Constructor.get());
   java_lang_reflect_Constructor->SetStatus(Class::kStatusResolved);
 
-  Class* java_lang_reflect_Field = AllocClass(java_lang_Class, sizeof(FieldClass));
-  CHECK(java_lang_reflect_Field != NULL);
+  SirtRef<Class> java_lang_reflect_Field(AllocClass(java_lang_Class.get(), sizeof(FieldClass)));
+  CHECK(java_lang_reflect_Field.get() != NULL);
   java_lang_reflect_Field->SetDescriptor(intern_table_->InternStrong("Ljava/lang/reflect/Field;"));
   java_lang_reflect_Field->SetObjectSize(sizeof(Field));
-  SetClassRoot(kJavaLangReflectField, java_lang_reflect_Field);
+  SetClassRoot(kJavaLangReflectField, java_lang_reflect_Field.get());
   java_lang_reflect_Field->SetStatus(Class::kStatusResolved);
-  Field::SetClass(java_lang_reflect_Field);
+  Field::SetClass(java_lang_reflect_Field.get());
 
-  Class* java_lang_reflect_Method = AllocClass(java_lang_Class, sizeof(MethodClass));
+  SirtRef<Class> java_lang_reflect_Method(AllocClass(java_lang_Class.get(), sizeof(MethodClass)));
   java_lang_reflect_Method->SetDescriptor(intern_table_->InternStrong("Ljava/lang/reflect/Method;"));
-  CHECK(java_lang_reflect_Method != NULL);
+  CHECK(java_lang_reflect_Method.get() != NULL);
   java_lang_reflect_Method->SetObjectSize(sizeof(Method));
-  SetClassRoot(kJavaLangReflectMethod, java_lang_reflect_Method);
+  SetClassRoot(kJavaLangReflectMethod, java_lang_reflect_Method.get());
   java_lang_reflect_Method->SetStatus(Class::kStatusResolved);
-  Method::SetClasses(java_lang_reflect_Constructor, java_lang_reflect_Method);
+  Method::SetClasses(java_lang_reflect_Constructor.get(), java_lang_reflect_Method.get());
 
   // now we can use FindSystemClass
 
   // run char class through InitializePrimitiveClass to finish init
-  InitializePrimitiveClass(char_class, "C", Class::kPrimChar);
-  SetClassRoot(kPrimitiveChar, char_class);  // needs descriptor
+  InitializePrimitiveClass(char_class.get(), "C", Class::kPrimChar);
+  SetClassRoot(kPrimitiveChar, char_class.get());  // needs descriptor
 
   // Object and String need to be rerun through FindSystemClass to finish init
   java_lang_Object->SetStatus(Class::kStatusNotReady);
   Class* Object_class = FindSystemClass("Ljava/lang/Object;");
-  CHECK_EQ(java_lang_Object, Object_class);
+  CHECK_EQ(java_lang_Object.get(), Object_class);
   CHECK_EQ(java_lang_Object->GetObjectSize(), sizeof(Object));
   java_lang_String->SetStatus(Class::kStatusNotReady);
   Class* String_class = FindSystemClass("Ljava/lang/String;");
-  CHECK_EQ(java_lang_String, String_class);
+  CHECK_EQ(java_lang_String.get(), String_class);
   CHECK_EQ(java_lang_String->GetObjectSize(), sizeof(String));
 
   // Setup the primitive array type classes - can't be done until Object has a vtable
@@ -363,13 +364,13 @@
   ByteArray::SetArrayClass(GetClassRoot(kByteArrayClass));
 
   Class* found_char_array_class = FindSystemClass("[C");
-  CHECK_EQ(char_array_class, found_char_array_class);
+  CHECK_EQ(char_array_class.get(), found_char_array_class);
 
   SetClassRoot(kShortArrayClass, FindSystemClass("[S"));
   ShortArray::SetArrayClass(GetClassRoot(kShortArrayClass));
 
   Class* found_int_array_class = FindSystemClass("[I");
-  CHECK_EQ(int_array_class, found_int_array_class);
+  CHECK_EQ(int_array_class.get(), found_int_array_class);
 
   SetClassRoot(kLongArrayClass, FindSystemClass("[J"));
   LongArray::SetArrayClass(GetClassRoot(kLongArrayClass));
@@ -381,10 +382,10 @@
   DoubleArray::SetArrayClass(GetClassRoot(kDoubleArrayClass));
 
   Class* found_class_array_class = FindSystemClass("[Ljava/lang/Class;");
-  CHECK_EQ(class_array_class, found_class_array_class);
+  CHECK_EQ(class_array_class.get(), found_class_array_class);
 
   Class* found_object_array_class = FindSystemClass("[Ljava/lang/Object;");
-  CHECK_EQ(object_array_class, found_object_array_class);
+  CHECK_EQ(object_array_class.get(), found_object_array_class);
 
   // Setup the single, global copies of "interfaces" and "iftable"
   Class* java_lang_Cloneable = FindSystemClass("Ljava/lang/Cloneable;");
@@ -409,19 +410,19 @@
   // run Class, Constructor, Field, and Method through FindSystemClass.
   // this initializes their dex_cache_ fields and register them in classes_.
   Class* Class_class = FindSystemClass("Ljava/lang/Class;");
-  CHECK_EQ(java_lang_Class, Class_class);
+  CHECK_EQ(java_lang_Class.get(), Class_class);
 
   java_lang_reflect_Constructor->SetStatus(Class::kStatusNotReady);
   Class* Constructor_class = FindSystemClass("Ljava/lang/reflect/Constructor;");
-  CHECK_EQ(java_lang_reflect_Constructor, Constructor_class);
+  CHECK_EQ(java_lang_reflect_Constructor.get(), Constructor_class);
 
   java_lang_reflect_Field->SetStatus(Class::kStatusNotReady);
   Class* Field_class = FindSystemClass("Ljava/lang/reflect/Field;");
-  CHECK_EQ(java_lang_reflect_Field, Field_class);
+  CHECK_EQ(java_lang_reflect_Field.get(), Field_class);
 
   java_lang_reflect_Method->SetStatus(Class::kStatusNotReady);
   Class* Method_class = FindSystemClass("Ljava/lang/reflect/Method;");
-  CHECK_EQ(java_lang_reflect_Method, Method_class);
+  CHECK_EQ(java_lang_reflect_Method.get(), Method_class);
 
   // End of special init trickery, subsequent classes may be loaded via FindSystemClass
 
@@ -647,7 +648,7 @@
       CHECK_EQ(oat_file->GetOatHeader().GetDexFileCount(),
                static_cast<uint32_t>(dex_caches->GetLength()));
       for (int i = 0; i < dex_caches->GetLength(); i++) {
-        DexCache* dex_cache = dex_caches->Get(i);
+        SirtRef<DexCache> dex_cache(dex_caches->Get(i));
         const std::string& dex_file_location = dex_cache->GetLocation()->ToModifiedUtf8();
 
         std::string dex_filename;
@@ -741,6 +742,7 @@
   }
 
   visitor(array_interfaces_, arg);
+  visitor(array_iftable_, arg);
 }
 
 ClassLinker::~ClassLinker() {
@@ -762,23 +764,47 @@
 }
 
 DexCache* ClassLinker::AllocDexCache(const DexFile& dex_file) {
-  String* location = intern_table_->InternStrong(dex_file.GetLocation().c_str());
-  if (location == NULL) {
+  SirtRef<DexCache> dex_cache(down_cast<DexCache*>(AllocObjectArray<Object>(DexCache::LengthAsArray())));
+  if (dex_cache.get() == NULL) {
     return NULL;
   }
-  DexCache* dex_cache = down_cast<DexCache*>(AllocObjectArray<Object>(DexCache::LengthAsArray()));
-  if (dex_cache == NULL) {
+  SirtRef<String> location(intern_table_->InternStrong(dex_file.GetLocation().c_str()));
+  if (location.get() == NULL) {
     return NULL;
   }
-  // TODO: lots of missing null checks hidden in this call...
-  dex_cache->Init(location,
-                  AllocObjectArray<String>(dex_file.NumStringIds()),
-                  AllocClassArray(dex_file.NumTypeIds()),
-                  AllocObjectArray<Method>(dex_file.NumMethodIds()),
-                  AllocObjectArray<Field>(dex_file.NumFieldIds()),
-                  AllocCodeAndDirectMethods(dex_file.NumMethodIds()),
-                  AllocObjectArray<StaticStorageBase>(dex_file.NumTypeIds()));
-  return dex_cache;
+  SirtRef<ObjectArray<String> > strings(AllocObjectArray<String>(dex_file.NumStringIds()));
+  if (strings.get() == NULL) {
+    return NULL;
+  }
+  SirtRef<ObjectArray<Class> > types(AllocClassArray(dex_file.NumTypeIds()));
+  if (types.get() == NULL) {
+    return NULL;
+  }
+  SirtRef<ObjectArray<Method> > methods(AllocObjectArray<Method>(dex_file.NumMethodIds()));
+  if (methods.get() == NULL) {
+    return NULL;
+  }
+  SirtRef<ObjectArray<Field> > fields(AllocObjectArray<Field>(dex_file.NumFieldIds()));
+  if (fields.get() == NULL) {
+    return NULL;
+  }
+  SirtRef<CodeAndDirectMethods> code_and_direct_methods(AllocCodeAndDirectMethods(dex_file.NumMethodIds()));
+  if (code_and_direct_methods.get() == NULL) {
+    return NULL;
+  }
+  SirtRef<ObjectArray<StaticStorageBase> > initialized_static_storage(AllocObjectArray<StaticStorageBase>(dex_file.NumTypeIds()));
+  if (initialized_static_storage.get() == NULL) {
+    return NULL;
+  }
+
+  dex_cache->Init(location.get(),
+                  strings.get(),
+                  types.get(),
+                  methods.get(),
+                  fields.get(),
+                  code_and_direct_methods.get(),
+                  initialized_static_storage.get());
+  return dex_cache.get();
 }
 
 CodeAndDirectMethods* ClassLinker::AllocCodeAndDirectMethods(size_t length) {
@@ -787,18 +813,18 @@
 
 InterfaceEntry* ClassLinker::AllocInterfaceEntry(Class* interface) {
   DCHECK(interface->IsInterface());
-  ObjectArray<Object>* array = AllocObjectArray<Object>(InterfaceEntry::LengthAsArray());
-  InterfaceEntry* interface_entry = down_cast<InterfaceEntry*>(array);
+  SirtRef<ObjectArray<Object> > array(AllocObjectArray<Object>(InterfaceEntry::LengthAsArray()));
+  SirtRef<InterfaceEntry> interface_entry(down_cast<InterfaceEntry*>(array.get()));
   interface_entry->SetInterface(interface);
-  return interface_entry;
+  return interface_entry.get();
 }
 
 Class* ClassLinker::AllocClass(Class* java_lang_Class, size_t class_size) {
   DCHECK_GE(class_size, sizeof(Class));
-  Class* klass = Heap::AllocObject(java_lang_Class, class_size)->AsClass();
+  SirtRef<Class> klass(Heap::AllocObject(java_lang_Class, class_size)->AsClass());
   klass->SetPrimitiveType(Class::kPrimNot);  // default to not being primitive
   klass->SetClassSize(class_size);
-  return klass;
+  return klass.get();
 }
 
 Class* ClassLinker::AllocClass(size_t class_size) {
@@ -907,27 +933,27 @@
                                 const ClassLoader* class_loader,
                                 const DexFile& dex_file,
                                 const DexFile::ClassDef& dex_class_def) {
-  Class* klass;
+  SirtRef<Class> klass(NULL);
   // Load the class from the dex file.
   if (!init_done_) {
     // finish up init of hand crafted class_roots_
     if (descriptor == "Ljava/lang/Object;") {
-      klass = GetClassRoot(kJavaLangObject);
+      klass.reset(GetClassRoot(kJavaLangObject));
     } else if (descriptor == "Ljava/lang/Class;") {
-      klass = GetClassRoot(kJavaLangClass);
+      klass.reset(GetClassRoot(kJavaLangClass));
     } else if (descriptor == "Ljava/lang/String;") {
-      klass = GetClassRoot(kJavaLangString);
+      klass.reset(GetClassRoot(kJavaLangString));
     } else if (descriptor == "Ljava/lang/reflect/Constructor;") {
-      klass = GetClassRoot(kJavaLangReflectConstructor);
+      klass.reset(GetClassRoot(kJavaLangReflectConstructor));
     } else if (descriptor == "Ljava/lang/reflect/Field;") {
-      klass = GetClassRoot(kJavaLangReflectField);
+      klass.reset(GetClassRoot(kJavaLangReflectField));
     } else if (descriptor == "Ljava/lang/reflect/Method;") {
-      klass = GetClassRoot(kJavaLangReflectMethod);
+      klass.reset(GetClassRoot(kJavaLangReflectMethod));
     } else {
-      klass = AllocClass(SizeOfClass(dex_file, dex_class_def));
+      klass.reset(AllocClass(SizeOfClass(dex_file, dex_class_def)));
     }
   } else {
-    klass = AllocClass(SizeOfClass(dex_file, dex_class_def));
+    klass.reset(AllocClass(SizeOfClass(dex_file, dex_class_def)));
   }
   klass->SetDexCache(FindDexCache(dex_file));
   LoadClass(dex_file, dex_class_def, klass, class_loader);
@@ -936,16 +962,16 @@
   if (self->IsExceptionPending()) {
     return NULL;
   }
-  ObjectLock lock(klass);
+  ObjectLock lock(klass.get());
   klass->SetClinitThreadId(self->GetTid());
   // Add the newly loaded class to the loaded classes table.
-  bool success = InsertClass(descriptor, klass);  // TODO: just return collision
+  bool success = InsertClass(descriptor, klass.get());  // TODO: just return collision
   if (!success) {
     // We may fail to insert if we raced with another thread.
     klass->SetClinitThreadId(0);
-    klass = LookupClass(descriptor, class_loader);
-    CHECK(klass != NULL);
-    return klass;
+    klass.reset(LookupClass(descriptor, class_loader));
+    CHECK(klass.get() != NULL);
+    return klass.get();
   }
   // Finish loading (if necessary) by finding parents
   CHECK(!klass->IsLoaded());
@@ -965,7 +991,7 @@
     return NULL;
   }
   CHECK(klass->IsResolved());
-  return klass;
+  return klass.get();
 }
 
 // Precomputes size that will be needed for Class, matching LinkStaticFields
@@ -1015,11 +1041,11 @@
   return size;
 }
 
-void LinkCode(Method* method, const OatFile::OatClass* oat_class, uint32_t method_index) {
+void LinkCode(SirtRef<Method>& method, const OatFile::OatClass* oat_class, uint32_t method_index) {
   // Every kind of method should at least get an invoke stub from the oat_method.
   // non-abstract methods also get their code pointers.
   const OatFile::OatMethod oat_method = oat_class->GetOatMethod(method_index);
-  oat_method.LinkMethod(method);
+  oat_method.LinkMethod(method.get());
 
   if (method->IsAbstract()) {
     method->SetCode(Runtime::Current()->GetAbstractMethodErrorStubArray()->GetData());
@@ -1034,9 +1060,9 @@
 
 void ClassLinker::LoadClass(const DexFile& dex_file,
                             const DexFile::ClassDef& dex_class_def,
-                            Class* klass,
+                            SirtRef<Class>& klass,
                             const ClassLoader* class_loader) {
-  CHECK(klass != NULL);
+  CHECK(klass.get() != NULL);
   CHECK(klass->GetDexCache() != NULL);
   CHECK_EQ(Class::kStatusNotReady, klass->GetStatus());
   const byte* class_data = dex_file.GetClassData(dex_class_def);
@@ -1088,8 +1114,8 @@
     for (size_t i = 0; i < num_static_fields; ++i) {
       DexFile::Field dex_field;
       dex_file.dexReadClassDataField(&class_data, &dex_field, &last_idx);
-      Field* sfield = AllocField();
-      klass->SetStaticField(i, sfield);
+      SirtRef<Field> sfield(AllocField());
+      klass->SetStaticField(i, sfield.get());
       LoadField(dex_file, dex_field, klass, sfield);
     }
   }
@@ -1101,8 +1127,8 @@
     for (size_t i = 0; i < num_instance_fields; ++i) {
       DexFile::Field dex_field;
       dex_file.dexReadClassDataField(&class_data, &dex_field, &last_idx);
-      Field* ifield = AllocField();
-      klass->SetInstanceField(i, ifield);
+      SirtRef<Field> ifield(AllocField());
+      klass->SetInstanceField(i, ifield.get());
       LoadField(dex_file, dex_field, klass, ifield);
     }
   }
@@ -1131,8 +1157,8 @@
     for (size_t i = 0; i < num_direct_methods; ++i, ++method_index) {
       DexFile::Method dex_method;
       dex_file.dexReadClassDataMethod(&class_data, &dex_method, &last_idx);
-      Method* method = AllocMethod();
-      klass->SetDirectMethod(i, method);
+      SirtRef<Method> method(AllocMethod());
+      klass->SetDirectMethod(i, method.get());
       LoadMethod(dex_file, dex_method, klass, method);
       if (oat_class.get() != NULL) {
         LinkCode(method, oat_class.get(), method_index);
@@ -1148,8 +1174,8 @@
     for (size_t i = 0; i < num_virtual_methods; ++i, ++method_index) {
       DexFile::Method dex_method;
       dex_file.dexReadClassDataMethod(&class_data, &dex_method, &last_idx);
-      Method* method = AllocMethod();
-      klass->SetVirtualMethod(i, method);
+      SirtRef<Method> method(AllocMethod());
+      klass->SetVirtualMethod(i, method.get());
       LoadMethod(dex_file, dex_method, klass, method);
       if (oat_class.get() != NULL) {
         LinkCode(method, oat_class.get(), method_index);
@@ -1160,7 +1186,7 @@
 
 void ClassLinker::LoadInterfaces(const DexFile& dex_file,
                                  const DexFile::ClassDef& dex_class_def,
-                                 Class* klass) {
+                                 SirtRef<Class>& klass) {
   const DexFile::TypeList* list = dex_file.GetInterfacesList(dex_class_def);
   if (list != NULL) {
     klass->SetInterfaces(AllocClassArray(list->Size()));
@@ -1175,10 +1201,10 @@
 
 void ClassLinker::LoadField(const DexFile& dex_file,
                             const DexFile::Field& src,
-                            Class* klass,
-                            Field* dst) {
+                            SirtRef<Class>& klass,
+                            SirtRef<Field>& dst) {
   const DexFile::FieldId& field_id = dex_file.GetFieldId(src.field_idx_);
-  dst->SetDeclaringClass(klass);
+  dst->SetDeclaringClass(klass.get());
   dst->SetName(ResolveString(dex_file, field_id.name_idx_, klass->GetDexCache()));
   dst->SetTypeIdx(field_id.type_idx_);
   dst->SetAccessFlags(src.access_flags_);
@@ -1188,17 +1214,17 @@
   const char* descriptor = dex_file.dexStringByTypeIdx(field_id.type_idx_);
   if (descriptor[1] == '\0') {
     // only the descriptors of primitive types should be 1 character long
-    Class* resolved = ResolveType(dex_file, field_id.type_idx_, klass);
+    Class* resolved = ResolveType(dex_file, field_id.type_idx_, klass.get());
     DCHECK(resolved->IsPrimitive());
   }
 }
 
 void ClassLinker::LoadMethod(const DexFile& dex_file,
                              const DexFile::Method& src,
-                             Class* klass,
-                             Method* dst) {
+                             SirtRef<Class>& klass,
+                             SirtRef<Method>& dst) {
   const DexFile::MethodId& method_id = dex_file.GetMethodId(src.method_idx_);
-  dst->SetDeclaringClass(klass);
+  dst->SetDeclaringClass(klass.get());
 
   String* method_name = ResolveString(dex_file, method_id.name_idx_, klass->GetDexCache());
   if (method_name == NULL) {
@@ -1268,11 +1294,12 @@
 }
 
 void ClassLinker::AppendToBootClassPath(const DexFile& dex_file) {
-  AppendToBootClassPath(dex_file, AllocDexCache(dex_file));
+  SirtRef<DexCache> dex_cache(AllocDexCache(dex_file));
+  AppendToBootClassPath(dex_file, dex_cache);
 }
 
-void ClassLinker::AppendToBootClassPath(const DexFile& dex_file, DexCache* dex_cache) {
-  CHECK(dex_cache != NULL) << dex_file.GetLocation();
+void ClassLinker::AppendToBootClassPath(const DexFile& dex_file, SirtRef<DexCache>& dex_cache) {
+  CHECK(dex_cache.get() != NULL) << dex_file.GetLocation();
   boot_class_path_.push_back(&dex_file);
   RegisterDexFile(dex_file, dex_cache);
 }
@@ -1292,12 +1319,12 @@
   return IsDexFileRegisteredLocked(dex_file);
 }
 
-void ClassLinker::RegisterDexFileLocked(const DexFile& dex_file, DexCache* dex_cache) {
+void ClassLinker::RegisterDexFileLocked(const DexFile& dex_file, SirtRef<DexCache>& dex_cache) {
   dex_lock_.AssertHeld();
-  CHECK(dex_cache != NULL) << dex_file.GetLocation();
+  CHECK(dex_cache.get() != NULL) << dex_file.GetLocation();
   CHECK(dex_cache->GetLocation()->Equals(dex_file.GetLocation()));
   dex_files_.push_back(&dex_file);
-  dex_caches_.push_back(dex_cache);
+  dex_caches_.push_back(dex_cache.get());
 }
 
 void ClassLinker::RegisterDexFile(const DexFile& dex_file) {
@@ -1310,7 +1337,7 @@
   // Don't alloc while holding the lock, since allocation may need to
   // suspend all threads and another thread may need the dex_lock_ to
   // get to a suspend point.
-  DexCache* dex_cache = AllocDexCache(dex_file);
+  SirtRef<DexCache> dex_cache(AllocDexCache(dex_file));
   {
     MutexLock mu(dex_lock_);
     if (IsDexFileRegisteredLocked(dex_file)) {
@@ -1320,7 +1347,7 @@
   }
 }
 
-void ClassLinker::RegisterDexFile(const DexFile& dex_file, DexCache* dex_cache) {
+void ClassLinker::RegisterDexFile(const DexFile& dex_file, SirtRef<DexCache>& dex_cache) {
   MutexLock mu(dex_lock_);
   RegisterDexFileLocked(dex_file, dex_cache);
 }
@@ -1420,22 +1447,22 @@
   // Array classes are simple enough that we don't need to do a full
   // link step.
 
-  Class* new_class = NULL;
+  SirtRef<Class> new_class(NULL);
   if (!init_done_) {
     // Classes that were hand created, ie not by FindSystemClass
     if (descriptor == "[Ljava/lang/Class;") {
-      new_class = GetClassRoot(kClassArrayClass);
+      new_class.reset(GetClassRoot(kClassArrayClass));
     } else if (descriptor == "[Ljava/lang/Object;") {
-      new_class = GetClassRoot(kObjectArrayClass);
+      new_class.reset(GetClassRoot(kObjectArrayClass));
     } else if (descriptor == "[C") {
-      new_class = GetClassRoot(kCharArrayClass);
+      new_class.reset(GetClassRoot(kCharArrayClass));
     } else if (descriptor == "[I") {
-      new_class = GetClassRoot(kIntArrayClass);
+      new_class.reset(GetClassRoot(kIntArrayClass));
     }
   }
-  if (new_class == NULL) {
-    new_class = AllocClass(sizeof(Class));
-    if (new_class == NULL) {
+  if (new_class.get() == NULL) {
+    new_class.reset(AllocClass(sizeof(Class)));
+    if (new_class.get() == NULL) {
       return NULL;
     }
     new_class->SetComponentType(component_type);
@@ -1486,8 +1513,8 @@
   new_class->SetAccessFlags(((new_class->GetComponentType()->GetAccessFlags() &
                              ~kAccInterface) | kAccFinal) & kAccJavaFlagsMask);
 
-  if (InsertClass(descriptor, new_class)) {
-    return new_class;
+  if (InsertClass(descriptor, new_class.get())) {
+    return new_class.get();
   }
   // Another thread must have loaded the class after we
   // started but before we finished.  Abandon what we've
@@ -1570,8 +1597,8 @@
 
 Class* ClassLinker::CreateProxyClass(String* name, ObjectArray<Class>* interfaces,
     ClassLoader* loader, ObjectArray<Method>* methods, ObjectArray<ObjectArray<Class> >* throws) {
-  Class* klass = AllocClass(GetClassRoot(kJavaLangClass), sizeof(ProxyClass));
-  CHECK(klass != NULL);
+  SirtRef<Class> klass(AllocClass(GetClassRoot(kJavaLangClass), sizeof(ProxyClass)));
+  CHECK(klass.get() != NULL);
   klass->SetObjectSize(sizeof(Proxy));
   const char* descriptor = DotToDescriptor(name->ToModifiedUtf8().c_str()).c_str();;
   klass->SetDescriptor(intern_table_->InternStrong(descriptor));
@@ -1590,7 +1617,7 @@
   size_t num_virtual_methods = methods->GetLength();
   klass->SetVirtualMethods(AllocObjectArray<Method>(num_virtual_methods));
   for (size_t i = 0; i < num_virtual_methods; ++i) {
-    Method* prototype = methods->Get(i);
+    SirtRef<Method> prototype(methods->Get(i));
     klass->SetVirtualMethod(i, CreateProxyMethod(klass, prototype, throws->Get(i)));
   }
   // Link the virtual methods, creating vtable and iftables
@@ -1598,10 +1625,10 @@
     DCHECK(Thread::Current()->IsExceptionPending());
     return NULL;
   }
-  return klass;
+  return klass.get();
 }
 
-Method* ClassLinker::CreateProxyConstructor(Class* klass) {
+Method* ClassLinker::CreateProxyConstructor(SirtRef<Class>& klass) {
   // Create constructor for Proxy that must initialize h
   Class* proxy_class = GetClassRoot(kJavaLangReflectProxy);
   ObjectArray<Method>* proxy_direct_methods = proxy_class->GetDirectMethods();
@@ -1612,7 +1639,7 @@
   Method* constructor = down_cast<Method*>(proxy_constructor->Clone());
   // Make this constructor public and fix the class to be our Proxy version
   constructor->SetAccessFlags((constructor->GetAccessFlags() & ~kAccProtected) | kAccPublic);
-  constructor->SetDeclaringClass(klass);
+  constructor->SetDeclaringClass(klass.get());
   // Sanity checks
   CHECK(constructor->IsConstructor());
   CHECK(constructor->GetName()->Equals("<init>"));
@@ -1621,7 +1648,7 @@
   return constructor;
 }
 
-Method* ClassLinker::CreateProxyMethod(Class* klass, Method* prototype,
+Method* ClassLinker::CreateProxyMethod(SirtRef<Class>& klass, SirtRef<Method>& prototype,
                                        ObjectArray<Class>* throws) {
   // We steal everything from the prototype (such as DexCache, invoke stub, etc.) then specialise
   // as necessary
@@ -1629,7 +1656,7 @@
 
   // Set class to be the concrete proxy class and clear the abstract flag, modify exceptions to
   // the intersection of throw exceptions as defined in Proxy
-  method->SetDeclaringClass(klass);
+  method->SetDeclaringClass(klass.get());
   method->SetAccessFlags((method->GetAccessFlags() & ~kAccAbstract) | kAccFinal);
   method->SetExceptionTypes(throws);
 
@@ -1993,7 +2020,7 @@
   }
 }
 
-bool ClassLinker::LinkClass(Class* klass) {
+bool ClassLinker::LinkClass(SirtRef<Class>& klass) {
   CHECK_EQ(Class::kStatusLoaded, klass->GetStatus());
   if (!LinkSuperClass(klass)) {
     return false;
@@ -2014,10 +2041,10 @@
   return true;
 }
 
-bool ClassLinker::LoadSuperAndInterfaces(Class* klass, const DexFile& dex_file) {
+bool ClassLinker::LoadSuperAndInterfaces(SirtRef<Class>& klass, const DexFile& dex_file) {
   CHECK_EQ(Class::kStatusIdx, klass->GetStatus());
   if (klass->GetSuperClassTypeIdx() != DexFile::kDexNoIndex) {
-    Class* super_class = ResolveType(dex_file, klass->GetSuperClassTypeIdx(), klass);
+    Class* super_class = ResolveType(dex_file, klass->GetSuperClassTypeIdx(), klass.get());
     if (super_class == NULL) {
       DCHECK(Thread::Current()->IsExceptionPending());
       return false;
@@ -2026,7 +2053,7 @@
   }
   for (size_t i = 0; i < klass->NumInterfaces(); ++i) {
     uint32_t idx = klass->GetInterfacesTypeIdx()->Get(i);
-    Class* interface = ResolveType(dex_file, idx, klass);
+    Class* interface = ResolveType(dex_file, idx, klass.get());
     klass->SetInterface(i, interface);
     if (interface == NULL) {
       DCHECK(Thread::Current()->IsExceptionPending());
@@ -2047,7 +2074,7 @@
   return true;
 }
 
-bool ClassLinker::LinkSuperClass(Class* klass) {
+bool ClassLinker::LinkSuperClass(SirtRef<Class>& klass) {
   CHECK(!klass->IsPrimitive());
   Class* super = klass->GetSuperClass();
   if (klass->GetDescriptor()->Equals("Ljava/lang/Object;")) {
@@ -2108,7 +2135,7 @@
 }
 
 // Populate the class vtable and itable. Compute return type indices.
-bool ClassLinker::LinkMethods(Class* klass) {
+bool ClassLinker::LinkMethods(SirtRef<Class>& klass) {
   if (klass->IsInterface()) {
     // No vtable.
     size_t count = klass->NumVirtualMethods();
@@ -2128,7 +2155,7 @@
   return true;
 }
 
-bool ClassLinker::LinkVirtualMethods(Class* klass) {
+bool ClassLinker::LinkVirtualMethods(SirtRef<Class>& klass) {
   if (klass->HasSuperClass()) {
     uint32_t max_count = klass->NumVirtualMethods() + klass->GetSuperClass()->GetVTable()->GetLength();
     size_t actual_count = klass->GetSuperClass()->GetVTable()->GetLength();
@@ -2179,18 +2206,18 @@
       ThrowClassFormatError("Too many methods: %d", num_virtual_methods);
       return false;
     }
-    ObjectArray<Method>* vtable = AllocObjectArray<Method>(num_virtual_methods);
+    SirtRef<ObjectArray<Method> > vtable(AllocObjectArray<Method>(num_virtual_methods));
     for (size_t i = 0; i < num_virtual_methods; ++i) {
       Method* virtual_method = klass->GetVirtualMethodDuringLinking(i);
       vtable->Set(i, virtual_method);
       virtual_method->SetMethodIndex(i & 0xFFFF);
     }
-    klass->SetVTable(vtable);
+    klass->SetVTable(vtable.get());
   }
   return true;
 }
 
-bool ClassLinker::LinkInterfaceMethods(Class* klass) {
+bool ClassLinker::LinkInterfaceMethods(SirtRef<Class>& klass) {
   size_t super_ifcount;
   if (klass->HasSuperClass()) {
     super_ifcount = klass->GetSuperClass()->GetIfTableCount();
@@ -2208,7 +2235,7 @@
     // DCHECK(klass->GetIfTable() == NULL);
     return true;
   }
-  ObjectArray<InterfaceEntry>* iftable = AllocObjectArray<InterfaceEntry>(ifcount);
+  SirtRef<ObjectArray<InterfaceEntry> > iftable(AllocObjectArray<InterfaceEntry>(ifcount));
   if (super_ifcount != 0) {
     ObjectArray<InterfaceEntry>* super_iftable = klass->GetSuperClass()->GetIfTable();
     for (size_t i = 0; i < super_ifcount; i++) {
@@ -2234,7 +2261,7 @@
       iftable->Set(idx++, AllocInterfaceEntry(interface->GetIfTable()->Get(j)->GetInterface()));
     }
   }
-  klass->SetIfTable(iftable);
+  klass->SetIfTable(iftable.get());
   CHECK_EQ(idx, ifcount);
 
   // If we're an interface, we don't need the vtable pointers, so we're done.
@@ -2272,20 +2299,20 @@
         }
       }
       if (k < 0) {
-        Method* miranda_method = NULL;
+        SirtRef<Method> miranda_method(NULL);
         for (size_t mir = 0; mir < miranda_list.size(); mir++) {
           if (miranda_list[mir]->HasSameNameAndSignature(interface_method)) {
-            miranda_method = miranda_list[mir];
+            miranda_method.reset(miranda_list[mir]);
             break;
           }
         }
-        if (miranda_method == NULL) {
+        if (miranda_method.get() == NULL) {
           // point the interface table at a phantom slot
-          miranda_method = AllocMethod();
-          memcpy(miranda_method, interface_method, sizeof(Method));
-          miranda_list.push_back(miranda_method);
+          miranda_method.reset(AllocMethod());
+          memcpy(miranda_method.get(), interface_method, sizeof(Method));
+          miranda_list.push_back(miranda_method.get());
         }
-        method_array->Set(j, miranda_method);
+        method_array->Set(j, miranda_method.get());
       }
     }
   }
@@ -2303,7 +2330,7 @@
     vtable = vtable->CopyOf(new_vtable_count);
     for (size_t i = 0; i < miranda_list.size(); ++i) {
       Method* method = miranda_list[i];
-      method->SetDeclaringClass(klass);
+      method->SetDeclaringClass(klass.get());
       method->SetAccessFlags(method->GetAccessFlags() | kAccMiranda);
       method->SetMethodIndex(0xFFFF & (old_vtable_count + i));
       klass->SetVirtualMethod(old_method_count + i, method);
@@ -2323,13 +2350,13 @@
   return true;
 }
 
-bool ClassLinker::LinkInstanceFields(Class* klass) {
-  CHECK(klass != NULL);
+bool ClassLinker::LinkInstanceFields(SirtRef<Class>& klass) {
+  CHECK(klass.get() != NULL);
   return LinkFields(klass, false);
 }
 
-bool ClassLinker::LinkStaticFields(Class* klass) {
-  CHECK(klass != NULL);
+bool ClassLinker::LinkStaticFields(SirtRef<Class>& klass) {
+  CHECK(klass.get() != NULL);
   size_t allocated_class_size = klass->GetClassSize();
   bool success = LinkFields(klass, true);
   CHECK_EQ(allocated_class_size, klass->GetClassSize());
@@ -2358,7 +2385,7 @@
   }
 };
 
-bool ClassLinker::LinkFields(Class* klass, bool is_static) {
+bool ClassLinker::LinkFields(SirtRef<Class>& klass, bool is_static) {
   size_t num_fields =
       is_static ? klass->NumStaticFields() : klass->NumInstanceFields();
 
@@ -2467,7 +2494,7 @@
     Field* field = fields->Get(i);
     if (false) {  // enable to debug field layout
       LOG(INFO) << "LinkFields: " << (is_static ? "static" : "instance")
-                << " class=" << PrettyClass(klass)
+                << " class=" << PrettyClass(klass.get())
                 << " field=" << PrettyField(field)
                 << " offset=" << field->GetField32(MemberOffset(Field::OffsetOffset()), false);
     }
@@ -2505,7 +2532,7 @@
 
 //  Set the bitmap of reference offsets, refOffsets, from the ifields
 //  list.
-void ClassLinker::CreateReferenceInstanceOffsets(Class* klass) {
+void ClassLinker::CreateReferenceInstanceOffsets(SirtRef<Class>& klass) {
   uint32_t reference_offsets = 0;
   Class* super_class = klass->GetSuperClass();
   if (super_class != NULL) {
@@ -2519,11 +2546,11 @@
   CreateReferenceOffsets(klass, false, reference_offsets);
 }
 
-void ClassLinker::CreateReferenceStaticOffsets(Class* klass) {
+void ClassLinker::CreateReferenceStaticOffsets(SirtRef<Class>& klass) {
   CreateReferenceOffsets(klass, true, 0);
 }
 
-void ClassLinker::CreateReferenceOffsets(Class* klass, bool is_static,
+void ClassLinker::CreateReferenceOffsets(SirtRef<Class>& klass, bool is_static,
                                          uint32_t reference_offsets) {
   size_t num_reference_fields =
       is_static ? klass->NumReferenceStaticFieldsDuringLinking()
diff --git a/src/class_linker.h b/src/class_linker.h
index f14b770..12d7b37 100644
--- a/src/class_linker.h
+++ b/src/class_linker.h
@@ -28,6 +28,7 @@
 #include "mutex.h"
 #include "oat_file.h"
 #include "object.h"
+#include "stack_indirect_reference_table.h"
 #include "unordered_map.h"
 #include "unordered_set.h"
 
@@ -196,7 +197,7 @@
   void RunRootClinits();
 
   void RegisterDexFile(const DexFile& dex_file);
-  void RegisterDexFile(const DexFile& dex_file, DexCache* dex_cache);
+  void RegisterDexFile(const DexFile& dex_file, SirtRef<DexCache>& dex_cache);
 
   const std::vector<const DexFile*>& GetBootClassPath() {
     return boot_class_path_;
@@ -274,7 +275,7 @@
                           const ClassLoader* class_loader);
 
   void AppendToBootClassPath(const DexFile& dex_file);
-  void AppendToBootClassPath(const DexFile& dex_file, DexCache* dex_cache);
+  void AppendToBootClassPath(const DexFile& dex_file, SirtRef<DexCache>& dex_cache);
 
   void ConstructFieldMap(const DexFile& dex_file, const DexFile::ClassDef& dex_class_def,
       Class* c, std::map<int, Field*>& field_map);
@@ -284,28 +285,28 @@
 
   void LoadClass(const DexFile& dex_file,
                  const DexFile::ClassDef& dex_class_def,
-                 Class* klass,
+                 SirtRef<Class>& klass,
                  const ClassLoader* class_loader);
 
   void LoadInterfaces(const DexFile& dex_file,
                       const DexFile::ClassDef& dex_class_def,
-                      Class *klass);
+                      SirtRef<Class>& klass);
 
   void LoadField(const DexFile& dex_file,
                  const DexFile::Field& dex_field,
-                 Class* klass,
-                 Field* dst);
+                 SirtRef<Class>& klass,
+                 SirtRef<Field>& dst);
 
   void LoadMethod(const DexFile& dex_file,
                   const DexFile::Method& dex_method,
-                  Class* klass,
-                  Method* dst);
+                  SirtRef<Class>& klass,
+                  SirtRef<Method>& dst);
 
   // Inserts a class into the class table.  Returns true if the class
   // was inserted.
   bool InsertClass(const std::string& descriptor, Class* klass);
 
-  void RegisterDexFileLocked(const DexFile& dex_file, DexCache* dex_cache);
+  void RegisterDexFileLocked(const DexFile& dex_file, SirtRef<DexCache>& dex_cache);
   bool IsDexFileRegisteredLocked(const DexFile& dex_file) const;
 
   bool InitializeClass(Class* klass, bool can_run_clinit);
@@ -322,26 +323,26 @@
                                       const Class* klass1,
                                       const Class* klass2);
 
-  bool LinkClass(Class* klass);
+  bool LinkClass(SirtRef<Class>& klass);
 
-  bool LinkSuperClass(Class* klass);
+  bool LinkSuperClass(SirtRef<Class>& klass);
 
-  bool LoadSuperAndInterfaces(Class* klass, const DexFile& dex_file);
+  bool LoadSuperAndInterfaces(SirtRef<Class>& klass, const DexFile& dex_file);
 
-  bool LinkMethods(Class* klass);
+  bool LinkMethods(SirtRef<Class>& klass);
 
-  bool LinkVirtualMethods(Class* klass);
+  bool LinkVirtualMethods(SirtRef<Class>& klass);
 
-  bool LinkInterfaceMethods(Class* klass);
+  bool LinkInterfaceMethods(SirtRef<Class>& klass);
 
-  bool LinkStaticFields(Class* klass);
-  bool LinkInstanceFields(Class* klass);
-  bool LinkFields(Class *klass, bool is_static);
+  bool LinkStaticFields(SirtRef<Class>& klass);
+  bool LinkInstanceFields(SirtRef<Class>& klass);
+  bool LinkFields(SirtRef<Class>& klass, bool is_static);
 
 
-  void CreateReferenceInstanceOffsets(Class* klass);
-  void CreateReferenceStaticOffsets(Class* klass);
-  void CreateReferenceOffsets(Class *klass, bool is_static,
+  void CreateReferenceInstanceOffsets(SirtRef<Class>& klass);
+  void CreateReferenceStaticOffsets(SirtRef<Class>& klass);
+  void CreateReferenceOffsets(SirtRef<Class>& klass, bool is_static,
                               uint32_t reference_offsets);
 
   // For use by ImageWriter to find DexCaches for its roots
@@ -351,8 +352,8 @@
 
   const OatFile* FindOpenedOatFile(const std::string& location);
 
-  Method* CreateProxyConstructor(Class* klass);
-  Method* CreateProxyMethod(Class* klass, Method* prototype, ObjectArray<Class>* throws);
+  Method* CreateProxyConstructor(SirtRef<Class>& klass);
+  Method* CreateProxyMethod(SirtRef<Class>& klass, SirtRef<Method>& prototype, ObjectArray<Class>* throws);
 
   std::vector<const DexFile*> boot_class_path_;
 
diff --git a/src/class_linker_test.cc b/src/class_linker_test.cc
index 75f2740..d1be1de 100644
--- a/src/class_linker_test.cc
+++ b/src/class_linker_test.cc
@@ -144,7 +144,8 @@
 
   void AssertClass(const std::string& descriptor, Class* klass) {
     EXPECT_TRUE(klass->GetDescriptor()->Equals(descriptor));
-    if (klass->GetDescriptor()->Equals(String::AllocFromModifiedUtf8("Ljava/lang/Object;"))) {
+    SirtRef<String> Object_descriptor(String::AllocFromModifiedUtf8("Ljava/lang/Object;"));
+    if (klass->GetDescriptor()->Equals(Object_descriptor.get())) {
       EXPECT_FALSE(klass->HasSuperClass());
     } else {
       EXPECT_TRUE(klass->HasSuperClass());
@@ -661,14 +662,14 @@
 }
 
 TEST_F(ClassLinkerTest, FindClassNested) {
-  const ClassLoader* class_loader = LoadDex("Nested");
+  SirtRef<ClassLoader> class_loader(LoadDex("Nested"));
 
-  Class* outer = class_linker_->FindClass("LNested;", class_loader);
+  Class* outer = class_linker_->FindClass("LNested;", class_loader.get());
   ASSERT_TRUE(outer != NULL);
   EXPECT_EQ(0U, outer->NumVirtualMethods());
   EXPECT_EQ(1U, outer->NumDirectMethods());
 
-  Class* inner = class_linker_->FindClass("LNested$Inner;", class_loader);
+  Class* inner = class_linker_->FindClass("LNested$Inner;", class_loader.get());
   ASSERT_TRUE(inner != NULL);
   EXPECT_EQ(0U, inner->NumVirtualMethods());
   EXPECT_EQ(1U, inner->NumDirectMethods());
@@ -720,9 +721,9 @@
   EXPECT_EQ(0U, JavaLangObject->NumStaticFields());
   EXPECT_EQ(0U, JavaLangObject->NumInterfaces());
 
-  const ClassLoader* class_loader = LoadDex("MyClass");
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClass"));
   AssertNonExistentClass("LMyClass;");
-  Class* MyClass = class_linker_->FindClass("LMyClass;", class_loader);
+  Class* MyClass = class_linker_->FindClass("LMyClass;", class_loader.get());
   ASSERT_TRUE(MyClass != NULL);
   ASSERT_TRUE(MyClass->GetClass() != NULL);
   ASSERT_EQ(MyClass->GetClass(), MyClass->GetClass()->GetClass());
@@ -730,7 +731,7 @@
   ASSERT_TRUE(MyClass->GetDescriptor()->Equals("LMyClass;"));
   EXPECT_TRUE(MyClass->GetSuperClass() == JavaLangObject);
   EXPECT_TRUE(MyClass->HasSuperClass());
-  EXPECT_EQ(class_loader, MyClass->GetClassLoader());
+  EXPECT_EQ(class_loader.get(), MyClass->GetClassLoader());
   EXPECT_EQ(Class::kStatusResolved, MyClass->GetStatus());
   EXPECT_FALSE(MyClass->IsErroneous());
   EXPECT_TRUE(MyClass->IsLoaded());
@@ -758,7 +759,7 @@
   AssertArrayClass("[Ljava/lang/Object;", "Ljava/lang/Object;", NULL);
   // synthesized on the fly
   AssertArrayClass("[[C", "[C", NULL);
-  AssertArrayClass("[[[LMyClass;", "[[LMyClass;", class_loader);
+  AssertArrayClass("[[[LMyClass;", "[[LMyClass;", class_loader.get());
   // or not available at all
   AssertNonExistentClass("[[[[LNonExistentClass;");
 }
@@ -779,9 +780,9 @@
 }
 
 TEST_F(ClassLinkerTest, ValidatePrimitiveArrayElementsOffset) {
-  LongArray* array = LongArray::Alloc(0);
+  SirtRef<LongArray> array(LongArray::Alloc(0));
   EXPECT_EQ(class_linker_->FindSystemClass("[J"), array->GetClass());
-  uint32_t array_offset = reinterpret_cast<uint32_t>(array);
+  uint32_t array_offset = reinterpret_cast<uint32_t>(array.get());
   uint32_t data_offset = reinterpret_cast<uint32_t>(array->GetData());
   EXPECT_EQ(16U, data_offset - array_offset);
 }
@@ -809,18 +810,18 @@
 }
 
 TEST_F(ClassLinkerTest, TwoClassLoadersOneClass) {
-  const ClassLoader* class_loader_1 = LoadDex("MyClass");
-  const ClassLoader* class_loader_2 = LoadDex("MyClass");
-  Class* MyClass_1 = class_linker_->FindClass("LMyClass;", class_loader_1);
-  Class* MyClass_2 = class_linker_->FindClass("LMyClass;", class_loader_2);
+  SirtRef<ClassLoader> class_loader_1(LoadDex("MyClass"));
+  SirtRef<ClassLoader> class_loader_2(LoadDex("MyClass"));
+  Class* MyClass_1 = class_linker_->FindClass("LMyClass;", class_loader_1.get());
+  Class* MyClass_2 = class_linker_->FindClass("LMyClass;", class_loader_2.get());
   EXPECT_TRUE(MyClass_1 != NULL);
   EXPECT_TRUE(MyClass_2 != NULL);
   EXPECT_NE(MyClass_1, MyClass_2);
 }
 
 TEST_F(ClassLinkerTest, StaticFields) {
-  const ClassLoader* class_loader = LoadDex("Statics");
-  Class* statics = class_linker_->FindClass("LStatics;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("Statics"));
+  Class* statics = class_linker_->FindClass("LStatics;", class_loader.get());
   class_linker_->EnsureInitialized(statics, true);
 
   // Static final primitives that are initialized by a compile-time constant
@@ -831,48 +832,48 @@
 
   EXPECT_EQ(9U, statics->NumStaticFields());
 
-  Field* s0 = statics->FindStaticField("s0", class_linker_->FindClass("Z", class_loader));
+  Field* s0 = statics->FindStaticField("s0", class_linker_->FindClass("Z", class_loader.get()));
   EXPECT_TRUE(s0->GetClass()->GetDescriptor()->Equals("Ljava/lang/reflect/Field;"));
   EXPECT_TRUE(s0->GetType()->IsPrimitiveBoolean());
   EXPECT_EQ(true, s0->GetBoolean(NULL));
   s0->SetBoolean(NULL, false);
 
-  Field* s1 = statics->FindStaticField("s1", class_linker_->FindClass("B", class_loader));
+  Field* s1 = statics->FindStaticField("s1", class_linker_->FindClass("B", class_loader.get()));
   EXPECT_TRUE(s1->GetType()->IsPrimitiveByte());
   EXPECT_EQ(5, s1->GetByte(NULL));
   s1->SetByte(NULL, 6);
 
-  Field* s2 = statics->FindStaticField("s2", class_linker_->FindClass("C", class_loader));
+  Field* s2 = statics->FindStaticField("s2", class_linker_->FindClass("C", class_loader.get()));
   EXPECT_TRUE(s2->GetType()->IsPrimitiveChar());
   EXPECT_EQ('a', s2->GetChar(NULL));
   s2->SetChar(NULL, 'b');
 
-  Field* s3 = statics->FindStaticField("s3", class_linker_->FindClass("S", class_loader));
+  Field* s3 = statics->FindStaticField("s3", class_linker_->FindClass("S", class_loader.get()));
   EXPECT_TRUE(s3->GetType()->IsPrimitiveShort());
   EXPECT_EQ(-536, s3->GetShort(NULL));
   s3->SetShort(NULL, -535);
 
-  Field* s4 = statics->FindStaticField("s4", class_linker_->FindClass("I", class_loader));
+  Field* s4 = statics->FindStaticField("s4", class_linker_->FindClass("I", class_loader.get()));
   EXPECT_TRUE(s4->GetType()->IsPrimitiveInt());
   EXPECT_EQ(2000000000, s4->GetInt(NULL));
   s4->SetInt(NULL, 2000000001);
 
-  Field* s5 = statics->FindStaticField("s5", class_linker_->FindClass("J", class_loader));
+  Field* s5 = statics->FindStaticField("s5", class_linker_->FindClass("J", class_loader.get()));
   EXPECT_TRUE(s5->GetType()->IsPrimitiveLong());
   EXPECT_EQ(0x1234567890abcdefLL, s5->GetLong(NULL));
   s5->SetLong(NULL, 0x34567890abcdef12LL);
 
-  Field* s6 = statics->FindStaticField("s6", class_linker_->FindClass("F", class_loader));
+  Field* s6 = statics->FindStaticField("s6", class_linker_->FindClass("F", class_loader.get()));
   EXPECT_TRUE(s6->GetType()->IsPrimitiveFloat());
   EXPECT_EQ(0.5, s6->GetFloat(NULL));
   s6->SetFloat(NULL, 0.75);
 
-  Field* s7 = statics->FindStaticField("s7", class_linker_->FindClass("D", class_loader));
+  Field* s7 = statics->FindStaticField("s7", class_linker_->FindClass("D", class_loader.get()));
   EXPECT_TRUE(s7->GetType()->IsPrimitiveDouble());
   EXPECT_EQ(16777217, s7->GetDouble(NULL));
   s7->SetDouble(NULL, 16777219);
 
-  Field* s8 = statics->FindStaticField("s8", class_linker_->FindClass("Ljava/lang/String;", class_loader));
+  Field* s8 = statics->FindStaticField("s8", class_linker_->FindClass("Ljava/lang/String;", class_loader.get()));
   EXPECT_FALSE(s8->GetType()->IsPrimitive());
   EXPECT_TRUE(s8->GetObject(NULL)->AsString()->Equals("android"));
   s8->SetObject(NULL, String::AllocFromModifiedUtf8("robot"));
@@ -889,12 +890,12 @@
 }
 
 TEST_F(ClassLinkerTest, Interfaces) {
-  const ClassLoader* class_loader = LoadDex("Interfaces");
-  Class* I = class_linker_->FindClass("LInterfaces$I;", class_loader);
-  Class* J = class_linker_->FindClass("LInterfaces$J;", class_loader);
-  Class* K = class_linker_->FindClass("LInterfaces$K;", class_loader);
-  Class* A = class_linker_->FindClass("LInterfaces$A;", class_loader);
-  Class* B = class_linker_->FindClass("LInterfaces$B;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("Interfaces"));
+  Class* I = class_linker_->FindClass("LInterfaces$I;", class_loader.get());
+  Class* J = class_linker_->FindClass("LInterfaces$J;", class_loader.get());
+  Class* K = class_linker_->FindClass("LInterfaces$K;", class_loader.get());
+  Class* A = class_linker_->FindClass("LInterfaces$A;", class_loader.get());
+  Class* B = class_linker_->FindClass("LInterfaces$B;", class_loader.get());
   EXPECT_TRUE(I->IsAssignableFrom(A));
   EXPECT_TRUE(J->IsAssignableFrom(A));
   EXPECT_TRUE(J->IsAssignableFrom(K));
@@ -938,11 +939,11 @@
   // case 1, get the uninitialized storage from StaticsFromCode.<clinit>
   // case 2, get the initialized storage from StaticsFromCode.getS0
 
-  const ClassLoader* class_loader = LoadDex("StaticsFromCode");
-  const DexFile* dex_file = ClassLoader::GetCompileTimeClassPath(class_loader)[0];
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticsFromCode"));
+  const DexFile* dex_file = ClassLoader::GetCompileTimeClassPath(class_loader.get())[0];
   CHECK(dex_file != NULL);
 
-  Class* klass = class_linker_->FindClass("LStaticsFromCode;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticsFromCode;", class_loader.get());
   Method* clinit = klass->FindDirectMethod("<clinit>", "()V");
   Method* getS0 = klass->FindDirectMethod("getS0", "()Ljava/lang/Object;");
   uint32_t type_idx = FindTypeIdxByDescriptor(*dex_file, "LStaticsFromCode;");
diff --git a/src/class_loader.cc b/src/class_loader.cc
index 94a212f..1623303 100644
--- a/src/class_loader.cc
+++ b/src/class_loader.cc
@@ -31,12 +31,12 @@
 // TODO: get global references for these
 Class* PathClassLoader::dalvik_system_PathClassLoader_ = NULL;
 
-const PathClassLoader* PathClassLoader::AllocCompileTime(std::vector<const DexFile*>& dex_files) {
+PathClassLoader* PathClassLoader::AllocCompileTime(std::vector<const DexFile*>& dex_files) {
   CHECK(!Runtime::Current()->IsStarted());
   DCHECK(dalvik_system_PathClassLoader_ != NULL);
-  PathClassLoader* p = down_cast<PathClassLoader*>(dalvik_system_PathClassLoader_->AllocObject());
-  SetCompileTimeClassPath(p, dex_files);
-  return p;
+  SirtRef<PathClassLoader> p(down_cast<PathClassLoader*>(dalvik_system_PathClassLoader_->AllocObject()));
+  SetCompileTimeClassPath(p.get(), dex_files);
+  return p.get();
 }
 
 void PathClassLoader::SetClass(Class* dalvik_system_PathClassLoader) {
diff --git a/src/class_loader.h b/src/class_loader.h
index 2904e3e..a5436fa 100644
--- a/src/class_loader.h
+++ b/src/class_loader.h
@@ -51,7 +51,7 @@
 // TODO: add MANAGED when class_path_ removed
 class PathClassLoader : public BaseDexClassLoader {
  public:
-  static const PathClassLoader* AllocCompileTime(std::vector<const DexFile*>& dex_files);
+  static PathClassLoader* AllocCompileTime(std::vector<const DexFile*>& dex_files);
   static void SetClass(Class* dalvik_system_PathClassLoader);
   static void ResetClass();
  private:
diff --git a/src/common_test.h b/src/common_test.h
index fc02a49..3f56aab 100644
--- a/src/common_test.h
+++ b/src/common_test.h
@@ -297,14 +297,6 @@
     return 0;
   }
 
-  const ClassLoader* AllocPathClassLoader(const DexFile* dex_file) {
-    CHECK(dex_file != NULL);
-    class_linker_->RegisterDexFile(*dex_file);
-    std::vector<const DexFile*> dex_files;
-    dex_files.push_back(dex_file);
-    return PathClassLoader::AllocCompileTime(dex_files);
-  }
-
   const DexFile* OpenTestDexFile(const char* name) {
     CHECK(name != NULL);
     std::string filename;
@@ -320,17 +312,17 @@
     return dex_file;
   }
 
-  const ClassLoader* LoadDex(const char* dex_name) {
+  ClassLoader* LoadDex(const char* dex_name) {
     const DexFile* dex_file = OpenTestDexFile(dex_name);
     CHECK(dex_file != NULL);
     loaded_dex_files_.push_back(dex_file);
     class_linker_->RegisterDexFile(*dex_file);
     std::vector<const DexFile*> class_path;
     class_path.push_back(dex_file);
-    const ClassLoader* class_loader = PathClassLoader::AllocCompileTime(class_path);
-    CHECK(class_loader != NULL);
-    Thread::Current()->SetClassLoaderOverride(class_loader);
-    return class_loader;
+    SirtRef<ClassLoader> class_loader(PathClassLoader::AllocCompileTime(class_path));
+    CHECK(class_loader.get() != NULL);
+    Thread::Current()->SetClassLoaderOverride(class_loader.get());
+    return class_loader.get();
   }
 
   void CompileClass(const ClassLoader* class_loader, const char* class_name) {
@@ -353,7 +345,7 @@
     MakeExecutable(runtime_->GetJniStubArray());
   }
 
-  void CompileDirectMethod(const ClassLoader* class_loader,
+  void CompileDirectMethod(ClassLoader* class_loader,
                            const char* class_name,
                            const char* method_name,
                            const char* signature) {
@@ -366,7 +358,7 @@
     CompileMethod(method);
   }
 
-  void CompileVirtualMethod(const ClassLoader* class_loader,
+  void CompileVirtualMethod(ClassLoader* class_loader,
                             const char* class_name,
                             const char* method_name,
                             const char* signature) {
diff --git a/src/compiler_test.cc b/src/compiler_test.cc
index 1066783..35da6f2 100644
--- a/src/compiler_test.cc
+++ b/src/compiler_test.cc
@@ -127,9 +127,9 @@
 TEST_F(CompilerTest, AbstractMethodErrorStub) {
   CompileDirectMethod(NULL, "java.lang.Object", "<init>", "()V");
 
-  const ClassLoader* class_loader = LoadDex("AbstractMethod");
-  ASSERT_TRUE(class_loader != NULL);
-  EnsureCompiled(class_loader, "AbstractClass", "foo", "()V", true);
+  SirtRef<ClassLoader> class_loader(LoadDex("AbstractMethod"));
+  ASSERT_TRUE(class_loader.get() != NULL);
+  EnsureCompiled(class_loader.get(), "AbstractClass", "foo", "()V", true);
 
   // Create a jobj_ of ConcreteClass, NOT AbstractClass.
   jclass c_class = env_->FindClass("ConcreteClass");
@@ -138,7 +138,7 @@
   ASSERT_TRUE(jobj_ != NULL);
 
 #if defined(__arm__)
-  Class* jlame = class_linker_->FindClass("Ljava/lang/AbstractMethodError;", class_loader);
+  Class* jlame = class_linker_->FindClass("Ljava/lang/AbstractMethodError;", class_loader.get());
   // Force non-virtual call to AbstractClass foo, will throw AbstractMethodError exception.
   env_->CallNonvirtualVoidMethod(jobj_, class_, mid_);
   EXPECT_TRUE(Thread::Current()->IsExceptionPending());
diff --git a/src/dex2oat.cc b/src/dex2oat.cc
index aca8e8c..455b8b9 100644
--- a/src/dex2oat.cc
+++ b/src/dex2oat.cc
@@ -191,16 +191,14 @@
   }
 
   // ClassLoader creation needs to come after Runtime::Create
-  const ClassLoader* class_loader;
-  if (boot_image_option.empty()) {
-    class_loader = NULL;
-  } else {
+  SirtRef<ClassLoader> class_loader(NULL);
+  if (!boot_image_option.empty()) {
     std::vector<const DexFile*> dex_files;
     DexFile::OpenDexFiles(dex_filenames, dex_files, host_prefix);
     for (size_t i = 0; i < dex_files.size(); i++) {
       class_linker->RegisterDexFile(*dex_files[i]);
     }
-    class_loader = PathClassLoader::AllocCompileTime(dex_files);
+    class_loader.reset(PathClassLoader::AllocCompileTime(dex_files));
   }
 
   // if we loaded an existing image, we will reuse values from the image roots.
@@ -224,7 +222,7 @@
   }
   Compiler compiler(kThumb2, image_filename != NULL);
   if (method_names.empty()) {
-    compiler.CompileAll(class_loader);
+    compiler.CompileAll(class_loader.get());
   } else {
     for (size_t i = 0; i < method_names.size(); i++) {
       // names are actually class_descriptor + name + signature.
@@ -246,7 +244,7 @@
                                             end_of_name - end_of_class_descriptor).ToString();
       std::string signature = method_name.substr(end_of_name).ToString();
 
-      Class* klass = class_linker->FindClass(class_descriptor, class_loader);
+      Class* klass = class_linker->FindClass(class_descriptor, class_loader.get());
       if (klass == NULL) {
         fprintf(stderr, "could not find class for descriptor %s in method %s\n",
                 class_descriptor.c_str(), method_name.data());
@@ -268,7 +266,7 @@
     }
   }
 
-  if (!OatWriter::Create(oat_filename, class_loader, compiler)) {
+  if (!OatWriter::Create(oat_filename, class_loader.get(), compiler)) {
     fprintf(stderr, "Failed to create oat file %s\n", oat_filename.c_str());
     return EXIT_FAILURE;
   }
diff --git a/src/dex_cache_test.cc b/src/dex_cache_test.cc
index 25f194a..054f991 100644
--- a/src/dex_cache_test.cc
+++ b/src/dex_cache_test.cc
@@ -13,8 +13,8 @@
 class DexCacheTest : public CommonTest {};
 
 TEST_F(DexCacheTest, Open) {
-  DexCache* dex_cache = class_linker_->AllocDexCache(*java_lang_dex_file_.get());
-  ASSERT_TRUE(dex_cache != NULL);
+  SirtRef<DexCache> dex_cache(class_linker_->AllocDexCache(*java_lang_dex_file_.get()));
+  ASSERT_TRUE(dex_cache.get() != NULL);
 
   EXPECT_EQ(java_lang_dex_file_->NumStringIds(), dex_cache->NumStrings());
   EXPECT_EQ(java_lang_dex_file_->NumTypeIds(),   dex_cache->NumResolvedTypes());
diff --git a/src/dex_file.cc b/src/dex_file.cc
index 220c111..83866a9 100644
--- a/src/dex_file.cc
+++ b/src/dex_file.cc
@@ -596,13 +596,6 @@
   return static_cast<ValueType>(type);
 }
 
-String* DexFile::dexArtStringById(int32_t idx) const {
-  if (idx == -1) {
-    return NULL;
-  }
-  return String::AllocFromModifiedUtf8(dexStringById(idx));
-}
-
 int32_t DexFile::GetLineNumFromPC(const art::Method* method, uint32_t rel_pc) const {
   // For native method, lineno should be -2 to indicate it is native. Note that
   // "line number == -2" is how libcore tells from StackTraceElement.
@@ -630,9 +623,11 @@
 
   if (!method->IsStatic()) {
     if (need_locals) {
-      local_in_reg[arg_reg].name_ = String::AllocFromModifiedUtf8("this");
-      local_in_reg[arg_reg].descriptor_ = method->GetDeclaringClass()->GetDescriptor();
-      local_in_reg[arg_reg].signature_ = NULL;
+      std::string descriptor = method->GetDeclaringClass()->GetDescriptor()->ToModifiedUtf8();
+      const ClassDef* class_def = FindClassDef(descriptor);
+      CHECK(class_def != NULL) << descriptor;
+      local_in_reg[arg_reg].name_ = "this";
+      local_in_reg[arg_reg].descriptor_ = GetClassDescriptor(*class_def);
       local_in_reg[arg_reg].start_address_ = 0;
       local_in_reg[arg_reg].is_live_ = true;
     }
@@ -646,17 +641,15 @@
       return;
     }
     int32_t id = DecodeUnsignedLeb128P1(&stream);
-    const char* descriptor_utf8 = it->GetDescriptor();
+    const char* descriptor = it->GetDescriptor();
     if (need_locals) {
-      String* descriptor = String::AllocFromModifiedUtf8(descriptor_utf8);
-      String* name = dexArtStringById(id);
+      const char* name = dexStringById(id);
       local_in_reg[arg_reg].name_ = name;
       local_in_reg[arg_reg].descriptor_ = descriptor;
-      local_in_reg[arg_reg].signature_ = NULL;
       local_in_reg[arg_reg].start_address_ = address;
       local_in_reg[arg_reg].is_live_ = true;
     }
-    switch (*descriptor_utf8) {
+    switch (*descriptor) {
       case 'D':
       case 'J':
         arg_reg += 2;
@@ -700,12 +693,10 @@
         if (need_locals) {
           InvokeLocalCbIfLive(cnxt, reg, address, local_in_reg, local_cb);
 
-          local_in_reg[reg].name_ = dexArtStringById(DecodeUnsignedLeb128P1(&stream));
-          local_in_reg[reg].descriptor_ = dexArtStringByTypeIdx(DecodeUnsignedLeb128P1(&stream));
+          local_in_reg[reg].name_ = dexStringById(DecodeUnsignedLeb128P1(&stream));
+          local_in_reg[reg].descriptor_ = dexStringByTypeIdx(DecodeUnsignedLeb128P1(&stream));
           if (opcode == DBG_START_LOCAL_EXTENDED) {
-            local_in_reg[reg].signature_ = dexArtStringById(DecodeUnsignedLeb128P1(&stream));
-          } else {
-            local_in_reg[reg].signature_ = NULL;
+            local_in_reg[reg].signature_ = dexStringById(DecodeUnsignedLeb128P1(&stream));
           }
           local_in_reg[reg].start_address_ = address;
           local_in_reg[reg].is_live_ = true;
diff --git a/src/dex_file.h b/src/dex_file.h
index f41040e..070901b 100644
--- a/src/dex_file.h
+++ b/src/dex_file.h
@@ -575,8 +575,6 @@
     return dexStringById(idx, &unicode_length);
   }
 
-  String* dexArtStringById(int32_t idx) const;
-
   // Get the descriptor string associated with a given type index.
   const char* dexStringByTypeIdx(uint32_t idx, int32_t* unicode_length) const {
     const TypeId& type_id = GetTypeId(idx);
@@ -588,11 +586,6 @@
     return dexStringById(type_id.descriptor_idx_);
   }
 
-  String* dexArtStringByTypeIdx(int32_t idx) const {
-    const TypeId& type_id = GetTypeId(idx);
-    return dexArtStringById(type_id.descriptor_idx_);
-  }
-
   // TODO: encoded_field is actually a stream of bytes
   void dexReadClassDataField(const byte** encoded_field,
                              DexFile::Field* field,
@@ -712,9 +705,9 @@
   typedef void (*DexDebugNewLocalCb)(void* cnxt, uint16_t reg,
                                      uint32_t startAddress,
                                      uint32_t endAddress,
-                                     const String* name,
-                                     const String* descriptor,
-                                     const String* signature);
+                                     const char* name,
+                                     const char* descriptor,
+                                     const char* signature);
 
   static bool LineNumForPcCb(void* cnxt, uint32_t address, uint32_t line_num) {
     LineNumFromPcContext* context = (LineNumFromPcContext*) cnxt;
@@ -751,16 +744,16 @@
   };
 
   struct LocalInfo {
-    LocalInfo() : name_(NULL), descriptor_(NULL), signature_(NULL), start_address_(0), is_live_(false) {}
+    LocalInfo() : name_(NULL), start_address_(0), is_live_(false) {}
 
     // E.g., list
-    const String* name_;
+    const char* name_;
 
     // E.g., Ljava/util/LinkedList;
-    const String* descriptor_;
+    const char* descriptor_;
 
     // E.g., java.util.LinkedList<java.lang.Integer>
-    const String* signature_;
+    const char* signature_;
 
     // PC location where the local is first defined.
     uint16_t start_address_;
diff --git a/src/dex_verifier.cc b/src/dex_verifier.cc
index 3e41d23..6c7335e 100644
--- a/src/dex_verifier.cc
+++ b/src/dex_verifier.cc
@@ -345,14 +345,20 @@
 
   /* Generate a register map and add it to the method. */
   UniquePtr<RegisterMap> map(GenerateRegisterMapV(vdata));
-  ByteArray* header = ByteArray::Alloc(sizeof(RegisterMapHeader));
-  ByteArray* data = ByteArray::Alloc(ComputeRegisterMapSize(map.get()));
+  SirtRef<ByteArray> header(ByteArray::Alloc(sizeof(RegisterMapHeader)));
+  if (header.get() == NULL) {
+    return false;
+  }
+  SirtRef<ByteArray> data(ByteArray::Alloc(ComputeRegisterMapSize(map.get())));
+  if (data.get() == NULL) {
+    return false;
+  }
 
   memcpy(header->GetData(), map.get()->header_, sizeof(RegisterMapHeader));
   memcpy(data->GetData(), map.get()->data_, ComputeRegisterMapSize(map.get()));
 
-  method->SetRegisterMapHeader(header);
-  method->SetRegisterMapData(data);
+  method->SetRegisterMapHeader(header.get());
+  method->SetRegisterMapData(data.get());
 
   return true;
 }
@@ -5305,8 +5311,14 @@
   }
 
   /* Update method, and free compressed map if it was sitting on the heap. */
-  //ByteArray* header = ByteArray::Alloc(sizeof(RegisterMapHeader));
-  //ByteArray* data = ByteArray::Alloc(ComputeRegisterMapSize(map));
+  //SirtRef<ByteArray> header(ByteArray::Alloc(sizeof(RegisterMapHeader)));
+  //if (header.get() == NULL) {
+  //  return NULL;
+  //}
+  //SirtRef<ByteArray> data(ByteArray::Alloc(ComputeRegisterMapSize(map)));
+  //if (data.get() == NULL) {
+  //  return NULL;
+  //}
 
   //memcpy(header->GetData(), map->header_, sizeof(RegisterMapHeader));
   //memcpy(data->GetData(), map->data_, ComputeRegisterMapSize(map));
diff --git a/src/dex_verifier_test.cc b/src/dex_verifier_test.cc
index b72f495..0560abf 100644
--- a/src/dex_verifier_test.cc
+++ b/src/dex_verifier_test.cc
@@ -38,8 +38,8 @@
 }
 
 TEST_F(DexVerifierTest, IntMath) {
-  const ClassLoader* class_loader = LoadDex("IntMath");
-  Class* klass = class_linker_->FindClass("LIntMath;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("IntMath"));
+  Class* klass = class_linker_->FindClass("LIntMath;", class_loader.get());
   ASSERT_TRUE(DexVerifier::VerifyClass(klass));
 }
 
diff --git a/src/exception_test.cc b/src/exception_test.cc
index 29d91bb..ab462a9 100644
--- a/src/exception_test.cc
+++ b/src/exception_test.cc
@@ -19,8 +19,8 @@
   virtual void SetUp() {
     CommonTest::SetUp();
 
-    const ClassLoader* class_loader = LoadDex("ExceptionHandle");
-    my_klass_ = class_linker_->FindClass("LExceptionHandle;", class_loader);
+    SirtRef<ClassLoader> class_loader(LoadDex("ExceptionHandle"));
+    my_klass_ = class_linker_->FindClass("LExceptionHandle;", class_loader.get());
     ASSERT_TRUE(my_klass_ != NULL);
 
     dex_ = &Runtime::Current()->GetClassLinker()->FindDexFile(my_klass_->GetDexCache());
diff --git a/src/heap_test.cc b/src/heap_test.cc
index 8e869be..addbc0b 100644
--- a/src/heap_test.cc
+++ b/src/heap_test.cc
@@ -11,7 +11,7 @@
 
   Class* c = class_linker_->FindSystemClass("[Ljava/lang/Object;");
   for (size_t i = 0; i < 1024; ++i) {
-    ObjectArray<Object>* array = ObjectArray<Object>::Alloc(c, 2048);
+    SirtRef<ObjectArray<Object> > array(ObjectArray<Object>::Alloc(c, 2048));
     for (size_t j = 0; j < 2048; ++j) {
       array->Set(j, String::AllocFromModifiedUtf8("hello, world!"));
     }
diff --git a/src/image_writer.cc b/src/image_writer.cc
index f3ecda0..d8385ca 100644
--- a/src/image_writer.cc
+++ b/src/image_writer.cc
@@ -89,14 +89,14 @@
       DCHECK_EQ(obj, obj->AsString()->Intern());
       return;
     }
-    String* interned = obj->AsString()->Intern();
-    if (obj != interned) {
-      if (!IsImageOffsetAssigned(interned)) {
+    SirtRef<String> interned(obj->AsString()->Intern());
+    if (obj != interned.get()) {
+      if (!IsImageOffsetAssigned(interned.get())) {
         // interned obj is after us, allocate its location early
-        image_writer->AssignImageOffset(interned);
+        image_writer->AssignImageOffset(interned.get());
       }
       // point those looking for this object to the interned version.
-      SetImageOffset(obj, GetImageOffset(interned));
+      SetImageOffset(obj, GetImageOffset(interned.get()));
       return;
     }
     // else (obj == interned), nothing to do but fall through to the normal case
@@ -137,8 +137,8 @@
   }
 
   // build an Object[] of the roots needed to restore the runtime
-  ObjectArray<Object>* image_roots = ObjectArray<Object>::Alloc(object_array_class,
-                                                                ImageHeader::kImageRootsMax);
+  SirtRef<ObjectArray<Object> > image_roots(
+      ObjectArray<Object>::Alloc(object_array_class, ImageHeader::kImageRootsMax));
   image_roots->Set(ImageHeader::kJniStubArray, runtime->GetJniStubArray());
   image_roots->Set(ImageHeader::kAbstractMethodErrorStubArray,
                    runtime->GetAbstractMethodErrorStubArray());
@@ -163,11 +163,11 @@
   for (int i = 0; i < ImageHeader::kImageRootsMax; i++) {
     CHECK(image_roots->Get(i) != NULL);
   }
-  return image_roots;
+  return image_roots.get();
 }
 
 void ImageWriter::CalculateNewObjectOffsets() {
-  ObjectArray<Object>* image_roots = CreateImageRoots();
+  SirtRef<ObjectArray<Object> > image_roots(CreateImageRoots());
 
   HeapBitmap* heap_bitmap = Heap::GetLiveBits();
   DCHECK(heap_bitmap != NULL);
@@ -186,7 +186,7 @@
 
   // return to write header at start of image with future location of image_roots
   ImageHeader image_header(reinterpret_cast<uint32_t>(image_base_),
-                           reinterpret_cast<uint32_t>(GetImageAddress(image_roots)),
+                           reinterpret_cast<uint32_t>(GetImageAddress(image_roots.get())),
                            oat_file_->GetOatHeader().GetChecksum(),
                            reinterpret_cast<uint32_t>(oat_base_),
                            reinterpret_cast<uint32_t>(oat_limit));
diff --git a/src/intern_table_test.cc b/src/intern_table_test.cc
index d9a6a45..a0c47ec 100644
--- a/src/intern_table_test.cc
+++ b/src/intern_table_test.cc
@@ -11,26 +11,27 @@
 
 TEST_F(InternTableTest, Intern) {
   InternTable intern_table;
-  const String* foo_1 = intern_table.InternStrong(3, "foo");
-  const String* foo_2 = intern_table.InternStrong(3, "foo");
-  const String* foo_3 = String::AllocFromModifiedUtf8("foo");
-  const String* bar = intern_table.InternStrong(3, "bar");
+  SirtRef<String> foo_1(intern_table.InternStrong(3, "foo"));
+  SirtRef<String> foo_2(intern_table.InternStrong(3, "foo"));
+  SirtRef<String> foo_3(String::AllocFromModifiedUtf8("foo"));
+  SirtRef<String> bar(intern_table.InternStrong(3, "bar"));
   EXPECT_TRUE(foo_1->Equals("foo"));
   EXPECT_TRUE(foo_2->Equals("foo"));
   EXPECT_TRUE(foo_3->Equals("foo"));
-  EXPECT_TRUE(foo_1 != NULL);
-  EXPECT_TRUE(foo_2 != NULL);
-  EXPECT_EQ(foo_1, foo_2);
-  EXPECT_NE(foo_1, bar);
-  EXPECT_NE(foo_2, bar);
-  EXPECT_NE(foo_3, bar);
+  EXPECT_TRUE(foo_1.get() != NULL);
+  EXPECT_TRUE(foo_2.get() != NULL);
+  EXPECT_EQ(foo_1.get(), foo_2.get());
+  EXPECT_NE(foo_1.get(), bar.get());
+  EXPECT_NE(foo_2.get(), bar.get());
+  EXPECT_NE(foo_3.get(), bar.get());
 }
 
 TEST_F(InternTableTest, Size) {
   InternTable t;
   EXPECT_EQ(0U, t.Size());
   t.InternStrong(3, "foo");
-  t.InternWeak(String::AllocFromModifiedUtf8("foo"));
+  SirtRef<String> foo(String::AllocFromModifiedUtf8("foo"));
+  t.InternWeak(foo.get());
   EXPECT_EQ(1U, t.Size());
   t.InternStrong(3, "bar");
   EXPECT_EQ(2U, t.Size());
@@ -72,21 +73,24 @@
   InternTable t;
   t.InternStrong(3, "foo");
   t.InternStrong(3, "bar");
-  const String* s0 = t.InternWeak(String::AllocFromModifiedUtf8("hello"));
-  const String* s1 = t.InternWeak(String::AllocFromModifiedUtf8("world"));
+  SirtRef<String> hello(String::AllocFromModifiedUtf8("hello"));
+  SirtRef<String> world(String::AllocFromModifiedUtf8("world"));
+  SirtRef<String> s0(t.InternWeak(hello.get()));
+  SirtRef<String> s1(t.InternWeak(world.get()));
 
   EXPECT_EQ(4U, t.Size());
 
   // We should traverse only the weaks...
   TestPredicate p;
-  p.Expect(s0);
-  p.Expect(s1);
+  p.Expect(s0.get());
+  p.Expect(s1.get());
   t.SweepInternTableWeaks(IsMarked, &p);
 
   EXPECT_EQ(2U, t.Size());
 
   // Just check that we didn't corrupt the unordered_multimap.
-  t.InternWeak(String::AllocFromModifiedUtf8("still here"));
+  SirtRef<String> still_here(String::AllocFromModifiedUtf8("still here"));
+  t.InternWeak(still_here.get());
   EXPECT_EQ(3U, t.Size());
 }
 
@@ -94,41 +98,45 @@
   {
     // Strongs are never weak.
     InternTable t;
-    String* foo_1 = t.InternStrong(3, "foo");
-    EXPECT_FALSE(t.ContainsWeak(foo_1));
-    String* foo_2 = t.InternStrong(3, "foo");
-    EXPECT_FALSE(t.ContainsWeak(foo_2));
-    EXPECT_EQ(foo_1, foo_2);
+    SirtRef<String> interned_foo_1(t.InternStrong(3, "foo"));
+    EXPECT_FALSE(t.ContainsWeak(interned_foo_1.get()));
+    SirtRef<String> interned_foo_2(t.InternStrong(3, "foo"));
+    EXPECT_FALSE(t.ContainsWeak(interned_foo_2.get()));
+    EXPECT_EQ(interned_foo_1.get(), interned_foo_2.get());
   }
 
   {
     // Weaks are always weak.
     InternTable t;
-    String* foo_1 = t.InternWeak(String::AllocFromModifiedUtf8("foo"));
-    EXPECT_TRUE(t.ContainsWeak(foo_1));
-    String* foo_2 = t.InternWeak(String::AllocFromModifiedUtf8("foo"));
-    EXPECT_TRUE(t.ContainsWeak(foo_2));
-    EXPECT_EQ(foo_1, foo_2);
+    SirtRef<String> foo_1(String::AllocFromModifiedUtf8("foo"));
+    SirtRef<String> foo_2(String::AllocFromModifiedUtf8("foo"));
+    EXPECT_NE(foo_1.get(), foo_2.get());
+    SirtRef<String> interned_foo_1(t.InternWeak(foo_1.get()));
+    SirtRef<String> interned_foo_2(t.InternWeak(foo_2.get()));
+    EXPECT_TRUE(t.ContainsWeak(interned_foo_2.get()));
+    EXPECT_EQ(interned_foo_1.get(), interned_foo_2.get());
   }
 
   {
     // A weak can be promoted to a strong.
     InternTable t;
-    String* foo_1 = t.InternWeak(String::AllocFromModifiedUtf8("foo"));
-    EXPECT_TRUE(t.ContainsWeak(foo_1));
-    String* foo_2 = t.InternStrong(3, "foo");
-    EXPECT_FALSE(t.ContainsWeak(foo_2));
-    EXPECT_EQ(foo_1, foo_2);
+    SirtRef<String> foo(String::AllocFromModifiedUtf8("foo"));
+    SirtRef<String> interned_foo_1(t.InternWeak(foo.get()));
+    EXPECT_TRUE(t.ContainsWeak(interned_foo_1.get()));
+    SirtRef<String> interned_foo_2(t.InternStrong(3, "foo"));
+    EXPECT_FALSE(t.ContainsWeak(interned_foo_2.get()));
+    EXPECT_EQ(interned_foo_1.get(), interned_foo_2.get());
   }
 
   {
     // Interning a weak after a strong gets you the strong.
     InternTable t;
-    String* foo_1 = t.InternStrong(3, "foo");
-    EXPECT_FALSE(t.ContainsWeak(foo_1));
-    String* foo_2 = t.InternWeak(String::AllocFromModifiedUtf8("foo"));
-    EXPECT_FALSE(t.ContainsWeak(foo_2));
-    EXPECT_EQ(foo_1, foo_2);
+    SirtRef<String> interned_foo_1(t.InternStrong(3, "foo"));
+    EXPECT_FALSE(t.ContainsWeak(interned_foo_1.get()));
+    SirtRef<String> foo(String::AllocFromModifiedUtf8("foo"));
+    SirtRef<String> interned_foo_2(t.InternWeak(foo.get()));
+    EXPECT_FALSE(t.ContainsWeak(interned_foo_2.get()));
+    EXPECT_EQ(interned_foo_1.get(), interned_foo_2.get());
   }
 }
 
diff --git a/src/java_lang_reflect_Array.cc b/src/java_lang_reflect_Array.cc
index f1456b1..d157282 100644
--- a/src/java_lang_reflect_Array.cc
+++ b/src/java_lang_reflect_Array.cc
@@ -29,13 +29,13 @@
 // Objects or primitive types.
 Array* CreateMultiArray(Class* array_class, int current_dimension, IntArray* dimensions) {
   int32_t array_length = dimensions->Get(current_dimension++);
-  Array* new_array = Array::Alloc(array_class, array_length);
-  if (new_array == NULL) {
+  SirtRef<Array> new_array(Array::Alloc(array_class, array_length));
+  if (new_array.get() == NULL) {
     CHECK(Thread::Current()->IsExceptionPending());
     return NULL;
   }
   if (current_dimension == dimensions->GetLength()) {
-    return new_array;
+    return new_array.get();
   }
 
   if (!array_class->GetComponentType()->IsArrayClass()) {
@@ -53,16 +53,16 @@
   }
   DCHECK(sub_array_class->IsArrayClass());
   // Create a new sub-array in every element of the array.
-  ObjectArray<Array>* object_array = new_array->AsObjectArray<Array>();
+  SirtRef<ObjectArray<Array> > object_array(new_array->AsObjectArray<Array>());
   for (int32_t i = 0; i < array_length; i++) {
-    Array* sub_array = CreateMultiArray(sub_array_class, current_dimension, dimensions);
-    if (sub_array == NULL) {
+    SirtRef<Array> sub_array(CreateMultiArray(sub_array_class, current_dimension, dimensions));
+    if (sub_array.get() == NULL) {
       CHECK(Thread::Current()->IsExceptionPending());
       return NULL;
     }
-    object_array->Set(i, sub_array);
+    object_array->Set(i, sub_array.get());
   }
-  return new_array;
+  return new_array.get();
 }
 
 // Create a multi-dimensional array of Objects or primitive types.
diff --git a/src/jni_compiler_test.cc b/src/jni_compiler_test.cc
index ecc2f88..938d733 100644
--- a/src/jni_compiler_test.cc
+++ b/src/jni_compiler_test.cc
@@ -21,14 +21,11 @@
 
 class JniCompilerTest : public CommonTest {
  protected:
-  virtual void SetUp() {
-    CommonTest::SetUp();
-    class_loader_ = LoadDex("MyClassNatives");
-  }
 
-  void CompileForTest(bool direct, const char* method_name, const char* method_sig) {
+  void CompileForTest(ClassLoader* class_loader, bool direct,
+                      const char* method_name, const char* method_sig) {
     // Compile the native method before starting the runtime
-    Class* c = class_linker_->FindClass("LMyClass;", class_loader_);
+    Class* c = class_linker_->FindClass("LMyClass;", class_loader);
     Method* method;
     if (direct) {
       method = c->FindDirectMethod(method_name, method_sig);
@@ -43,9 +40,10 @@
     ASSERT_TRUE(method->GetCode() != NULL);
   }
 
-  void SetupForTest(bool direct, const char* method_name, const char* method_sig,
+  void SetupForTest(ClassLoader* class_loader, bool direct,
+                    const char* method_name, const char* method_sig,
                     void* native_fnptr) {
-    CompileForTest(direct, method_name, method_sig);
+    CompileForTest(class_loader, direct, method_name, method_sig);
     if (!runtime_->IsStarted()) {
       runtime_->Start();
     }
@@ -78,7 +76,6 @@
   static jclass jklass_;
   static jobject jobj_;
  protected:
-  const ClassLoader* class_loader_;
   JNIEnv* env_;
   jmethodID jmethod_;
 };
@@ -88,7 +85,8 @@
 
 int gJava_MyClass_foo_calls = 0;
 void Java_MyClass_foo(JNIEnv* env, jobject thisObj) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -97,7 +95,9 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunNoArgMethod) {
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "foo", "()V",
+               reinterpret_cast<void*>(&Java_MyClass_foo));
 
   EXPECT_EQ(0, gJava_MyClass_foo_calls);
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
@@ -107,15 +107,13 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntMethodThroughStub) {
-  SetupForTest(false,
-               "bar",
-               "(I)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "bar", "(I)I",
                NULL /* dlsym will find &Java_MyClass_bar later */);
 
   std::string path("libarttest.so");
   std::string reason;
-  ASSERT_TRUE(Runtime::Current()->GetJavaVM()->LoadNativeLibrary(
-      path, const_cast<ClassLoader*>(class_loader_), reason))
+  ASSERT_TRUE(Runtime::Current()->GetJavaVM()->LoadNativeLibrary(path, class_loader.get(), reason))
       << path << ": " << reason;
 
   jint result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 24);
@@ -124,7 +122,8 @@
 
 int gJava_MyClass_fooI_calls = 0;
 jint Java_MyClass_fooI(JNIEnv* env, jobject thisObj, jint x) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -134,7 +133,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntMethod) {
-  SetupForTest(false, "fooI", "(I)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I",
                reinterpret_cast<void*>(&Java_MyClass_fooI));
 
   EXPECT_EQ(0, gJava_MyClass_fooI_calls);
@@ -148,7 +148,8 @@
 
 int gJava_MyClass_fooII_calls = 0;
 jint Java_MyClass_fooII(JNIEnv* env, jobject thisObj, jint x, jint y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -158,7 +159,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntIntMethod) {
-  SetupForTest(false, "fooII", "(II)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooII", "(II)I",
                reinterpret_cast<void*>(&Java_MyClass_fooII));
 
   EXPECT_EQ(0, gJava_MyClass_fooII_calls);
@@ -173,7 +175,8 @@
 
 int gJava_MyClass_fooJJ_calls = 0;
 jlong Java_MyClass_fooJJ(JNIEnv* env, jobject thisObj, jlong x, jlong y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -183,7 +186,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunLongLongMethod) {
-  SetupForTest(false, "fooJJ", "(JJ)J",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooJJ", "(JJ)J",
                reinterpret_cast<void*>(&Java_MyClass_fooJJ));
 
   EXPECT_EQ(0, gJava_MyClass_fooJJ_calls);
@@ -199,7 +203,8 @@
 
 int gJava_MyClass_fooDD_calls = 0;
 jdouble Java_MyClass_fooDD(JNIEnv* env, jobject thisObj, jdouble x, jdouble y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + thisObj
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -209,7 +214,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunDoubleDoubleMethod) {
-  SetupForTest(false, "fooDD", "(DD)D",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooDD", "(DD)D",
                reinterpret_cast<void*>(&Java_MyClass_fooDD));
 
   EXPECT_EQ(0, gJava_MyClass_fooDD_calls);
@@ -227,7 +233,8 @@
 int gJava_MyClass_fooIOO_calls = 0;
 jobject Java_MyClass_fooIOO(JNIEnv* env, jobject thisObj, jint x, jobject y,
                             jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + this + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(thisObj != NULL);
@@ -244,7 +251,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunIntObjectObjectMethod) {
-  SetupForTest(false, "fooIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooIOO));
 
@@ -276,7 +284,8 @@
 
 int gJava_MyClass_fooSII_calls = 0;
 jint Java_MyClass_fooSII(JNIEnv* env, jclass klass, jint x, jint y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + klass
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -286,8 +295,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticIntIntMethod) {
-  SetupForTest(true, "fooSII",
-               "(II)I",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSII", "(II)I",
                reinterpret_cast<void*>(&Java_MyClass_fooSII));
 
   EXPECT_EQ(0, gJava_MyClass_fooSII_calls);
@@ -298,7 +307,8 @@
 
 int gJava_MyClass_fooSDD_calls = 0;
 jdouble Java_MyClass_fooSDD(JNIEnv* env, jclass klass, jdouble x, jdouble y) {
-  EXPECT_EQ(1u, Thread::Current()->NumSirtReferences());
+  // 2 = SirtRef<ClassLoader> + klass
+  EXPECT_EQ(2U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -308,7 +318,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticDoubleDoubleMethod) {
-  SetupForTest(true, "fooSDD", "(DD)D",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSDD", "(DD)D",
                reinterpret_cast<void*>(&Java_MyClass_fooSDD));
 
   EXPECT_EQ(0, gJava_MyClass_fooSDD_calls);
@@ -325,7 +336,8 @@
 int gJava_MyClass_fooSIOO_calls = 0;
 jobject Java_MyClass_fooSIOO(JNIEnv* env, jclass klass, jint x, jobject y,
                              jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + klass + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -343,7 +355,8 @@
 
 
 TEST_F(JniCompilerTest, CompileAndRunStaticIntObjectObjectMethod) {
-  SetupForTest(true, "fooSIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooSIOO));
 
@@ -376,7 +389,8 @@
 int gJava_MyClass_fooSSIOO_calls = 0;
 jobject Java_MyClass_fooSSIOO(JNIEnv* env, jclass klass, jint x, jobject y,
                              jobject z) {
-  EXPECT_EQ(3u, Thread::Current()->NumSirtReferences());
+  // 4 = SirtRef<ClassLoader> + klass + y + z
+  EXPECT_EQ(4U, Thread::Current()->NumSirtReferences());
   EXPECT_EQ(Thread::kNative, Thread::Current()->GetState());
   EXPECT_EQ(Thread::Current()->GetJniEnv(), env);
   EXPECT_TRUE(klass != NULL);
@@ -393,7 +407,8 @@
 }
 
 TEST_F(JniCompilerTest, CompileAndRunStaticSynchronizedIntObjectObjectMethod) {
-  SetupForTest(true, "fooSSIOO",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "fooSSIOO",
                "(ILjava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooSSIOO));
 
@@ -429,22 +444,25 @@
 }
 
 TEST_F(JniCompilerTest, ExceptionHandling) {
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+
   // all compilation needs to happen before SetupForTest calls Runtime::Start
-  CompileForTest(false, "foo", "()V");
-  CompileForTest(false, "throwException", "()V");
-  CompileForTest(false, "foo", "()V");
+  CompileForTest(class_loader.get(), false, "foo", "()V");
+  CompileForTest(class_loader.get(), false, "throwException", "()V");
+  CompileForTest(class_loader.get(), false, "foo", "()V");
 
   gJava_MyClass_foo_calls = 0;
 
   // Check a single call of a JNI method is ok
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SetupForTest(class_loader.get(), false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(1, gJava_MyClass_foo_calls);
   EXPECT_FALSE(Thread::Current()->IsExceptionPending());
 
   // Get class for exception we expect to be thrown
-  Class* jlre = class_linker_->FindClass("Ljava/lang/RuntimeException;", class_loader_);
-  SetupForTest(false, "throwException", "()V", reinterpret_cast<void*>(&Java_MyClass_throwException));
+  Class* jlre = class_linker_->FindClass("Ljava/lang/RuntimeException;", class_loader.get());
+  SetupForTest(class_loader.get(), false, "throwException", "()V",
+               reinterpret_cast<void*>(&Java_MyClass_throwException));
   // Call Java_MyClass_throwException (JNI method that throws exception)
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(1, gJava_MyClass_foo_calls);
@@ -453,7 +471,7 @@
   Thread::Current()->ClearException();
 
   // Check a single call of a JNI method is ok
-  SetupForTest(false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
+  SetupForTest(class_loader.get(), false, "foo", "()V", reinterpret_cast<void*>(&Java_MyClass_foo));
   env_->CallNonvirtualVoidMethod(jobj_, jklass_, jmethod_);
   EXPECT_EQ(2, gJava_MyClass_foo_calls);
 }
@@ -497,7 +515,9 @@
 }
 
 TEST_F(JniCompilerTest, NativeStackTraceElement) {
-  SetupForTest(false, "fooI", "(I)I", reinterpret_cast<void*>(&Java_MyClass_nativeUpCall));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I",
+               reinterpret_cast<void*>(&Java_MyClass_nativeUpCall));
   jint result = env_->CallNonvirtualIntMethod(jobj_, jklass_, jmethod_, 10);
   EXPECT_EQ(10+9+8+7+6+5+4+3+2+1, result);
 }
@@ -507,7 +527,8 @@
 }
 
 TEST_F(JniCompilerTest, ReturnGlobalRef) {
-  SetupForTest(false, "fooO", "(Ljava/lang/Object;)Ljava/lang/Object;",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooO", "(Ljava/lang/Object;)Ljava/lang/Object;",
                reinterpret_cast<void*>(&Java_MyClass_fooO));
   jobject result = env_->CallNonvirtualObjectMethod(jobj_, jklass_, jmethod_, jobj_);
   EXPECT_EQ(JNILocalRefType, env_->GetObjectRefType(result));
@@ -523,7 +544,8 @@
 }
 
 TEST_F(JniCompilerTest, LocalReferenceTableClearingTest) {
-  SetupForTest(false, "fooI", "(I)I", reinterpret_cast<void*>(&local_ref_test));
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "fooI", "(I)I", reinterpret_cast<void*>(&local_ref_test));
   // 1000 invocations of a method that adds 10 local references
   for (int i=0; i < 1000; i++) {
     jint result = env_->CallIntMethod(jobj_, jmethod_, i);
@@ -541,7 +563,8 @@
 }
 
 TEST_F(JniCompilerTest, JavaLangSystemArrayCopy) {
-  SetupForTest(true, "arraycopy", "(Ljava/lang/Object;ILjava/lang/Object;II)V",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), true, "arraycopy", "(Ljava/lang/Object;ILjava/lang/Object;II)V",
                reinterpret_cast<void*>(&my_arraycopy));
   env_->CallStaticVoidMethod(jklass_, jmethod_, jobj_, 1234, jklass_, 5678, 9876);
 }
@@ -556,7 +579,8 @@
 }
 
 TEST_F(JniCompilerTest, CompareAndSwapInt) {
-  SetupForTest(false, "compareAndSwapInt", "(Ljava/lang/Object;JII)Z",
+  SirtRef<ClassLoader> class_loader(LoadDex("MyClassNatives"));
+  SetupForTest(class_loader.get(), false, "compareAndSwapInt", "(Ljava/lang/Object;JII)Z",
                reinterpret_cast<void*>(&my_casi));
   jboolean result = env_->CallBooleanMethod(jobj_, jmethod_, jobj_, 0x12345678ABCDEF88ll, 0xCAFEF00D, 0xEBADF00D);
   EXPECT_EQ(result, JNI_TRUE);
diff --git a/src/jni_internal_test.cc b/src/jni_internal_test.cc
index e1ab9db..18d8e0a 100644
--- a/src/jni_internal_test.cc
+++ b/src/jni_internal_test.cc
@@ -652,7 +652,7 @@
 
 
 TEST_F(JniInternalTest, GetPrimitiveField_SetPrimitiveField) {
-  LoadDex("AllFields");
+  SirtRef<ClassLoader> class_loader(LoadDex("AllFields"));
   runtime_->Start();
 
   jclass c = env_->FindClass("AllFields");
@@ -680,7 +680,7 @@
 }
 
 TEST_F(JniInternalTest, GetObjectField_SetObjectField) {
-  LoadDex("AllFields");
+  SirtRef<ClassLoader> class_loader(LoadDex("AllFields"));
   runtime_->Start();
 
   jclass c = env_->FindClass("AllFields");
@@ -865,10 +865,10 @@
 
 #if defined(__arm__)
 TEST_F(JniInternalTest, StaticMainMethod) {
-  const ClassLoader* class_loader = LoadDex("Main");
-  CompileDirectMethod(class_loader, "Main", "main", "([Ljava/lang/String;)V");
+  SirtRef<ClassLoader> class_loader(LoadDex("Main"));
+  CompileDirectMethod(class_loader.get(), "Main", "main", "([Ljava/lang/String;)V");
 
-  Class* klass = class_linker_->FindClass("LMain;", class_loader);
+  Class* klass = class_linker_->FindClass("LMain;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("main", "([Ljava/lang/String;)V");
@@ -882,10 +882,10 @@
 }
 
 TEST_F(JniInternalTest, StaticNopMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "nop", "()V");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "nop", "()V");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("nop", "()V");
@@ -897,10 +897,10 @@
 }
 
 TEST_F(JniInternalTest, StaticIdentityByteMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "identity", "(B)B");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "identity", "(B)B");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("identity", "(B)B");
@@ -933,10 +933,10 @@
 }
 
 TEST_F(JniInternalTest, StaticIdentityIntMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "identity", "(I)I");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "identity", "(I)I");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("identity", "(I)I");
@@ -969,10 +969,10 @@
 }
 
 TEST_F(JniInternalTest, StaticIdentityDoubleMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "identity", "(D)D");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "identity", "(D)D");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("identity", "(D)D");
@@ -1005,10 +1005,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumIntIntMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(II)I");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(II)I");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(II)I");
@@ -1051,10 +1051,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumIntIntIntMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(III)I");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(III)I");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(III)I");
@@ -1102,10 +1102,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumIntIntIntIntMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(IIII)I");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(IIII)I");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(IIII)I");
@@ -1158,10 +1158,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumIntIntIntIntIntMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(IIIII)I");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(IIIII)I");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(IIIII)I");
@@ -1219,10 +1219,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumDoubleDoubleMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(DD)D");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(DD)D");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(DD)D");
@@ -1265,10 +1265,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumDoubleDoubleDoubleMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(DDD)D");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(DDD)D");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(DDD)D");
@@ -1302,10 +1302,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumDoubleDoubleDoubleDoubleMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(DDDD)D");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(DDDD)D");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(DDDD)D");
@@ -1342,10 +1342,10 @@
 }
 
 TEST_F(JniInternalTest, StaticSumDoubleDoubleDoubleDoubleDoubleMethod) {
-  const ClassLoader* class_loader = LoadDex("StaticLeafMethods");
-  CompileDirectMethod(class_loader, "StaticLeafMethods", "sum", "(DDDDD)D");
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticLeafMethods"));
+  CompileDirectMethod(class_loader.get(), "StaticLeafMethods", "sum", "(DDDDD)D");
 
-  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticLeafMethods;", class_loader.get());
   ASSERT_TRUE(klass != NULL);
 
   Method* method = klass->FindDirectMethod("sum", "(DDDDD)D");
diff --git a/src/mark_sweep.cc b/src/mark_sweep.cc
index 71983fe..d70d10d 100644
--- a/src/mark_sweep.cc
+++ b/src/mark_sweep.cc
@@ -216,29 +216,8 @@
 // Scans the header, static field references, and interface pointers
 // of a class object.
 inline void MarkSweep::ScanClass(const Object* obj) {
-  DCHECK(obj != NULL);
-  DCHECK(obj->IsClass());
-  const Class* klass = obj->AsClass();
-  MarkObject(klass->GetClass());
   ScanInstanceFields(obj);
-  MarkObject(klass->GetDescriptor());
-  MarkObject(klass->GetDexCache());
-  MarkObject(klass->GetVerifyErrorClass());
-  if (klass->IsArrayClass()) {
-    MarkObject(klass->GetComponentType());
-  }
-  if (klass->IsLoaded()) {
-    MarkObject(klass->GetSuperClass());
-  }
-  MarkObject(klass->GetClassLoader());
-  if (klass->IsLoaded()) {
-    MarkObject(klass->GetInterfaces());
-    MarkObject(klass->GetDirectMethods());
-    MarkObject(klass->GetVirtualMethods());
-    MarkObject(klass->GetIFields());
-    MarkObject(klass->GetSFields());
-  }
-  ScanStaticFields(klass);
+  ScanStaticFields(obj->AsClass());
 }
 
 // Scans the header of all array objects.  If the array object is
diff --git a/src/oat_test.cc b/src/oat_test.cc
index 392a822..3433e44 100644
--- a/src/oat_test.cc
+++ b/src/oat_test.cc
@@ -12,18 +12,18 @@
 TEST_F(OatTest, WriteRead) {
   const bool compile = false;  // DISABLED_ due to the time to compile libcore
 
-  const ClassLoader* class_loader = NULL;
+  SirtRef<ClassLoader> class_loader(NULL);
   if (compile) {
     compiler_.reset(new Compiler(kThumb2, false));
-    compiler_->CompileAll(class_loader);
+    compiler_->CompileAll(class_loader.get());
   }
 
   ScratchFile tmp;
-  bool success = OatWriter::Create(tmp.GetFilename(), class_loader, *compiler_.get());
+  bool success = OatWriter::Create(tmp.GetFilename(), class_loader.get(), *compiler_.get());
   ASSERT_TRUE(success);
 
   if (compile) {  // OatWriter strips the code, regenerate to compare
-    compiler_->CompileAll(class_loader);
+    compiler_->CompileAll(class_loader.get());
   }
   UniquePtr<OatFile> oat_file(OatFile::Open(std::string(tmp.GetFilename()), "", NULL));
   ASSERT_TRUE(oat_file.get() != NULL);
@@ -43,7 +43,7 @@
 
     UniquePtr<const OatFile::OatClass> oat_class(oat_dex_file->GetOatClass(i));
 
-    Class* klass = class_linker->FindClass(descriptor, class_loader);
+    Class* klass = class_linker->FindClass(descriptor, class_loader.get());
 
     size_t method_index = 0;
     for (size_t i = 0; i < klass->NumDirectMethods(); i++, method_index++) {
diff --git a/src/object.cc b/src/object.cc
index b28594a..6ea1f59 100644
--- a/src/object.cc
+++ b/src/object.cc
@@ -30,23 +30,23 @@
   // Object::SizeOf gets the right size even if we're an array.
   // Using c->AllocObject() here would be wrong.
   size_t num_bytes = SizeOf();
-  Object* copy = Heap::AllocObject(c, num_bytes);
-  if (copy == NULL) {
+  SirtRef<Object> copy(Heap::AllocObject(c, num_bytes));
+  if (copy.get() == NULL) {
     return NULL;
   }
 
   // Copy instance data.  We assume memcpy copies by words.
   // TODO: expose and use move32.
   byte* src_bytes = reinterpret_cast<byte*>(this);
-  byte* dst_bytes = reinterpret_cast<byte*>(copy);
+  byte* dst_bytes = reinterpret_cast<byte*>(copy.get());
   size_t offset = sizeof(Object);
   memcpy(dst_bytes + offset, src_bytes + offset, num_bytes - offset);
 
   if (c->IsFinalizable()) {
-    Heap::AddFinalizerReference(copy);
+    Heap::AddFinalizerReference(copy.get());
   }
 
-  return copy;
+  return copy.get();
 }
 
 uint32_t Object::GetThinLockId() {
@@ -1329,14 +1329,15 @@
 }
 
 String* String::Alloc(Class* java_lang_String, int32_t utf16_length) {
-  CharArray* array = CharArray::Alloc(utf16_length);
-  if (array == NULL) {
+  SirtRef<CharArray> array(CharArray::Alloc(utf16_length));
+  if (array.get() == NULL) {
     return NULL;
   }
-  return Alloc(java_lang_String, array);
+  return Alloc(java_lang_String, array.get());
 }
 
 String* String::Alloc(Class* java_lang_String, CharArray* array) {
+  SirtRef<CharArray> array_ref(array);  // hold reference in case AllocObject causes GC
   String* string = down_cast<String*>(java_lang_String->AllocObject());
   if (string == NULL) {
     return NULL;
diff --git a/src/object_test.cc b/src/object_test.cc
index e5af663..c493126 100644
--- a/src/object_test.cc
+++ b/src/object_test.cc
@@ -27,7 +27,7 @@
       utf16_expected[i] = ch;
     }
 
-    String* string = String::AllocFromModifiedUtf8(length, utf8_in);
+    SirtRef<String> string(String::AllocFromModifiedUtf8(length, utf8_in));
     ASSERT_EQ(length, string->GetLength());
     ASSERT_TRUE(string->GetCharArray() != NULL);
     ASSERT_TRUE(string->GetCharArray()->GetData() != NULL);
@@ -42,20 +42,22 @@
 
 TEST_F(ObjectTest, IsInSamePackage) {
   // Matches
-  EXPECT_TRUE(Class::IsInSamePackage(String::AllocFromModifiedUtf8("Ljava/lang/Object;"),
-                                     String::AllocFromModifiedUtf8("Ljava/lang/Class")));
-  EXPECT_TRUE(Class::IsInSamePackage(String::AllocFromModifiedUtf8("LFoo;"),
-                                     String::AllocFromModifiedUtf8("LBar;")));
+  SirtRef<String> Object_descriptor(String::AllocFromModifiedUtf8("Ljava/lang/Object;"));
+  SirtRef<String> Class_descriptor(String::AllocFromModifiedUtf8("Ljava/lang/Class;"));
+  EXPECT_TRUE(Class::IsInSamePackage(Object_descriptor.get(), Class_descriptor.get()));
+  SirtRef<String> Foo_descriptor(String::AllocFromModifiedUtf8("LFoo;"));
+  SirtRef<String> Bar_descriptor(String::AllocFromModifiedUtf8("LBar;"));
+  EXPECT_TRUE(Class::IsInSamePackage(Foo_descriptor.get(), Bar_descriptor.get()));
 
   // Mismatches
-  EXPECT_FALSE(Class::IsInSamePackage(String::AllocFromModifiedUtf8("Ljava/lang/Object;"),
-                                      String::AllocFromModifiedUtf8("Ljava/io/File;")));
-  EXPECT_FALSE(Class::IsInSamePackage(String::AllocFromModifiedUtf8("Ljava/lang/Object;"),
-                                      String::AllocFromModifiedUtf8("Ljava/lang/reflect/Method;")));
+  SirtRef<String> File_descriptor(String::AllocFromModifiedUtf8("Ljava/io/File;"));
+  EXPECT_FALSE(Class::IsInSamePackage(Object_descriptor.get(), File_descriptor.get()));
+  SirtRef<String> Method_descriptor(String::AllocFromModifiedUtf8("Ljava/lang/reflect/Method;"));
+  EXPECT_FALSE(Class::IsInSamePackage(Object_descriptor.get(), Method_descriptor.get()));
 }
 
 TEST_F(ObjectTest, Clone) {
-  ObjectArray<Object>* a1 = class_linker_->AllocObjectArray<Object>(256);
+  SirtRef<ObjectArray<Object> > a1(class_linker_->AllocObjectArray<Object>(256));
   size_t s1 = a1->SizeOf();
   Object* clone = a1->Clone();
   EXPECT_EQ(s1, clone->SizeOf());
@@ -63,16 +65,16 @@
 }
 
 TEST_F(ObjectTest, AllocObjectArray) {
-  ObjectArray<Object>* oa = class_linker_->AllocObjectArray<Object>(2);
+  SirtRef<ObjectArray<Object> > oa(class_linker_->AllocObjectArray<Object>(2));
   EXPECT_EQ(2, oa->GetLength());
   EXPECT_TRUE(oa->Get(0) == NULL);
   EXPECT_TRUE(oa->Get(1) == NULL);
-  oa->Set(0, oa);
-  EXPECT_TRUE(oa->Get(0) == oa);
+  oa->Set(0, oa.get());
+  EXPECT_TRUE(oa->Get(0) == oa.get());
   EXPECT_TRUE(oa->Get(1) == NULL);
-  oa->Set(1, oa);
-  EXPECT_TRUE(oa->Get(0) == oa);
-  EXPECT_TRUE(oa->Get(1) == oa);
+  oa->Set(1, oa.get());
+  EXPECT_TRUE(oa->Get(0) == oa.get());
+  EXPECT_TRUE(oa->Get(1) == oa.get());
 
   Thread* self = Thread::Current();
   Class* aioobe = class_linker_->FindSystemClass("Ljava/lang/ArrayIndexOutOfBoundsException;");
@@ -97,15 +99,15 @@
 
 TEST_F(ObjectTest, AllocArray) {
   Class* c = class_linker_->FindSystemClass("[I");
-  Array* a = Array::Alloc(c, 1);
+  SirtRef<Array> a(Array::Alloc(c, 1));
   ASSERT_TRUE(c == a->GetClass());
 
   c = class_linker_->FindSystemClass("[Ljava/lang/Object;");
-  a = Array::Alloc(c, 1);
+  a.reset(Array::Alloc(c, 1));
   ASSERT_TRUE(c == a->GetClass());
 
   c = class_linker_->FindSystemClass("[[Ljava/lang/Object;");
-  a = Array::Alloc(c, 1);
+  a.reset(Array::Alloc(c, 1));
   ASSERT_TRUE(c == a->GetClass());
 }
 
@@ -177,20 +179,20 @@
 
 TEST_F(ObjectTest, StaticFieldFromCode) {
   // pretend we are trying to access 'Static.s0' from StaticsFromCode.<clinit>
-  const ClassLoader* class_loader = LoadDex("StaticsFromCode");
-  const DexFile* dex_file = ClassLoader::GetCompileTimeClassPath(class_loader)[0];
+  SirtRef<ClassLoader> class_loader(LoadDex("StaticsFromCode"));
+  const DexFile* dex_file = ClassLoader::GetCompileTimeClassPath(class_loader.get())[0];
   CHECK(dex_file != NULL);
 
-  Class* klass = class_linker_->FindClass("LStaticsFromCode;", class_loader);
+  Class* klass = class_linker_->FindClass("LStaticsFromCode;", class_loader.get());
   Method* clinit = klass->FindDirectMethod("<clinit>", "()V");
   uint32_t field_idx = FindFieldIdxByDescriptorAndName(*dex_file, "LStaticsFromCode;", "s0");
   Field* field = FindFieldFromCode(field_idx, clinit, true);
   Object* s0 = field->GetObj(NULL);
   EXPECT_EQ(NULL, s0);
 
-  CharArray* char_array = CharArray::Alloc(0);
-  field->SetObj(NULL, char_array);
-  EXPECT_EQ(char_array, field->GetObj(NULL));
+  SirtRef<CharArray> char_array(CharArray::Alloc(0));
+  field->SetObj(NULL, char_array.get());
+  EXPECT_EQ(char_array.get(), field->GetObj(NULL));
 
   field->SetObj(NULL, NULL);
   EXPECT_EQ(NULL, field->GetObj(NULL));
@@ -222,7 +224,7 @@
 }
 
 TEST_F(ObjectTest, StringEqualsUtf8) {
-  String* string = String::AllocFromModifiedUtf8("android");
+  SirtRef<String> string(String::AllocFromModifiedUtf8("android"));
   EXPECT_TRUE(string->Equals("android"));
   EXPECT_FALSE(string->Equals("Android"));
   EXPECT_FALSE(string->Equals("ANDROID"));
@@ -230,21 +232,22 @@
   EXPECT_FALSE(string->Equals("and"));
   EXPECT_FALSE(string->Equals("androids"));
 
-  String* empty = String::AllocFromModifiedUtf8("");
+  SirtRef<String> empty(String::AllocFromModifiedUtf8(""));
   EXPECT_TRUE(empty->Equals(""));
   EXPECT_FALSE(empty->Equals("a"));
 }
 
 TEST_F(ObjectTest, StringEquals) {
-  String* string = String::AllocFromModifiedUtf8("android");
-  EXPECT_TRUE(string->Equals(String::AllocFromModifiedUtf8("android")));
+  SirtRef<String> string(String::AllocFromModifiedUtf8("android"));
+  SirtRef<String> string_2(String::AllocFromModifiedUtf8("android"));
+  EXPECT_TRUE(string->Equals(string_2.get()));
   EXPECT_FALSE(string->Equals("Android"));
   EXPECT_FALSE(string->Equals("ANDROID"));
   EXPECT_FALSE(string->Equals(""));
   EXPECT_FALSE(string->Equals("and"));
   EXPECT_FALSE(string->Equals("androids"));
 
-  String* empty = String::AllocFromModifiedUtf8("");
+  SirtRef<String> empty(String::AllocFromModifiedUtf8(""));
   EXPECT_TRUE(empty->Equals(""));
   EXPECT_FALSE(empty->Equals("a"));
 }
@@ -252,12 +255,12 @@
 TEST_F(ObjectTest, DescriptorCompare) {
   ClassLinker* linker = class_linker_;
 
-  const ClassLoader* class_loader_1 = LoadDex("ProtoCompare");
-  const ClassLoader* class_loader_2 = LoadDex("ProtoCompare2");
+  SirtRef<ClassLoader> class_loader_1(LoadDex("ProtoCompare"));
+  SirtRef<ClassLoader> class_loader_2(LoadDex("ProtoCompare2"));
 
-  Class* klass1 = linker->FindClass("LProtoCompare;", class_loader_1);
+  Class* klass1 = linker->FindClass("LProtoCompare;", class_loader_1.get());
   ASSERT_TRUE(klass1 != NULL);
-  Class* klass2 = linker->FindClass("LProtoCompare2;", class_loader_2);
+  Class* klass2 = linker->FindClass("LProtoCompare2;", class_loader_2.get());
   ASSERT_TRUE(klass2 != NULL);
 
   Method* m1_1 = klass1->GetVirtualMethod(0);
@@ -293,22 +296,26 @@
 
 
 TEST_F(ObjectTest, StringHashCode) {
-  EXPECT_EQ(0, String::AllocFromModifiedUtf8("")->GetHashCode());
-  EXPECT_EQ(65, String::AllocFromModifiedUtf8("A")->GetHashCode());
-  EXPECT_EQ(64578, String::AllocFromModifiedUtf8("ABC")->GetHashCode());
+  SirtRef<String> empty(String::AllocFromModifiedUtf8(""));
+  SirtRef<String> A(String::AllocFromModifiedUtf8("A"));
+  SirtRef<String> ABC(String::AllocFromModifiedUtf8("ABC"));
+
+  EXPECT_EQ(0, empty->GetHashCode());
+  EXPECT_EQ(65, A->GetHashCode());
+  EXPECT_EQ(64578, ABC->GetHashCode());
 }
 
 TEST_F(ObjectTest, InstanceOf) {
-  const ClassLoader* class_loader = LoadDex("XandY");
-  Class* X = class_linker_->FindClass("LX;", class_loader);
-  Class* Y = class_linker_->FindClass("LY;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("XandY"));
+  Class* X = class_linker_->FindClass("LX;", class_loader.get());
+  Class* Y = class_linker_->FindClass("LY;", class_loader.get());
   ASSERT_TRUE(X != NULL);
   ASSERT_TRUE(Y != NULL);
 
-  Object* x = X->AllocObject();
-  Object* y = Y->AllocObject();
-  ASSERT_TRUE(x != NULL);
-  ASSERT_TRUE(y != NULL);
+  SirtRef<Object> x(X->AllocObject());
+  SirtRef<Object> y(Y->AllocObject());
+  ASSERT_TRUE(x.get() != NULL);
+  ASSERT_TRUE(y.get() != NULL);
 
   EXPECT_EQ(1U, IsAssignableFromCode(X, x->GetClass()));
   EXPECT_EQ(0U, IsAssignableFromCode(Y, x->GetClass()));
@@ -335,9 +342,9 @@
 }
 
 TEST_F(ObjectTest, IsAssignableFrom) {
-  const ClassLoader* class_loader = LoadDex("XandY");
-  Class* X = class_linker_->FindClass("LX;", class_loader);
-  Class* Y = class_linker_->FindClass("LY;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("XandY"));
+  Class* X = class_linker_->FindClass("LX;", class_loader.get());
+  Class* Y = class_linker_->FindClass("LY;", class_loader.get());
 
   EXPECT_TRUE(X->IsAssignableFrom(X));
   EXPECT_TRUE(X->IsAssignableFrom(Y));
@@ -346,18 +353,18 @@
 }
 
 TEST_F(ObjectTest, IsAssignableFromArray) {
-  const ClassLoader* class_loader = LoadDex("XandY");
-  Class* X = class_linker_->FindClass("LX;", class_loader);
-  Class* Y = class_linker_->FindClass("LY;", class_loader);
+  SirtRef<ClassLoader> class_loader(LoadDex("XandY"));
+  Class* X = class_linker_->FindClass("LX;", class_loader.get());
+  Class* Y = class_linker_->FindClass("LY;", class_loader.get());
   ASSERT_TRUE(X != NULL);
   ASSERT_TRUE(Y != NULL);
 
-  Class* YA = class_linker_->FindClass("[LY;", class_loader);
-  Class* YAA = class_linker_->FindClass("[[LY;", class_loader);
+  Class* YA = class_linker_->FindClass("[LY;", class_loader.get());
+  Class* YAA = class_linker_->FindClass("[[LY;", class_loader.get());
   ASSERT_TRUE(YA != NULL);
   ASSERT_TRUE(YAA != NULL);
 
-  Class* XAA = class_linker_->FindClass("[[LX;", class_loader);
+  Class* XAA = class_linker_->FindClass("[[LX;", class_loader.get());
   ASSERT_TRUE(XAA != NULL);
 
   Class* O = class_linker_->FindSystemClass("Ljava/lang/Object;");
@@ -397,8 +404,8 @@
 }
 
 TEST_F(ObjectTest, FindInstanceField) {
-  String* s = String::AllocFromModifiedUtf8("ABC");
-  ASSERT_TRUE(s != NULL);
+  SirtRef<String> s(String::AllocFromModifiedUtf8("ABC"));
+  ASSERT_TRUE(s.get() != NULL);
   Class* c = s->GetClass();
   ASSERT_TRUE(c != NULL);
 
@@ -429,8 +436,8 @@
 }
 
 TEST_F(ObjectTest, FindStaticField) {
-  String* s = String::AllocFromModifiedUtf8("ABC");
-  ASSERT_TRUE(s != NULL);
+  SirtRef<String> s(String::AllocFromModifiedUtf8("ABC"));
+  ASSERT_TRUE(s.get() != NULL);
   Class* c = s->GetClass();
   ASSERT_TRUE(c != NULL);
 
diff --git a/src/runtime.cc b/src/runtime.cc
index 6148ae7..64823fd 100644
--- a/src/runtime.cc
+++ b/src/runtime.cc
@@ -726,7 +726,7 @@
 
 Method* Runtime::CreateCalleeSaveMethod(InstructionSet insns, CalleeSaveType type) {
   Class* method_class = Method::GetMethodClass();
-  Method* method = down_cast<Method*>(method_class->AllocObject());
+  SirtRef<Method> method(down_cast<Method*>(method_class->AllocObject()));
   method->SetDeclaringClass(method_class);
   const char* name;
   if (type == kSaveAll) {
@@ -777,7 +777,7 @@
   } else {
     UNIMPLEMENTED(FATAL);
   }
-  return method;
+  return method.get();
 }
 
 bool Runtime::HasCalleeSaveMethod(CalleeSaveType type) const {
diff --git a/src/stack_indirect_reference_table.h b/src/stack_indirect_reference_table.h
index f0b6698..8b98763 100644
--- a/src/stack_indirect_reference_table.h
+++ b/src/stack_indirect_reference_table.h
@@ -23,22 +23,55 @@
 
 class Object;
 
-// Stack allocated indirect reference table, allocated within the bridge frame
-// between managed and native code.
+// Stack allocated indirect reference table. It can allocated within
+// the bridge frame between managed and native code backed by stack
+// storage or manually allocated by SirtRef to hold one reference.
 class StackIndirectReferenceTable {
 public:
+
+  StackIndirectReferenceTable(Object* object) {
+    number_of_references_ = 1;
+    references_[0] = object;
+    Thread::Current()->PushSirt(this);
+  }
+
+  ~StackIndirectReferenceTable() {
+    StackIndirectReferenceTable* sirt = Thread::Current()->PopSirt();
+    CHECK_EQ(this, sirt);
+  }
+
   // Number of references contained within this SIRT
-  size_t NumberOfReferences() {
+  size_t NumberOfReferences() const {
     return number_of_references_;
   }
 
   // Link to previous SIRT or NULL
-  StackIndirectReferenceTable* Link() {
+  StackIndirectReferenceTable* GetLink() const {
     return link_;
   }
 
-  Object** References() {
-    return references_;
+  void SetLink(StackIndirectReferenceTable* sirt) {
+    DCHECK_NE(this, sirt);
+    link_ = sirt;
+  }
+
+  Object* GetReference(size_t i) const {
+    DCHECK_LT(i, number_of_references_);
+    return references_[i];
+  }
+
+  void SetReference(size_t i, Object* object) {
+    DCHECK_LT(i, number_of_references_);
+    references_[i] = object;
+  }
+
+  bool Contains(Object** sirt_entry) const {
+    // A SIRT should always contain something. One created by the
+    // jni_compiler should have a jobject/jclass as a native method is
+    // passed in a this pointer or a class
+    DCHECK_GT(number_of_references_, 0U);
+    return ((&references_[0] <= sirt_entry)
+            && (sirt_entry <= (&references_[number_of_references_ - 1])));
   }
 
   // Offset of length within SIRT, used by generated code
@@ -57,12 +90,34 @@
   size_t number_of_references_;
   StackIndirectReferenceTable* link_;
 
-  // Fake array, really allocated and filled in by jni_compiler.
-  Object* references_[0];
+  // number_of_references_ are available if this is allocated and filled in by jni_compiler.
+  Object* references_[1];
 
   DISALLOW_COPY_AND_ASSIGN(StackIndirectReferenceTable);
 };
 
+template<class T>
+class SirtRef {
+public:
+  SirtRef(T* object) : sirt_(object) {}
+  ~SirtRef() {}
+
+  T& operator*() const { return *get(); }
+  T* operator->() const { return get(); }
+  T* get() const {
+    return down_cast<T*>(sirt_.GetReference(0));
+  }
+
+  void reset(T* object = NULL) {
+    sirt_.SetReference(0, object);
+  }
+
+private:
+  StackIndirectReferenceTable sirt_;
+
+  DISALLOW_COPY_AND_ASSIGN(SirtRef);
+};
+
 }  // namespace art
 
 #endif  // ART_SRC_STACK_INDIRECT_REFERENCE_TABLE_H_
diff --git a/src/stub_arm.cc b/src/stub_arm.cc
index dfbd524..3768721 100644
--- a/src/stub_arm.cc
+++ b/src/stub_arm.cc
@@ -3,6 +3,7 @@
 #include "assembler_arm.h"
 #include "jni_internal.h"
 #include "object.h"
+#include "stack_indirect_reference_table.h"
 
 #define __ assembler->
 
@@ -38,12 +39,12 @@
 
   assembler->EmitSlowPaths();
   size_t cs = assembler->CodeSize();
-  ByteArray* resolution_trampoline = ByteArray::Alloc(cs);
-  CHECK(resolution_trampoline != NULL);
+  SirtRef<ByteArray> resolution_trampoline(ByteArray::Alloc(cs));
+  CHECK(resolution_trampoline.get() != NULL);
   MemoryRegion code(resolution_trampoline->GetData(), resolution_trampoline->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return resolution_trampoline;
+  return resolution_trampoline.get();
 }
 
 typedef void (*ThrowAme)(Method*, Thread*);
@@ -69,13 +70,13 @@
   assembler->EmitSlowPaths();
 
   size_t cs = assembler->CodeSize();
-  ByteArray* abstract_stub = ByteArray::Alloc(cs);
-  CHECK(abstract_stub != NULL);
+  SirtRef<ByteArray> abstract_stub(ByteArray::Alloc(cs));
+  CHECK(abstract_stub.get() != NULL);
   CHECK(abstract_stub->GetClass()->GetDescriptor());
   MemoryRegion code(abstract_stub->GetData(), abstract_stub->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return abstract_stub;
+  return abstract_stub.get();
 }
 
 ByteArray* CreateJniStub() {
@@ -98,12 +99,12 @@
   assembler->EmitSlowPaths();
 
   size_t cs = assembler->CodeSize();
-  ByteArray* jni_stub = ByteArray::Alloc(cs);
-  CHECK(jni_stub != NULL);
+  SirtRef<ByteArray> jni_stub(ByteArray::Alloc(cs));
+  CHECK(jni_stub.get() != NULL);
   MemoryRegion code(jni_stub->GetData(), jni_stub->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return jni_stub;
+  return jni_stub.get();
 }
 
 } // namespace arm
diff --git a/src/stub_x86.cc b/src/stub_x86.cc
index ea745ee..7660f6f 100644
--- a/src/stub_x86.cc
+++ b/src/stub_x86.cc
@@ -3,6 +3,7 @@
 #include "assembler_x86.h"
 #include "jni_internal.h"
 #include "object.h"
+#include "stack_indirect_reference_table.h"
 
 #define __ assembler->
 
@@ -17,12 +18,12 @@
 
   assembler->EmitSlowPaths();
   size_t cs = assembler->CodeSize();
-  ByteArray* resolution_trampoline = ByteArray::Alloc(cs);
-  CHECK(resolution_trampoline != NULL);
+  SirtRef<ByteArray> resolution_trampoline(ByteArray::Alloc(cs));
+  CHECK(resolution_trampoline.get() != NULL);
   MemoryRegion code(resolution_trampoline->GetData(), resolution_trampoline->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return resolution_trampoline;
+  return resolution_trampoline.get();
 }
 
 typedef void (*ThrowAme)(Method*, Thread*);
@@ -46,12 +47,12 @@
   assembler->EmitSlowPaths();
 
   size_t cs = assembler->CodeSize();
-  ByteArray* abstract_stub = ByteArray::Alloc(cs);
-  CHECK(abstract_stub != NULL);
+  SirtRef<ByteArray> abstract_stub(ByteArray::Alloc(cs));
+  CHECK(abstract_stub.get() != NULL);
   MemoryRegion code(abstract_stub->GetData(), abstract_stub->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return abstract_stub;
+  return abstract_stub.get();
 }
 
 ByteArray* CreateJniStub() {
@@ -79,12 +80,12 @@
   assembler->EmitSlowPaths();
 
   size_t cs = assembler->CodeSize();
-  ByteArray* jni_stub = ByteArray::Alloc(cs);
-  CHECK(jni_stub != NULL);
+  SirtRef<ByteArray> jni_stub(ByteArray::Alloc(cs));
+  CHECK(jni_stub.get() != NULL);
   MemoryRegion code(jni_stub->GetData(), jni_stub->GetLength());
   assembler->FinalizeInstructions(code);
 
-  return jni_stub;
+  return jni_stub.get();
 }
 
 } // namespace x86
diff --git a/src/thread.cc b/src/thread.cc
index 73b8a94..7dc8104 100644
--- a/src/thread.cc
+++ b/src/thread.cc
@@ -826,7 +826,7 @@
 
 size_t Thread::NumSirtReferences() {
   size_t count = 0;
-  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->Link()) {
+  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->GetLink()) {
     count += cur->NumberOfReferences();
   }
   return count;
@@ -834,13 +834,8 @@
 
 bool Thread::SirtContains(jobject obj) {
   Object** sirt_entry = reinterpret_cast<Object**>(obj);
-  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->Link()) {
-    size_t num_refs = cur->NumberOfReferences();
-    // A SIRT should always have a jobject/jclass as a native method is passed
-    // in a this pointer or a class
-    DCHECK_GT(num_refs, 0u);
-    if ((&cur->References()[0] <= sirt_entry) &&
-        (sirt_entry <= (&cur->References()[num_refs - 1]))) {
+  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->GetLink()) {
+    if (cur->Contains(sirt_entry)) {
       return true;
     }
   }
@@ -848,10 +843,10 @@
 }
 
 void Thread::SirtVisitRoots(Heap::RootVisitor* visitor, void* arg) {
-  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->Link()) {
+  for (StackIndirectReferenceTable* cur = top_sirt_; cur; cur = cur->GetLink()) {
     size_t num_refs = cur->NumberOfReferences();
     for (size_t j = 0; j < num_refs; j++) {
-      Object* object = cur->References()[j];
+      Object* object = cur->GetReference(j);
       if (object != NULL) {
         visitor(object, arg);
       }
@@ -1025,6 +1020,18 @@
   return pc + 2;
 }
 
+void Thread::PushSirt(StackIndirectReferenceTable* sirt) {
+  sirt->SetLink(top_sirt_);
+  top_sirt_ = sirt;
+}
+
+StackIndirectReferenceTable* Thread::PopSirt() {
+  CHECK(top_sirt_ != NULL);
+  StackIndirectReferenceTable* sirt = top_sirt_;
+  top_sirt_ = top_sirt_->GetLink();
+  return sirt;
+}
+
 void Thread::WalkStack(StackVisitor* visitor) const {
   Frame frame = GetTopOfStack();
   uintptr_t pc = ManglePc(top_of_managed_stack_pc_);
@@ -1129,9 +1136,15 @@
       line_number = dex_file.GetLineNumFromPC(method, method->ToDexPC(native_pc));
     }
     // Allocate element, potentially triggering GC
-    StackTraceElement* obj =
-        StackTraceElement::Alloc(String::AllocFromModifiedUtf8(class_name.c_str()),
-                                 method->GetName(), klass->GetSourceFile(), line_number);
+    // TODO: reuse class_name_object via Class::name_?
+    SirtRef<String> class_name_object(String::AllocFromModifiedUtf8(class_name.c_str()));
+    if (class_name_object.get() == NULL) {
+      return NULL;
+    }
+    StackTraceElement* obj = StackTraceElement::Alloc(class_name_object.get(),
+                                                      method->GetName(),
+                                                      klass->GetSourceFile(),
+                                                      line_number);
     if (obj == NULL) {
       return NULL;
     }
@@ -1417,6 +1430,9 @@
   if (pre_allocated_OutOfMemoryError_ != NULL) {
     visitor(pre_allocated_OutOfMemoryError_, arg);
   }
+  if (class_loader_override_ != NULL) {
+    visitor(class_loader_override_, arg);
+  }
   jni_env_->locals.VisitRoots(visitor, arg);
   jni_env_->monitors.VisitRoots(visitor, arg);
 
diff --git a/src/thread.h b/src/thread.h
index f30d092..37c1042 100644
--- a/src/thread.h
+++ b/src/thread.h
@@ -47,12 +47,12 @@
 class Monitor;
 class Object;
 class Runtime;
-class Thread;
-class ThreadList;
-class Throwable;
 class StackIndirectReferenceTable;
 class StackTraceElement;
 class StaticStorageBase;
+class Thread;
+class ThreadList;
+class Throwable;
 
 template<class T> class ObjectArray;
 template<class T> class PrimitiveArray;
@@ -446,6 +446,9 @@
     return ThreadOffset(OFFSETOF_MEMBER(Thread, top_of_managed_stack_pc_));
   }
 
+  void PushSirt(StackIndirectReferenceTable* sirt);
+  StackIndirectReferenceTable* PopSirt();
+
   static ThreadOffset TopSirtOffset() {
     return ThreadOffset(OFFSETOF_MEMBER(Thread, top_sirt_));
   }
diff --git a/src/utils_test.cc b/src/utils_test.cc
index 0869b89..bf66d05 100644
--- a/src/utils_test.cc
+++ b/src/utils_test.cc
@@ -11,8 +11,8 @@
 
 #define EXPECT_DESCRIPTOR(pretty_descriptor, descriptor) \
   do { \
-    String* s = String::AllocFromModifiedUtf8(descriptor); \
-    std::string result(PrettyDescriptor(s)); \
+    SirtRef<String> s(String::AllocFromModifiedUtf8(descriptor)); \
+    std::string result(PrettyDescriptor(s.get())); \
     EXPECT_EQ(pretty_descriptor, result); \
   } while (false)
 
@@ -59,11 +59,11 @@
 TEST_F(UtilsTest, PrettyTypeOf) {
   EXPECT_EQ("null", PrettyTypeOf(NULL));
 
-  String* s = String::AllocFromModifiedUtf8("");
-  EXPECT_EQ("java.lang.String", PrettyTypeOf(s));
+  SirtRef<String> s(String::AllocFromModifiedUtf8(""));
+  EXPECT_EQ("java.lang.String", PrettyTypeOf(s.get()));
 
-  ShortArray* a = ShortArray::Alloc(2);
-  EXPECT_EQ("short[]", PrettyTypeOf(a));
+  SirtRef<ShortArray> a(ShortArray::Alloc(2));
+  EXPECT_EQ("short[]", PrettyTypeOf(a.get()));
 
   Class* c = class_linker_->FindSystemClass("[Ljava/lang/String;");
   ASSERT_TRUE(c != NULL);
diff --git a/test/ReferenceMap/stack_walk_refmap_jni.cc b/test/ReferenceMap/stack_walk_refmap_jni.cc
index c19ccf9..ca46569 100644
--- a/test/ReferenceMap/stack_walk_refmap_jni.cc
+++ b/test/ReferenceMap/stack_walk_refmap_jni.cc
@@ -29,7 +29,7 @@
 
   void VisitFrame(const Frame& frame, uintptr_t pc) {
     Method* m = frame.GetMethod();
-    if (!m ||m->IsNative()) {
+    if (!m || m->IsNative()) {
       return;
     }
     LOG(INFO) << "At " << PrettyMethod(m, false);