blob: 38db77c856e6fc11ef41cd270e0121abc0cef057 [file] [log] [blame]
/*
* Copyright 2010-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
* that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.idea.inspections.coroutines
import com.intellij.codeInsight.FileModificationService
import com.intellij.codeInspection.*
import com.intellij.codeInspection.ProblemHighlightType.*
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElementVisitor
import com.intellij.psi.codeStyle.CodeStyleManager
import org.jetbrains.kotlin.descriptors.DeclarationDescriptor
import org.jetbrains.kotlin.descriptors.ReceiverParameterDescriptor
import org.jetbrains.kotlin.idea.caches.resolve.analyzeWithContent
import org.jetbrains.kotlin.idea.caches.resolve.resolveToDescriptorIfAny
import org.jetbrains.kotlin.idea.core.ShortenReferences
import org.jetbrains.kotlin.idea.core.replaced
import org.jetbrains.kotlin.idea.inspections.AbstractKotlinInspection
import org.jetbrains.kotlin.idea.inspections.UnusedReceiverParameterInspection
import org.jetbrains.kotlin.idea.intentions.ConvertReceiverToParameterIntention
import org.jetbrains.kotlin.idea.intentions.MoveMemberToCompanionObjectIntention
import org.jetbrains.kotlin.idea.util.application.executeWriteCommand
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.containingClassOrObject
import org.jetbrains.kotlin.psi.psiUtil.forEachDescendantOfType
import org.jetbrains.kotlin.psi.psiUtil.getNonStrictParentOfType
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.resolve.calls.callUtil.getResolvedCall
import org.jetbrains.kotlin.resolve.descriptorUtil.fqNameSafe
import org.jetbrains.kotlin.resolve.scopes.receivers.ExpressionReceiver
import org.jetbrains.kotlin.resolve.scopes.receivers.ImplicitReceiver
import org.jetbrains.kotlin.resolve.scopes.receivers.ReceiverValue
import org.jetbrains.kotlin.types.KotlinType
class SuspendFunctionOnCoroutineScopeInspection : AbstractKotlinInspection() {
override fun buildVisitor(holder: ProblemsHolder, isOnTheFly: Boolean): PsiElementVisitor {
return namedFunctionVisitor(fun(function: KtNamedFunction) {
if (!function.hasModifier(KtTokens.SUSPEND_KEYWORD)) return
if (!function.hasBody()) return
val context = function.analyzeWithContent()
val descriptor = context[BindingContext.FUNCTION, function] ?: return
val (extensionOfCoroutineScope, memberOfCoroutineScope) = with(descriptor) {
extensionReceiverParameter.ofCoroutineScopeType() to dispatchReceiverParameter.ofCoroutineScopeType()
}
if (!extensionOfCoroutineScope && !memberOfCoroutineScope) return
fun DeclarationDescriptor.isReceiverOfAnalyzedFunction(): Boolean {
if (extensionOfCoroutineScope && this == descriptor) return true
if (memberOfCoroutineScope && this == descriptor.containingDeclaration) return true
return false
}
fun checkSuspiciousReceiver(receiver: ReceiverValue, problemExpression: KtExpression) {
when (receiver) {
is ImplicitReceiver -> if (!receiver.declarationDescriptor.isReceiverOfAnalyzedFunction()) return
is ExpressionReceiver -> {
val receiverThisExpression = receiver.expression as? KtThisExpression ?: return
if (receiverThisExpression.getTargetLabel() != null) {
val instanceReference = receiverThisExpression.instanceReference
if (context[BindingContext.REFERENCE_TARGET, instanceReference]?.isReceiverOfAnalyzedFunction() != true) return
}
}
}
val fixes = mutableListOf<LocalQuickFix>()
val reportElement = (problemExpression as? KtCallExpression)?.calleeExpression ?: problemExpression
holder.registerProblem(
reportElement,
MESSAGE,
GENERIC_ERROR_OR_WARNING,
WrapWithCoroutineScopeFix(removeReceiver = false, wrapCallOnly = true)
)
fixes += WrapWithCoroutineScopeFix(removeReceiver = extensionOfCoroutineScope, wrapCallOnly = false)
val file = function.containingKtFile
if (extensionOfCoroutineScope) {
fixes += IntentionWrapper(ConvertReceiverToParameterIntention(), file)
}
if (memberOfCoroutineScope) {
val containingDeclaration = function.containingClassOrObject
if (containingDeclaration is KtClass && !containingDeclaration.isInterface() && function.hasBody()) {
fixes += IntentionWrapper(MoveMemberToCompanionObjectIntention(), file)
}
}
holder.registerProblem(
with(function) { receiverTypeReference ?: nameIdentifier ?: funKeyword ?: this },
MESSAGE,
GENERIC_ERROR_OR_WARNING,
*fixes.toTypedArray()
)
}
function.forEachDescendantOfType(fun(callExpression: KtCallExpression) {
val resolvedCall = callExpression.getResolvedCall(context) ?: return
val extensionReceiverParameter = resolvedCall.resultingDescriptor.extensionReceiverParameter ?: return
if (!extensionReceiverParameter.type.isCoroutineScope()) return
val extensionReceiver = resolvedCall.extensionReceiver ?: return
checkSuspiciousReceiver(extensionReceiver, callExpression)
})
function.forEachDescendantOfType(fun(nameReferenceExpression: KtNameReferenceExpression) {
if (nameReferenceExpression.getReferencedName() != COROUTINE_CONTEXT) return
val resolvedCall = nameReferenceExpression.getResolvedCall(context) ?: return
if (resolvedCall.resultingDescriptor.fqNameSafe.asString() == "$COROUTINE_SCOPE.$COROUTINE_CONTEXT") {
val dispatchReceiver = resolvedCall.dispatchReceiver ?: return
checkSuspiciousReceiver(dispatchReceiver, nameReferenceExpression)
}
})
})
}
private class WrapWithCoroutineScopeFix(
private val removeReceiver: Boolean,
private val wrapCallOnly: Boolean
) : LocalQuickFix {
override fun getFamilyName(): String = "Wrap with coroutineScope"
override fun getName(): String =
when {
removeReceiver && !wrapCallOnly -> "Remove receiver & wrap with 'coroutineScope { ... }'"
wrapCallOnly -> "Wrap call with 'coroutineScope { ... }'"
else -> "Wrap function body with 'coroutineScope { ... }'"
}
override fun startInWriteAction() = false
override fun applyFix(project: Project, descriptor: ProblemDescriptor) {
val problemElement = descriptor.psiElement ?: return
val function = problemElement.getNonStrictParentOfType<KtNamedFunction>() ?: return
val functionDescriptor = function.resolveToDescriptorIfAny()
if (!FileModificationService.getInstance().preparePsiElementForWrite(function)) return
val bodyExpression = function.bodyExpression
fun getExpressionToWrapCall(): KtExpression? {
var result = problemElement as? KtExpression ?: return null
while (result.parent is KtQualifiedExpression || result.parent is KtCallExpression) {
result = result.parent as KtExpression
}
return result
}
var expressionToWrap = when {
wrapCallOnly -> getExpressionToWrapCall()
else -> bodyExpression
} ?: return
if (functionDescriptor?.extensionReceiverParameter.ofCoroutineScopeType()) {
val context = function.analyzeWithContent()
expressionToWrap.forEachDescendantOfType<KtDotQualifiedExpression> {
val receiverExpression = it.receiverExpression as? KtThisExpression
val selectorExpression = it.selectorExpression
if (receiverExpression?.getTargetLabel() != null && selectorExpression != null) {
if (context[BindingContext.REFERENCE_TARGET, receiverExpression.instanceReference] == functionDescriptor) {
if (it === expressionToWrap) {
expressionToWrap = it.replaced(selectorExpression)
} else {
it.replace(selectorExpression)
}
}
}
}
}
val factory = KtPsiFactory(function)
val blockExpression = function.bodyBlockExpression
project.executeWriteCommand(name, this) {
val result = when {
expressionToWrap != bodyExpression -> expressionToWrap.replaced(
factory.createExpressionByPattern("$COROUTINE_SCOPE_WRAPPER { $0 }", expressionToWrap)
)
blockExpression == null -> bodyExpression.replaced(
factory.createExpressionByPattern("$COROUTINE_SCOPE_WRAPPER { $0 }", bodyExpression)
)
else -> {
val bodyText = buildString {
for (statement in blockExpression.statements) {
append(statement.text)
}
}
blockExpression.replaced(
factory.createBlock("$COROUTINE_SCOPE_WRAPPER { $bodyText }")
)
}
}
val reformatted = CodeStyleManager.getInstance(project).reformat(result)
ShortenReferences.DEFAULT.process(reformatted as KtElement)
}
val receiverTypeReference = function.receiverTypeReference
if (removeReceiver && !wrapCallOnly && receiverTypeReference != null) {
UnusedReceiverParameterInspection.RemoveReceiverFix.apply(receiverTypeReference, project)
}
}
}
companion object {
private const val COROUTINE_SCOPE = "kotlinx.coroutines.CoroutineScope"
private const val COROUTINE_SCOPE_WRAPPER = "kotlinx.coroutines.coroutineScope"
private const val COROUTINE_CONTEXT = "coroutineContext"
private const val MESSAGE = "Ambiguous coroutineContext due to CoroutineScope receiver of suspend function"
private fun KotlinType.isCoroutineScope(): Boolean =
constructor.declarationDescriptor?.fqNameSafe?.asString() == COROUTINE_SCOPE
private fun ReceiverParameterDescriptor?.ofCoroutineScopeType(): Boolean {
if (this == null) return false
if (type.isCoroutineScope()) return true
return type.constructor.supertypes.reversed().any { it.isCoroutineScope() }
}
}
}