Update ktfmt to fix odd formatting with DSLs. (#214)

Summary:
Update KotlinInputAstVisitorBase.kt's binary expression handling to treat trailing lambdas as separate from the rest of the binary expression. While this should clean up some issues with binary expressions, the main intended side-effect was to fix DSLs:

https://github.com/facebookincubator/ktfmt/issues/206

Before, it would format DSLs like this:

```
a =
  a {
    b =
      b {
        c = c
      }
  }
```

After, it should format them like this

```
a = a {
  b = b {
    c = c
  }
}
```

Pull Request resolved: https://github.com/facebookincubator/ktfmt/pull/214

Reviewed By: hick209

Differential Revision: D27972459

Pulled By: cgrushko

fbshipit-source-id: 6f815a2083f11780e23c6cb857d3d0e4da4d3c0f
diff --git a/core/src/main/java/com/facebook/ktfmt/KotlinInputAstVisitorBase.kt b/core/src/main/java/com/facebook/ktfmt/KotlinInputAstVisitorBase.kt
index 0307738..8ae03ea 100644
--- a/core/src/main/java/com/facebook/ktfmt/KotlinInputAstVisitorBase.kt
+++ b/core/src/main/java/com/facebook/ktfmt/KotlinInputAstVisitorBase.kt
@@ -549,7 +549,7 @@
    * emitQualifiedExpression formats call expressions that are either part of a qualified
    * expression, or standing alone. This method makes it easier to handle both cases uniformly.
    */
-  private fun extractCallExpression(expression: KtExpression): KtCallExpression? {
+  private fun extractCallExpression(expression: KtExpression?): KtCallExpression? {
     val ktExpression = (expression as? KtQualifiedExpression)?.selectorExpression ?: expression
     return ktExpression as? KtCallExpression
   }
@@ -1017,27 +1017,44 @@
           }
         }
 
-    val leftMostExpression = parts.first()
-    leftMostExpression.left?.accept(this)
-    for (leftExpression in parts) {
-      when (leftExpression.operationToken) {
-        KtTokens.RANGE -> {}
-        KtTokens.ELVIS -> builder.breakOp(Doc.FillMode.INDEPENDENT, " ", expressionBreakIndent)
-        else -> builder.space()
+    // Don't count trailing lambdas as part of the binary expression, so they look like this:
+    //
+    // a + b + c {
+    //   d
+    // }
+    val hasTrailingLambda =
+        extractCallExpression(expression.right)?.lambdaArguments?.isNotEmpty() == true
+
+    builder.block(ZERO) {
+      val leftMostExpression = parts.first()
+      leftMostExpression.left?.accept(this)
+      for ((i, leftExpression) in parts.withIndex()) {
+        when (leftExpression.operationToken) {
+          KtTokens.RANGE -> {}
+          KtTokens.ELVIS -> builder.breakOp(Doc.FillMode.INDEPENDENT, " ", expressionBreakIndent)
+          else -> builder.space()
+        }
+        builder.token(leftExpression.operationReference.text)
+        val isFirst = leftExpression === leftMostExpression
+        if (isFirst) {
+          builder.open(expressionBreakIndent)
+        }
+
+        if (!hasTrailingLambda || i < parts.size - 1) {
+          when (leftExpression.operationToken) {
+            KtTokens.RANGE -> {}
+            KtTokens.ELVIS -> builder.space()
+            else -> builder.breakOp(Doc.FillMode.UNIFIED, " ", ZERO)
+          }
+          leftExpression.right?.accept(this)
+        }
       }
-      builder.token(leftExpression.operationReference.text)
-      val isFirst = leftExpression === leftMostExpression
-      if (isFirst) {
-        builder.open(expressionBreakIndent)
-      }
-      when (leftExpression.operationToken) {
-        KtTokens.RANGE -> {}
-        KtTokens.ELVIS -> builder.space()
-        else -> builder.breakOp(Doc.FillMode.UNIFIED, " ", ZERO)
-      }
-      leftExpression.right?.accept(this)
+      builder.close()
     }
-    builder.close()
+
+    if (hasTrailingLambda) {
+      processLambdaOrScopingFunction(expression.right)
+    }
   }
 
   override fun visitUnaryExpression(expression: KtUnaryExpression) {
@@ -1243,13 +1260,16 @@
         if (isGoogleStyle) doubleExpressionBreakIndent else expressionBreakIndent,
         Optional.of(tag))
 
-    if (initializer is KtLambdaExpression) {
-      initializer.accept(this)
-    } else {
+    if (initializer is KtCallExpression) {
       val call = initializer as KtCallExpression
       call.calleeExpression?.accept(this)
       builder.space()
       call.lambdaArguments.forEach { it.getArgumentExpression()?.accept(this) }
+    } else if (initializer is KtQualifiedExpression) {
+      // Known bug: qualified lambdas like `coroutineScope.launch {}` are not handled correctly.
+      initializer?.accept(this)
+    } else {
+      initializer?.accept(this)
     }
   }
 
diff --git a/core/src/test/java/com/facebook/ktfmt/FormatterKtTest.kt b/core/src/test/java/com/facebook/ktfmt/FormatterKtTest.kt
index bbc2b42..4837ca7 100644
--- a/core/src/test/java/com/facebook/ktfmt/FormatterKtTest.kt
+++ b/core/src/test/java/com/facebook/ktfmt/FormatterKtTest.kt
@@ -351,6 +351,34 @@
           deduceMaxWidth = true)
 
   @Test
+  fun `binary operators dont break when the last one is a lambda`() =
+      assertFormatted(
+          """
+      |----------------------
+      |foo =
+      |    foo + bar + dsl {
+      |      baz = 1
+      |    }
+      |""".trimMargin(),
+          deduceMaxWidth = true)
+
+  @Test
+  fun `binary operators break correctly when there's multiple before a lambda`() =
+      assertFormatted(
+          """
+      |----------------------
+      |foo =
+      |    foo +
+      |        bar +
+      |        dsl +
+      |        foo +
+      |        bar {
+      |      baz = 1
+      |    }
+      |""".trimMargin(),
+          deduceMaxWidth = true)
+
+  @Test
   fun `properties with accessors`() =
       assertFormatted(
           """
@@ -4401,4 +4429,39 @@
     """.trimMargin()
     assertThatFormatting(code).isEqualTo(code)
   }
+
+  @Test
+  fun `assignment in a dsl does not break if not needed`() =
+      assertFormatted(
+          """
+      |---------------------
+      |foo = fooDsl {
+      |  bar = barDsl {
+      |    baz = bazDsl {
+      |      bal = balDsl {
+      |        bim = 1
+      |      }
+      |    }
+      |  }
+      |}
+      |""".trimMargin(),
+          deduceMaxWidth = true)
+
+  @Test
+  fun `assignment in a dsl breaks when needed`() =
+      assertFormatted(
+          """
+      |------------------
+      |val foo = fooDsl {
+      |  bar += barDsl {
+      |    baz = bazDsl {
+      |      bal =
+      |          balDsl {
+      |        bim = 1
+      |      }
+      |    }
+      |  }
+      |}
+      |""".trimMargin(),
+          deduceMaxWidth = true)
 }
diff --git a/core/src/test/java/com/facebook/ktfmt/GoogleStyleFormatterKtTest.kt b/core/src/test/java/com/facebook/ktfmt/GoogleStyleFormatterKtTest.kt
index a9961eb..87bc26b 100644
--- a/core/src/test/java/com/facebook/ktfmt/GoogleStyleFormatterKtTest.kt
+++ b/core/src/test/java/com/facebook/ktfmt/GoogleStyleFormatterKtTest.kt
@@ -452,6 +452,36 @@
           deduceMaxWidth = true)
 
   @Test
+  fun `binary operators dont break when the last one is a lambda`() =
+      assertFormatted(
+          """
+      |--------------------
+      |foo =
+      |  foo + bar + dsl {
+      |    baz = 1
+      |  }
+      |""".trimMargin(),
+          formattingOptions = GOOGLE_FORMAT,
+          deduceMaxWidth = true)
+
+  @Test
+  fun `binary operators break correctly when there's multiple before a lambda`() =
+      assertFormatted(
+          """
+      |----------------------
+      |foo =
+      |  foo +
+      |    bar +
+      |    dsl +
+      |    foo +
+      |      bar {
+      |    baz = 1
+      |  }
+      |""".trimMargin(),
+          formattingOptions = GOOGLE_FORMAT,
+          deduceMaxWidth = true)
+
+  @Test
   fun `handle casting with breaks`() =
       assertFormatted(
           """
@@ -817,4 +847,41 @@
       |""".trimMargin(),
           formattingOptions = GOOGLE_FORMAT,
           deduceMaxWidth = true)
+
+  @Test
+  fun `assignment in a dsl does not break if not needed`() =
+      assertFormatted(
+          """
+      |---------------------
+      |foo = fooDsl {
+      |  bar = barDsl {
+      |    baz = bazDsl {
+      |      bal = balDsl {
+      |        bim = 1
+      |      }
+      |    }
+      |  }
+      |}
+      |""".trimMargin(),
+          formattingOptions = GOOGLE_FORMAT,
+          deduceMaxWidth = true)
+
+  @Test
+  fun `assignment in a dsl breaks when needed`() =
+      assertFormatted(
+          """
+      |------------------
+      |foo = fooDsl {
+      |  bar += barDsl {
+      |    baz = bazDsl {
+      |      bal =
+      |          balDsl {
+      |        bim = 1
+      |      }
+      |    }
+      |  }
+      |}
+      |""".trimMargin(),
+          formattingOptions = GOOGLE_FORMAT,
+          deduceMaxWidth = true)
 }