InferAddressSpaces: Support atomics

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@293584 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp
index 46f0609..16f2fc0 100644
--- a/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -232,16 +232,25 @@
   std::vector<std::pair<Value*, bool>> PostorderStack;
   // The set of visited expressions.
   DenseSet<Value*> Visited;
+
+  auto PushPtrOperand = [&](Value *Ptr) {
+    appendsFlatAddressExpressionToPostorderStack(
+      Ptr, &PostorderStack, &Visited);
+  };
+
   // We only explore address expressions that are reachable from loads and
   // stores for now because we aim at generating faster loads and stores.
   for (Instruction &I : instructions(F)) {
-    if (isa<LoadInst>(I)) {
-      appendsFlatAddressExpressionToPostorderStack(
-        I.getOperand(0), &PostorderStack, &Visited);
-    } else if (isa<StoreInst>(I)) {
-      appendsFlatAddressExpressionToPostorderStack(
-        I.getOperand(1), &PostorderStack, &Visited);
-    }
+    if (auto *LI = dyn_cast<LoadInst>(&I))
+      PushPtrOperand(LI->getPointerOperand());
+    else if (auto *SI = dyn_cast<StoreInst>(&I))
+      PushPtrOperand(SI->getPointerOperand());
+    else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I))
+      PushPtrOperand(RMW->getPointerOperand());
+    else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I))
+      PushPtrOperand(CmpX->getPointerOperand());
+
+    // TODO: Support intrinsics
   }
 
   std::vector<Value *> Postorder; // The resultant postorder.
@@ -527,6 +536,30 @@
   return NewAS;
 }
 
+/// \p returns true if \p U is the pointer operand of a memory instruction with
+/// a single pointer operand that can have its address space changed by simply
+/// mutating the use to a new value.
+static bool isSimplePointerUseValidToReplace(Use &U) {
+  User *Inst = U.getUser();
+  unsigned OpNo = U.getOperandNo();
+
+  if (auto *LI = dyn_cast<LoadInst>(Inst))
+    return OpNo == LoadInst::getPointerOperandIndex() && !LI->isVolatile();
+
+  if (auto *SI = dyn_cast<StoreInst>(Inst))
+    return OpNo == StoreInst::getPointerOperandIndex() && !SI->isVolatile();
+
+  if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst))
+    return OpNo == AtomicRMWInst::getPointerOperandIndex() && !RMW->isVolatile();
+
+  if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) {
+    return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() &&
+           !CmpX->isVolatile();
+  }
+
+  return false;
+}
+
 bool InferAddressSpaces::rewriteWithNewAddressSpaces(
   const std::vector<Value *> &Postorder,
   const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const {
@@ -570,15 +603,10 @@
                  << "\n  with\n  " << *NewV << '\n');
 
     for (Use *U : Uses) {
-      LoadInst *LI = dyn_cast<LoadInst>(U->getUser());
-      StoreInst *SI = dyn_cast<StoreInst>(U->getUser());
-
-      if ((LI && !LI->isVolatile()) ||
-          (SI && !SI->isVolatile() &&
-           U->getOperandNo() == StoreInst::getPointerOperandIndex())) {
-        // If V is used as the pointer operand of a load/store, sets the pointer
-        // operand to NewV. This replacement does not change the element type,
-        // so the resultant load/store is still valid.
+      if (isSimplePointerUseValidToReplace(*U)) {
+        // If V is used as the pointer operand of a compatible memory operation,
+        // sets the pointer operand to NewV. This replacement does not change
+        // the element type, so the resultant load/store is still valid.
         U->set(NewV);
       } else if (isa<Instruction>(U->getUser())) {
         // Otherwise, replaces the use with flat(NewV).
diff --git a/test/Transforms/InferAddressSpaces/AMDGPU/basic.ll b/test/Transforms/InferAddressSpaces/AMDGPU/basic.ll
index 4f93841..67b4ccd 100644
--- a/test/Transforms/InferAddressSpaces/AMDGPU/basic.ll
+++ b/test/Transforms/InferAddressSpaces/AMDGPU/basic.ll
@@ -128,4 +128,46 @@
   ret void
 }
 
+; CHECK-LABEL: @atomicrmw_add_global_to_flat(
+; CHECK-NEXT: %ret = atomicrmw add i32 addrspace(1)* %global.ptr, i32 %y seq_cst
+define i32 @atomicrmw_add_global_to_flat(i32 addrspace(1)* %global.ptr, i32 %y) #0 {
+  %cast = addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+  %ret = atomicrmw add i32 addrspace(4)* %cast, i32 %y seq_cst
+  ret i32 %ret
+}
+
+; CHECK-LABEL: @atomicrmw_add_group_to_flat(
+; CHECK-NEXT: %ret = atomicrmw add i32 addrspace(3)* %group.ptr, i32 %y seq_cst
+define i32 @atomicrmw_add_group_to_flat(i32 addrspace(3)* %group.ptr, i32 %y) #0 {
+  %cast = addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+  %ret = atomicrmw add i32 addrspace(4)* %cast, i32 %y seq_cst
+  ret i32 %ret
+}
+
+; CHECK-LABEL: @cmpxchg_global_to_flat(
+; CHECK: %ret = cmpxchg i32 addrspace(1)* %global.ptr, i32 %cmp, i32 %val seq_cst monotonic
+define { i32, i1 } @cmpxchg_global_to_flat(i32 addrspace(1)* %global.ptr, i32 %cmp, i32 %val) #0 {
+  %cast = addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+  %ret = cmpxchg i32 addrspace(4)* %cast, i32 %cmp, i32 %val seq_cst monotonic
+  ret { i32, i1 } %ret
+}
+
+; CHECK-LABEL: @cmpxchg_group_to_flat(
+; CHECK: %ret = cmpxchg i32 addrspace(3)* %group.ptr, i32 %cmp, i32 %val seq_cst monotonic
+define { i32, i1 } @cmpxchg_group_to_flat(i32 addrspace(3)* %group.ptr, i32 %cmp, i32 %val) #0 {
+  %cast = addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+  %ret = cmpxchg i32 addrspace(4)* %cast, i32 %cmp, i32 %val seq_cst monotonic
+  ret { i32, i1 } %ret
+}
+
+; Not pointer operand
+; CHECK-LABEL: @cmpxchg_group_to_flat_wrong_operand(
+; CHECK: %cast.cmp = addrspacecast i32 addrspace(3)* %cmp.ptr to i32 addrspace(4)*
+; CHECK: %ret = cmpxchg i32 addrspace(4)* addrspace(3)* %cas.ptr, i32 addrspace(4)* %cast.cmp, i32 addrspace(4)* %val seq_cst monotonic
+define { i32 addrspace(4)*, i1 } @cmpxchg_group_to_flat_wrong_operand(i32 addrspace(4)* addrspace(3)* %cas.ptr, i32 addrspace(3)* %cmp.ptr, i32 addrspace(4)* %val) #0 {
+  %cast.cmp = addrspacecast i32 addrspace(3)* %cmp.ptr to i32 addrspace(4)*
+  %ret = cmpxchg i32 addrspace(4)* addrspace(3)* %cas.ptr, i32 addrspace(4)* %cast.cmp, i32 addrspace(4)* %val seq_cst monotonic
+  ret { i32 addrspace(4)*, i1 } %ret
+}
+
 attributes #0 = { nounwind }
diff --git a/test/Transforms/InferAddressSpaces/AMDGPU/volatile.ll b/test/Transforms/InferAddressSpaces/AMDGPU/volatile.ll
index 57dff1f..f32d65b 100644
--- a/test/Transforms/InferAddressSpaces/AMDGPU/volatile.ll
+++ b/test/Transforms/InferAddressSpaces/AMDGPU/volatile.ll
@@ -79,4 +79,40 @@
   ret void
 }
 
+; CHECK-LABEL: @volatile_atomicrmw_add_group_to_flat(
+; CHECK: addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+; CHECK: atomicrmw volatile add i32 addrspace(4)*
+define i32 @volatile_atomicrmw_add_group_to_flat(i32 addrspace(3)* %group.ptr, i32 %y) #0 {
+  %cast = addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+  %ret = atomicrmw volatile add i32 addrspace(4)* %cast, i32 %y seq_cst
+  ret i32 %ret
+}
+
+; CHECK-LABEL: @volatile_atomicrmw_add_global_to_flat(
+; CHECK: addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+; CHECK: %ret = atomicrmw volatile add i32 addrspace(4)*
+define i32 @volatile_atomicrmw_add_global_to_flat(i32 addrspace(1)* %global.ptr, i32 %y) #0 {
+  %cast = addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+  %ret = atomicrmw volatile add i32 addrspace(4)* %cast, i32 %y seq_cst
+  ret i32 %ret
+}
+
+; CHECK-LABEL: @volatile_cmpxchg_global_to_flat(
+; CHECK: addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+; CHECK: cmpxchg volatile i32 addrspace(4)*
+define { i32, i1 } @volatile_cmpxchg_global_to_flat(i32 addrspace(1)* %global.ptr, i32 %cmp, i32 %val) #0 {
+  %cast = addrspacecast i32 addrspace(1)* %global.ptr to i32 addrspace(4)*
+  %ret = cmpxchg volatile i32 addrspace(4)* %cast, i32 %cmp, i32 %val seq_cst monotonic
+  ret { i32, i1 } %ret
+}
+
+; CHECK-LABEL: @volatile_cmpxchg_group_to_flat(
+; CHECK: addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+; CHECK: cmpxchg volatile i32 addrspace(4)*
+define { i32, i1 } @volatile_cmpxchg_group_to_flat(i32 addrspace(3)* %group.ptr, i32 %cmp, i32 %val) #0 {
+  %cast = addrspacecast i32 addrspace(3)* %group.ptr to i32 addrspace(4)*
+  %ret = cmpxchg volatile i32 addrspace(4)* %cast, i32 %cmp, i32 %val seq_cst monotonic
+  ret { i32, i1 } %ret
+}
+
 attributes #0 = { nounwind }
\ No newline at end of file