Adjust the special non-C++ enum block return type inference
so that it looks through certain syntactic forms and applies
even if normal inference would have succeeded.

There is potential for source incompatibility from this
change, but overall we feel that it produces a much
cleaner and more defensible result, and the block
compatibility rules should curb a lot of the potential
for annoyance.

rdar://13200889

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@176743 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaLambda.cpp b/lib/Sema/SemaLambda.cpp
index 21202c1..468fa02 100644
--- a/lib/Sema/SemaLambda.cpp
+++ b/lib/Sema/SemaLambda.cpp
@@ -225,73 +225,153 @@
   }
 }
 
-static bool checkReturnValueType(const ASTContext &Ctx, const Expr *E,
-                                 QualType &DeducedType,
-                                 QualType &AlternateType) {
-  // Handle ReturnStmts with no expressions.
-  if (!E) {
-    if (AlternateType.isNull())
-      AlternateType = Ctx.VoidTy;
+/// If this expression is an enumerator-like expression of some type
+/// T, return the type T; otherwise, return null.
+///
+/// Pointer comparisons on the result here should always work because
+/// it's derived from either the parent of an EnumConstantDecl
+/// (i.e. the definition) or the declaration returned by
+/// EnumType::getDecl() (i.e. the definition).
+static EnumDecl *findEnumForBlockReturn(Expr *E) {
+  // An expression is an enumerator-like expression of type T if,
+  // ignoring parens and parens-like expressions:
+  E = E->IgnoreParens();
 
-    return Ctx.hasSameType(DeducedType, Ctx.VoidTy);
+  //  - it is an enumerator whose enum type is T or
+  if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E)) {
+    if (EnumConstantDecl *D
+          = dyn_cast<EnumConstantDecl>(DRE->getDecl())) {
+      return cast<EnumDecl>(D->getDeclContext());
+    }
+    return 0;
   }
 
-  QualType StrictType = E->getType();
-  QualType LooseType = StrictType;
+  //  - it is a comma expression whose RHS is an enumerator-like
+  //    expression of type T or
+  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(E)) {
+    if (BO->getOpcode() == BO_Comma)
+      return findEnumForBlockReturn(BO->getRHS());
+    return 0;
+  }
 
-  // In C, enum constants have the type of their underlying integer type,
-  // not the enum. When inferring block return types, we should allow
-  // the enum type if an enum constant is used, unless the enum is
-  // anonymous (in which case there can be no variables of its type).
-  if (!Ctx.getLangOpts().CPlusPlus) {
-    const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenImpCasts());
-    if (DRE) {
-      const Decl *D = DRE->getDecl();
-      if (const EnumConstantDecl *ECD = dyn_cast<EnumConstantDecl>(D)) {
-        const EnumDecl *Enum = cast<EnumDecl>(ECD->getDeclContext());
-        if (Enum->getDeclName() || Enum->getTypedefNameForAnonDecl())
-          LooseType = Ctx.getTypeDeclType(Enum);
-      }
+  //  - it is a statement-expression whose value expression is an
+  //    enumerator-like expression of type T or
+  if (StmtExpr *SE = dyn_cast<StmtExpr>(E)) {
+    if (Expr *last = dyn_cast_or_null<Expr>(SE->getSubStmt()->body_back()))
+      return findEnumForBlockReturn(last);
+    return 0;
+  }
+
+  //   - it is a ternary conditional operator (not the GNU ?:
+  //     extension) whose second and third operands are
+  //     enumerator-like expressions of type T or
+  if (ConditionalOperator *CO = dyn_cast<ConditionalOperator>(E)) {
+    if (EnumDecl *ED = findEnumForBlockReturn(CO->getTrueExpr()))
+      if (ED == findEnumForBlockReturn(CO->getFalseExpr()))
+        return ED;
+    return 0;
+  }
+
+  // (implicitly:)
+  //   - it is an implicit integral conversion applied to an
+  //     enumerator-like expression of type T or
+  if (ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(E)) {
+    // We can only see integral conversions in valid enumerator-like
+    // expressions.
+    if (ICE->getCastKind() == CK_IntegralCast)
+      return findEnumForBlockReturn(ICE->getSubExpr());
+    return 0;
+  }
+
+  //   - it is an expression of that formal enum type.
+  if (const EnumType *ET = E->getType()->getAs<EnumType>()) {
+    return ET->getDecl();
+  }
+
+  // Otherwise, nope.
+  return 0;
+}
+
+/// Attempt to find a type T for which the returned expression of the
+/// given statement is an enumerator-like expression of that type.
+static EnumDecl *findEnumForBlockReturn(ReturnStmt *ret) {
+  if (Expr *retValue = ret->getRetValue())
+    return findEnumForBlockReturn(retValue);
+  return 0;
+}
+
+/// Attempt to find a common type T for which all of the returned
+/// expressions in a block are enumerator-like expressions of that
+/// type.
+static EnumDecl *findCommonEnumForBlockReturns(ArrayRef<ReturnStmt*> returns) {
+  ArrayRef<ReturnStmt*>::iterator i = returns.begin(), e = returns.end();
+
+  // Try to find one for the first return.
+  EnumDecl *ED = findEnumForBlockReturn(*i);
+  if (!ED) return 0;
+
+  // Check that the rest of the returns have the same enum.
+  for (++i; i != e; ++i) {
+    if (findEnumForBlockReturn(*i) != ED)
+      return 0;
+  }
+
+  // Never infer an anonymous enum type.
+  if (!ED->hasNameForLinkage()) return 0;
+
+  return ED;
+}
+
+/// Adjust the given return statements so that they formally return
+/// the given type.  It should require, at most, an IntegralCast.
+static void adjustBlockReturnsToEnum(Sema &S, ArrayRef<ReturnStmt*> returns,
+                                     QualType returnType) {
+  for (ArrayRef<ReturnStmt*>::iterator
+         i = returns.begin(), e = returns.end(); i != e; ++i) {
+    ReturnStmt *ret = *i;
+    Expr *retValue = ret->getRetValue();
+    if (S.Context.hasSameType(retValue->getType(), returnType))
+      continue;
+
+    // Right now we only support integral fixup casts.
+    assert(returnType->isIntegralOrUnscopedEnumerationType());
+    assert(retValue->getType()->isIntegralOrUnscopedEnumerationType());
+
+    ExprWithCleanups *cleanups = dyn_cast<ExprWithCleanups>(retValue);
+
+    Expr *E = (cleanups ? cleanups->getSubExpr() : retValue);
+    E = ImplicitCastExpr::Create(S.Context, returnType, CK_IntegralCast,
+                                 E, /*base path*/ 0, VK_RValue);
+    if (cleanups) {
+      cleanups->setSubExpr(E);
+    } else {
+      ret->setRetValue(E);
     }
   }
-
-  // Special case for the first return statement we find.
-  // The return type has already been tentatively set, but we might still
-  // have an alternate type we should prefer.
-  if (AlternateType.isNull())
-    AlternateType = LooseType;
-
-  if (Ctx.hasSameType(DeducedType, StrictType)) {
-    // FIXME: The loose type is different when there are constants from two
-    // different enums. We could consider warning here.
-    if (AlternateType != Ctx.DependentTy)
-      if (!Ctx.hasSameType(AlternateType, LooseType))
-        AlternateType = Ctx.VoidTy;
-    return true;
-  }
-
-  if (Ctx.hasSameType(DeducedType, LooseType)) {
-    // Use DependentTy to signal that we're using an alternate type and may
-    // need to add casts somewhere.
-    AlternateType = Ctx.DependentTy;
-    return true;
-  }
-
-  if (Ctx.hasSameType(AlternateType, StrictType) ||
-      Ctx.hasSameType(AlternateType, LooseType)) {
-    DeducedType = AlternateType;
-    // Use DependentTy to signal that we're using an alternate type and may
-    // need to add casts somewhere.
-    AlternateType = Ctx.DependentTy;
-    return true;
-  }
-
-  return false;
 }
 
 void Sema::deduceClosureReturnType(CapturingScopeInfo &CSI) {
   assert(CSI.HasImplicitReturnType);
 
+  // C++ Core Issue #975, proposed resolution:
+  //   If a lambda-expression does not include a trailing-return-type,
+  //   it is as if the trailing-return-type denotes the following type:
+  //     - if there are no return statements in the compound-statement,
+  //       or all return statements return either an expression of type
+  //       void or no expression or braced-init-list, the type void;
+  //     - otherwise, if all return statements return an expression
+  //       and the types of the returned expressions after
+  //       lvalue-to-rvalue conversion (4.1 [conv.lval]),
+  //       array-to-pointer conversion (4.2 [conv.array]), and
+  //       function-to-pointer conversion (4.3 [conv.func]) are the
+  //       same, that common type;
+  //     - otherwise, the program is ill-formed.
+  //
+  // In addition, in blocks in non-C++ modes, if all of the return
+  // statements are enumerator-like expressions of some type T, where
+  // T has a name for linkage, then we infer the return type of the
+  // block to be that type.
+
   // First case: no return statements, implicit void return type.
   ASTContext &Ctx = getASTContext();
   if (CSI.Returns.empty()) {
@@ -308,6 +388,17 @@
   if (CSI.ReturnType->isDependentType())
     return;
 
+  // Try to apply the enum-fuzz rule.
+  if (!getLangOpts().CPlusPlus) {
+    assert(isa<BlockScopeInfo>(CSI));
+    const EnumDecl *ED = findCommonEnumForBlockReturns(CSI.Returns);
+    if (ED) {
+      CSI.ReturnType = Context.getTypeDeclType(ED);
+      adjustBlockReturnsToEnum(*this, CSI.Returns, CSI.ReturnType);
+      return;
+    }
+  }
+
   // Third case: only one return statement. Don't bother doing extra work!
   SmallVectorImpl<ReturnStmt*>::iterator I = CSI.Returns.begin(),
                                          E = CSI.Returns.end();
@@ -316,47 +407,25 @@
 
   // General case: many return statements.
   // Check that they all have compatible return types.
-  // For now, that means "identical", with an exception for enum constants.
-  // (In C, enum constants have the type of their underlying integer type,
-  // not the type of the enum. C++ uses the type of the enum.)
-  QualType AlternateType;
 
   // We require the return types to strictly match here.
+  // Note that we've already done the required promotions as part of
+  // processing the return statement.
   for (; I != E; ++I) {
     const ReturnStmt *RS = *I;
     const Expr *RetE = RS->getRetValue();
-    if (!checkReturnValueType(Ctx, RetE, CSI.ReturnType, AlternateType)) {
-      // FIXME: This is a poor diagnostic for ReturnStmts without expressions.
-      Diag(RS->getLocStart(),
-           diag::err_typecheck_missing_return_type_incompatible)
-        << (RetE ? RetE->getType() : Ctx.VoidTy) << CSI.ReturnType
-        << isa<LambdaScopeInfo>(CSI);
-      // Don't bother fixing up the return statements in the block if some of
-      // them are unfixable anyway.
-      AlternateType = Ctx.VoidTy;
-      // Continue iterating so that we keep emitting diagnostics.
-    }
-  }
 
-  // If our return statements turned out to be compatible, but we needed to
-  // pick a different return type, go through and fix the ones that need it.
-  if (AlternateType == Ctx.DependentTy) {
-    for (SmallVectorImpl<ReturnStmt*>::iterator I = CSI.Returns.begin(),
-                                                E = CSI.Returns.end();
-         I != E; ++I) {
-      ReturnStmt *RS = *I;
-      Expr *RetE = RS->getRetValue();
-      if (RetE->getType() == CSI.ReturnType)
-        continue;
+    QualType ReturnType = (RetE ? RetE->getType() : Context.VoidTy);
+    if (Context.hasSameType(ReturnType, CSI.ReturnType))
+      continue;
 
-      // Right now we only support integral fixup casts.
-      assert(CSI.ReturnType->isIntegralOrUnscopedEnumerationType());
-      assert(RetE->getType()->isIntegralOrUnscopedEnumerationType());
-      ExprResult Casted = ImpCastExprToType(RetE, CSI.ReturnType,
-                                            CK_IntegralCast);
-      assert(Casted.isUsable());
-      RS->setRetValue(Casted.take());
-    }
+    // FIXME: This is a poor diagnostic for ReturnStmts without expressions.
+    // TODO: It's possible that the *first* return is the divergent one.
+    Diag(RS->getLocStart(),
+         diag::err_typecheck_missing_return_type_incompatible)
+      << ReturnType << CSI.ReturnType
+      << isa<LambdaScopeInfo>(CSI);
+    // Continue iterating so that we keep emitting diagnostics.
   }
 }
 
diff --git a/test/SemaObjC/blocks.m b/test/SemaObjC/blocks.m
index 9926b08..dd659ad 100644
--- a/test/SemaObjC/blocks.m
+++ b/test/SemaObjC/blocks.m
@@ -75,10 +75,11 @@
 
 
 // In C, enum constants have the type of the underlying integer type, not the
-// enumeration they are part of. We pretend the constants have enum type when
-// they are mixed with other expressions of enum type.
+// enumeration they are part of. We pretend the constants have enum type if
+// all the returns seem to be playing along.
 enum CStyleEnum {
-  CSE_Value = 1
+  CSE_Value = 1,
+  CSE_Value2 = 2
 };
 enum CStyleEnum getCSE();
 typedef enum CStyleEnum (^cse_block_t)();
@@ -92,7 +93,9 @@
   a = ^{ // expected-error {{incompatible block pointer types assigning to 'cse_block_t' (aka 'enum CStyleEnum (^)()') from 'int (^)(void)'}}
     return 1;
   };
-  a = ^{ // expected-error {{incompatible block pointer types assigning to 'cse_block_t' (aka 'enum CStyleEnum (^)()') from 'int (^)(void)'}}
+
+  // No warning here.
+  a = ^{
     return CSE_Value;
   };
 
@@ -114,6 +117,15 @@
     else
       return 1;
   };
+
+  // rdar://13200889
+  extern void check_enum(void);
+  a = ^{
+    return (arg ? (CSE_Value) : (check_enum(), (!arg ? CSE_Value2 : getCSE())));
+  };
+  a = ^{
+    return (arg ? (CSE_Value) : ({check_enum(); CSE_Value2; }));
+  };
 }