Fix handling of case patterns in unused import removal

https://github.com/google/google-java-format/issues/684

PiperOrigin-RevId: 411136606
diff --git a/core/src/main/java/com/google/googlejavaformat/java/RemoveUnusedImports.java b/core/src/main/java/com/google/googlejavaformat/java/RemoveUnusedImports.java
index 42af3f3..20e55e9 100644
--- a/core/src/main/java/com/google/googlejavaformat/java/RemoveUnusedImports.java
+++ b/core/src/main/java/com/google/googlejavaformat/java/RemoveUnusedImports.java
@@ -32,6 +32,7 @@
 import com.google.googlejavaformat.Newlines;
 import com.sun.source.doctree.DocCommentTree;
 import com.sun.source.doctree.ReferenceTree;
+import com.sun.source.tree.CaseTree;
 import com.sun.source.tree.IdentifierTree;
 import com.sun.source.tree.ImportTree;
 import com.sun.source.tree.Tree;
@@ -55,8 +56,10 @@
 import com.sun.tools.javac.util.Options;
 import java.io.IOError;
 import java.io.IOException;
+import java.lang.reflect.Method;
 import java.net.URI;
 import java.util.LinkedHashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.tools.Diagnostic;
@@ -115,6 +118,31 @@
       return null;
     }
 
+    // TODO(cushon): remove this override when pattern matching in switch is no longer a preview
+    // feature, and TreePathScanner visits CaseTree#getLabels instead of CaseTree#getExpressions
+    @SuppressWarnings("unchecked") // reflection
+    @Override
+    public Void visitCase(CaseTree tree, Void unused) {
+      if (CASE_TREE_GET_LABELS != null) {
+        try {
+          scan((List<? extends Tree>) CASE_TREE_GET_LABELS.invoke(tree), null);
+        } catch (ReflectiveOperationException e) {
+          throw new LinkageError(e.getMessage(), e);
+        }
+      }
+      return super.visitCase(tree, null);
+    }
+
+    private static final Method CASE_TREE_GET_LABELS = caseTreeGetLabels();
+
+    private static Method caseTreeGetLabels() {
+      try {
+        return CaseTree.class.getMethod("getLabels");
+      } catch (NoSuchMethodException e) {
+        return null;
+      }
+    }
+
     @Override
     public Void scan(Tree tree, Void unused) {
       if (tree == null) {
diff --git a/core/src/test/java/com/google/googlejavaformat/java/RemoveUnusedImportsCaseLabelsTest.java b/core/src/test/java/com/google/googlejavaformat/java/RemoveUnusedImportsCaseLabelsTest.java
new file mode 100644
index 0000000..c0babb0
--- /dev/null
+++ b/core/src/test/java/com/google/googlejavaformat/java/RemoveUnusedImportsCaseLabelsTest.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2021 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied. See the License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package com.google.googlejavaformat.java;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.googlejavaformat.java.RemoveUnusedImports.removeUnusedImports;
+import static org.junit.Assume.assumeTrue;
+
+import com.google.common.base.Joiner;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests that unused import removal doesn't remove types used in case labels. */
+@RunWith(JUnit4.class)
+public class RemoveUnusedImportsCaseLabelsTest {
+  @Test
+  public void preserveTypesInCaseLabels() throws FormatterException {
+    assumeTrue(Runtime.version().feature() >= 17);
+    String input =
+        Joiner.on('\n')
+            .join(
+                "package example;",
+                "import example.model.SealedInterface;",
+                "import example.model.TypeA;",
+                "import example.model.TypeB;",
+                "public class Main {",
+                "  public void apply(SealedInterface sealedInterface) {",
+                "    switch(sealedInterface) {",
+                "      case TypeA a -> System.out.println(\"A!\");",
+                "      case TypeB b -> System.out.println(\"B!\");",
+                "    }",
+                "  }",
+                "}");
+    assertThat(removeUnusedImports(input)).isEqualTo(input);
+  }
+}