handle java array type resolution
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
index 0c13f3a..6ea2590 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
@@ -25,7 +25,6 @@
 import com.google.devtools.ksp.symbol.Variance
 import com.google.devtools.ksp.symbol.impl.binary.*
 import com.google.devtools.ksp.symbol.impl.findParentAnnotated
-import com.google.devtools.ksp.symbol.impl.findParentPsiDeclaration
 import com.google.devtools.ksp.symbol.impl.findPsi
 import com.google.devtools.ksp.symbol.impl.getInstanceForCurrentRound
 import com.google.devtools.ksp.symbol.impl.java.*
@@ -57,7 +56,9 @@
 import org.jetbrains.kotlin.load.java.lazy.types.JavaTypeResolver
 import org.jetbrains.kotlin.load.java.lazy.types.toAttributes
 import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
+import org.jetbrains.kotlin.load.java.structure.impl.JavaArrayTypeImpl
 import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
+import org.jetbrains.kotlin.load.java.structure.impl.JavaConstructorImpl
 import org.jetbrains.kotlin.load.java.structure.impl.JavaFieldImpl
 import org.jetbrains.kotlin.load.java.structure.impl.JavaMethodImpl
 import org.jetbrains.kotlin.load.java.structure.impl.JavaTypeImpl
@@ -521,31 +522,41 @@
         } as PropertyAccessorDescriptor?
     }
 
-    fun resolveJavaType(psi: PsiType): KotlinType {
+    fun resolveJavaType(psi: PsiType, parentTypeReference: KSTypeReference? = null): KotlinType {
         incrementalContext.recordLookup(psi)
         val javaType = JavaTypeImpl.create(psi)
 
-        var parent: PsiElement? = (psi as? PsiClassReferenceType)?.resolve()
-        val stack = Stack<PsiElement>()
-        while (parent != null && parent !is PsiJavaFile) {
-            stack.push(parent)
-            parent = parent.findParentPsiDeclaration()
+        var parent: KSNode? = parentTypeReference
+
+        val stack = Stack<KSNode>()
+        while (parent != null) {
+            if (parent is KSFunctionDeclarationJavaImpl || parent is KSClassDeclarationJavaImpl) {
+                stack.push(parent)
+            }
+            parent = parent.parent
         }
         // Construct resolver context for the PsiType
         var resolverContext = lazyJavaResolverContext
         for (e in stack) {
-            val descriptor = resolveJavaDeclaration(e)!!
             when (e) {
-                is PsiMethod -> {
-                    resolverContext = resolverContext.childForMethod(descriptor, JavaMethodImpl(e))
-                }
-                is PsiClass -> {
+                is KSFunctionDeclarationJavaImpl -> {
                     resolverContext = resolverContext
-                        .childForClassOrPackage(descriptor as ClassDescriptor, JavaClassImpl(e))
+                        .childForMethod(
+                            resolveJavaDeclaration(e.psi)!!,
+                            if (e.psi.isConstructor) JavaConstructorImpl(e.psi) else JavaMethodImpl(e.psi)
+                        )
+                }
+                is KSClassDeclarationJavaImpl -> {
+                    resolverContext = resolverContext
+                        .childForClassOrPackage(resolveJavaDeclaration(e.psi) as ClassDescriptor, JavaClassImpl(e.psi))
                 }
             }
         }
-        return resolverContext.typeResolver.transformJavaType(javaType, TypeUsage.COMMON.toAttributes())
+        return if (javaType is JavaArrayTypeImpl)
+            resolverContext
+                .typeResolver.transformArrayType(javaType, TypeUsage.COMMON.toAttributes(), psi is PsiEllipsisType)
+        else
+            resolverContext.typeResolver.transformJavaType(javaType, TypeUsage.COMMON.toAttributes())
     }
 
     fun KotlinType.expandNonRecursively(): KotlinType =
@@ -619,7 +630,10 @@
                         ).defaultType
                     )
                 } else {
-                    return getKSTypeCached(resolveJavaType(type.psi), type.element.typeArguments, type.annotations)
+                    return getKSTypeCached(
+                        resolveJavaType(type.psi, type),
+                        type.element.typeArguments, type.annotations
+                    )
                 }
             }
             else -> throw IllegalStateException("Unable to resolve type for $type, $ExceptionMessage")
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
index 34b65c3..ef0389b 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
@@ -91,7 +91,7 @@
             is PsiWildcardType -> KSClassifierReferenceJavaImpl.getCached(type.extendsBound as PsiClassType, this)
             is PsiPrimitiveType -> KSClassifierReferenceDescriptorImpl.getCached(type.toKotlinType(), origin, this)
             is PsiArrayType -> {
-                val componentType = ResolverImpl.instance.resolveJavaType(type.componentType)
+                val componentType = ResolverImpl.instance.resolveJavaType(type.componentType, this)
                 if (type.componentType !is PsiPrimitiveType) {
                     KSClassifierReferenceDescriptorImpl.getCached(
                         ResolverImpl.instance.module.builtIns.getArrayType(Variance.INVARIANT, componentType),
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
index b2ed1a1..01459d6 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
@@ -236,14 +236,6 @@
     }
 }
 
-fun PsiElement.findParentPsiDeclaration(): PsiElement? {
-    var parent = this.parent
-    while (parent != null && parent !is PsiClass && parent !is PsiMethod && parent !is PsiJavaFile) {
-        parent = parent.parent
-    }
-    return parent
-}
-
 fun PsiElement.findParentDeclaration(): KSDeclaration? {
     return this.findParentAnnotated() as? KSDeclaration
 }
diff --git a/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt b/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
index df30cbe..af24004 100644
--- a/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
+++ b/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
@@ -18,7 +18,14 @@
 package com.google.devtools.ksp.processor
 
 import com.google.devtools.ksp.processing.Resolver
-import com.google.devtools.ksp.symbol.*
+import com.google.devtools.ksp.symbol.KSAnnotated
+import com.google.devtools.ksp.symbol.KSClassDeclaration
+import com.google.devtools.ksp.symbol.KSFunctionDeclaration
+import com.google.devtools.ksp.symbol.KSNode
+import com.google.devtools.ksp.symbol.KSTypeReference
+import com.google.devtools.ksp.symbol.Nullability
+import com.google.devtools.ksp.symbol.Origin
+import com.google.devtools.ksp.symbol.Variance
 import com.google.devtools.ksp.visitor.KSTopDownVisitor
 
 class ResolveJavaTypeProcessor : AbstractTestProcessor() {
@@ -31,8 +38,12 @@
 
     override fun process(resolver: Resolver): List<KSAnnotated> {
         val symbol = resolver.getClassDeclarationByName(resolver.getKSNameFromString("C"))
+        val symbolTypeParameter = resolver.getClassDeclarationByName(resolver.getKSNameFromString("Base"))
+        val another = resolver.getClassDeclarationByName(resolver.getKSNameFromString("Another"))
         assert(symbol?.origin == Origin.JAVA)
         symbol!!.accept(visitor, Unit)
+        symbolTypeParameter!!.accept(visitor, Unit)
+        another!!.accept(visitor, Unit)
         return emptyList()
     }
 
diff --git a/compiler-plugin/testData/api/resolveJavaType.kt b/compiler-plugin/testData/api/resolveJavaType.kt
index cf93c71..a40761e 100644
--- a/compiler-plugin/testData/api/resolveJavaType.kt
+++ b/compiler-plugin/testData/api/resolveJavaType.kt
@@ -31,6 +31,25 @@
 // kotlin.collections.MutableList<in kotlin.collections.MutableList<out kotlin.Double?>?>?
 // Bar?
 // kotlin.Array<Bar?>?
+// Foo<Base.T?, Base.Inner.P?>?
+// Bar<Base.Inner.P?, Base.T?>?
+// kotlin.collections.MutableList<Base.T?>?
+// kotlin.Unit
+// Base.T?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.collections.MutableList<Base.T?>?
+// kotlin.Unit
+// Base.T?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// Base<Another.T?, Another.T?>?
 // END
 // FILE: a.kt
 annotation class Test
@@ -69,3 +88,29 @@
 
     public Bar[] BarArryFun() {}
 }
+
+// FILE: Base.java
+import java.util.List;
+
+class Foo<T1,T2> {}
+class Bar<T1, T2> {}
+
+class Base<T,P> {
+    void genericT(List<T> t){};
+    void singleT(T t){};
+    void varargT(T... t){};
+    void arrayT(T[] t){};
+
+    class Inner<P> {
+        void genericT(List<T> t){};
+        void singleT(T t){};
+        void varargT(T... t){};
+        void arrayT(T[] t){};
+        Foo<T, P> foo;
+        Bar<P, T> bar;
+    }
+}
+
+class Another<T> {
+    Base<T, T> base;
+}