Support shebang for KTS files

Summary: This was crashing in JavaInput, so let's just remove it and add it back at the end.

Reviewed By: cortinico

Differential Revision: D36987246

fbshipit-source-id: e404d213d9ba5a916caaa63cba850066d1d77436
diff --git a/core/src/main/java/com/facebook/ktfmt/format/Formatter.kt b/core/src/main/java/com/facebook/ktfmt/format/Formatter.kt
index 3bf2e2a..1fa90f7 100644
--- a/core/src/main/java/com/facebook/ktfmt/format/Formatter.kt
+++ b/core/src/main/java/com/facebook/ktfmt/format/Formatter.kt
@@ -78,9 +78,15 @@
   @JvmStatic
   @Throws(FormatterException::class, ParseError::class)
   fun format(options: FormattingOptions, code: String): String {
-    checkEscapeSequences(code)
+    val (shebang, kotlinCode) =
+        if (code.startsWith("#!")) {
+          code.split("\n".toRegex(), limit = 2)
+        } else {
+          listOf("", code)
+        }
+    checkEscapeSequences(kotlinCode)
 
-    val lfCode = StringUtilRt.convertLineSeparators(code)
+    val lfCode = StringUtilRt.convertLineSeparators(kotlinCode)
     val sortedImports = sortedAndDistinctImports(lfCode)
     val pretty = prettyPrint(sortedImports, options, "\n")
     val noRedundantElements =
@@ -89,7 +95,9 @@
         } catch (e: ParseError) {
           throw IllegalStateException("Failed to re-parse code after pretty printing:\n $pretty", e)
         }
-    return prettyPrint(noRedundantElements, options, Newlines.guessLineSeparator(code)!!)
+    val prettyCode =
+        prettyPrint(noRedundantElements, options, Newlines.guessLineSeparator(kotlinCode)!!)
+    return if (shebang.isNotEmpty()) shebang + "\n" + prettyCode else prettyCode
   }
 
   /** prettyPrint reflows 'code' using google-java-format's engine. */
diff --git a/core/src/test/java/com/facebook/ktfmt/format/FormatterTest.kt b/core/src/test/java/com/facebook/ktfmt/format/FormatterTest.kt
index b403a01..bf57e02 100644
--- a/core/src/test/java/com/facebook/ktfmt/format/FormatterTest.kt
+++ b/core/src/test/java/com/facebook/ktfmt/format/FormatterTest.kt
@@ -50,6 +50,16 @@
         |""".trimMargin())
 
   @Test
+  fun `support script (kts) files with a shebang`() =
+      assertFormatted(
+          """
+        |#!/usr/bin/env kscript
+        |package foo
+        |
+        |println("Called")
+        |""".trimMargin())
+
+  @Test
   fun `call chains`() =
       assertFormatted(
           """