handle java array type resolution
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
index 0c13f3a..6ea2590 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/processing/impl/ResolverImpl.kt
@@ -25,7 +25,6 @@
import com.google.devtools.ksp.symbol.Variance
import com.google.devtools.ksp.symbol.impl.binary.*
import com.google.devtools.ksp.symbol.impl.findParentAnnotated
-import com.google.devtools.ksp.symbol.impl.findParentPsiDeclaration
import com.google.devtools.ksp.symbol.impl.findPsi
import com.google.devtools.ksp.symbol.impl.getInstanceForCurrentRound
import com.google.devtools.ksp.symbol.impl.java.*
@@ -57,7 +56,9 @@
import org.jetbrains.kotlin.load.java.lazy.types.JavaTypeResolver
import org.jetbrains.kotlin.load.java.lazy.types.toAttributes
import org.jetbrains.kotlin.load.java.sources.JavaSourceElement
+import org.jetbrains.kotlin.load.java.structure.impl.JavaArrayTypeImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaClassImpl
+import org.jetbrains.kotlin.load.java.structure.impl.JavaConstructorImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaFieldImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaMethodImpl
import org.jetbrains.kotlin.load.java.structure.impl.JavaTypeImpl
@@ -521,31 +522,41 @@
} as PropertyAccessorDescriptor?
}
- fun resolveJavaType(psi: PsiType): KotlinType {
+ fun resolveJavaType(psi: PsiType, parentTypeReference: KSTypeReference? = null): KotlinType {
incrementalContext.recordLookup(psi)
val javaType = JavaTypeImpl.create(psi)
- var parent: PsiElement? = (psi as? PsiClassReferenceType)?.resolve()
- val stack = Stack<PsiElement>()
- while (parent != null && parent !is PsiJavaFile) {
- stack.push(parent)
- parent = parent.findParentPsiDeclaration()
+ var parent: KSNode? = parentTypeReference
+
+ val stack = Stack<KSNode>()
+ while (parent != null) {
+ if (parent is KSFunctionDeclarationJavaImpl || parent is KSClassDeclarationJavaImpl) {
+ stack.push(parent)
+ }
+ parent = parent.parent
}
// Construct resolver context for the PsiType
var resolverContext = lazyJavaResolverContext
for (e in stack) {
- val descriptor = resolveJavaDeclaration(e)!!
when (e) {
- is PsiMethod -> {
- resolverContext = resolverContext.childForMethod(descriptor, JavaMethodImpl(e))
- }
- is PsiClass -> {
+ is KSFunctionDeclarationJavaImpl -> {
resolverContext = resolverContext
- .childForClassOrPackage(descriptor as ClassDescriptor, JavaClassImpl(e))
+ .childForMethod(
+ resolveJavaDeclaration(e.psi)!!,
+ if (e.psi.isConstructor) JavaConstructorImpl(e.psi) else JavaMethodImpl(e.psi)
+ )
+ }
+ is KSClassDeclarationJavaImpl -> {
+ resolverContext = resolverContext
+ .childForClassOrPackage(resolveJavaDeclaration(e.psi) as ClassDescriptor, JavaClassImpl(e.psi))
}
}
}
- return resolverContext.typeResolver.transformJavaType(javaType, TypeUsage.COMMON.toAttributes())
+ return if (javaType is JavaArrayTypeImpl)
+ resolverContext
+ .typeResolver.transformArrayType(javaType, TypeUsage.COMMON.toAttributes(), psi is PsiEllipsisType)
+ else
+ resolverContext.typeResolver.transformJavaType(javaType, TypeUsage.COMMON.toAttributes())
}
fun KotlinType.expandNonRecursively(): KotlinType =
@@ -619,7 +630,10 @@
).defaultType
)
} else {
- return getKSTypeCached(resolveJavaType(type.psi), type.element.typeArguments, type.annotations)
+ return getKSTypeCached(
+ resolveJavaType(type.psi, type),
+ type.element.typeArguments, type.annotations
+ )
}
}
else -> throw IllegalStateException("Unable to resolve type for $type, $ExceptionMessage")
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
index 34b65c3..ef0389b 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/java/KSTypeReferenceJavaImpl.kt
@@ -91,7 +91,7 @@
is PsiWildcardType -> KSClassifierReferenceJavaImpl.getCached(type.extendsBound as PsiClassType, this)
is PsiPrimitiveType -> KSClassifierReferenceDescriptorImpl.getCached(type.toKotlinType(), origin, this)
is PsiArrayType -> {
- val componentType = ResolverImpl.instance.resolveJavaType(type.componentType)
+ val componentType = ResolverImpl.instance.resolveJavaType(type.componentType, this)
if (type.componentType !is PsiPrimitiveType) {
KSClassifierReferenceDescriptorImpl.getCached(
ResolverImpl.instance.module.builtIns.getArrayType(Variance.INVARIANT, componentType),
diff --git a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
index b2ed1a1..01459d6 100644
--- a/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
+++ b/compiler-plugin/src/main/kotlin/com/google/devtools/ksp/symbol/impl/utils.kt
@@ -236,14 +236,6 @@
}
}
-fun PsiElement.findParentPsiDeclaration(): PsiElement? {
- var parent = this.parent
- while (parent != null && parent !is PsiClass && parent !is PsiMethod && parent !is PsiJavaFile) {
- parent = parent.parent
- }
- return parent
-}
-
fun PsiElement.findParentDeclaration(): KSDeclaration? {
return this.findParentAnnotated() as? KSDeclaration
}
diff --git a/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt b/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
index df30cbe..af24004 100644
--- a/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
+++ b/compiler-plugin/src/test/kotlin/com/google/devtools/ksp/processor/ResolveJavaTypeProcessor.kt
@@ -18,7 +18,14 @@
package com.google.devtools.ksp.processor
import com.google.devtools.ksp.processing.Resolver
-import com.google.devtools.ksp.symbol.*
+import com.google.devtools.ksp.symbol.KSAnnotated
+import com.google.devtools.ksp.symbol.KSClassDeclaration
+import com.google.devtools.ksp.symbol.KSFunctionDeclaration
+import com.google.devtools.ksp.symbol.KSNode
+import com.google.devtools.ksp.symbol.KSTypeReference
+import com.google.devtools.ksp.symbol.Nullability
+import com.google.devtools.ksp.symbol.Origin
+import com.google.devtools.ksp.symbol.Variance
import com.google.devtools.ksp.visitor.KSTopDownVisitor
class ResolveJavaTypeProcessor : AbstractTestProcessor() {
@@ -31,8 +38,12 @@
override fun process(resolver: Resolver): List<KSAnnotated> {
val symbol = resolver.getClassDeclarationByName(resolver.getKSNameFromString("C"))
+ val symbolTypeParameter = resolver.getClassDeclarationByName(resolver.getKSNameFromString("Base"))
+ val another = resolver.getClassDeclarationByName(resolver.getKSNameFromString("Another"))
assert(symbol?.origin == Origin.JAVA)
symbol!!.accept(visitor, Unit)
+ symbolTypeParameter!!.accept(visitor, Unit)
+ another!!.accept(visitor, Unit)
return emptyList()
}
diff --git a/compiler-plugin/testData/api/resolveJavaType.kt b/compiler-plugin/testData/api/resolveJavaType.kt
index cf93c71..a40761e 100644
--- a/compiler-plugin/testData/api/resolveJavaType.kt
+++ b/compiler-plugin/testData/api/resolveJavaType.kt
@@ -31,6 +31,25 @@
// kotlin.collections.MutableList<in kotlin.collections.MutableList<out kotlin.Double?>?>?
// Bar?
// kotlin.Array<Bar?>?
+// Foo<Base.T?, Base.Inner.P?>?
+// Bar<Base.Inner.P?, Base.T?>?
+// kotlin.collections.MutableList<Base.T?>?
+// kotlin.Unit
+// Base.T?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.collections.MutableList<Base.T?>?
+// kotlin.Unit
+// Base.T?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// kotlin.Array<Base.T?>?
+// kotlin.Unit
+// Base<Another.T?, Another.T?>?
// END
// FILE: a.kt
annotation class Test
@@ -69,3 +88,29 @@
public Bar[] BarArryFun() {}
}
+
+// FILE: Base.java
+import java.util.List;
+
+class Foo<T1,T2> {}
+class Bar<T1, T2> {}
+
+class Base<T,P> {
+ void genericT(List<T> t){};
+ void singleT(T t){};
+ void varargT(T... t){};
+ void arrayT(T[] t){};
+
+ class Inner<P> {
+ void genericT(List<T> t){};
+ void singleT(T t){};
+ void varargT(T... t){};
+ void arrayT(T[] t){};
+ Foo<T, P> foo;
+ Bar<P, T> bar;
+ }
+}
+
+class Another<T> {
+ Base<T, T> base;
+}