blob: 2862daa215946115b9b63e8d717618c7c3859328 [file] [log] [blame]
/*
* Copyright 2010-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jetbrains.kotlin.js.coroutine
import org.jetbrains.kotlin.js.backend.ast.*
import org.jetbrains.kotlin.js.backend.ast.metadata.*
import org.jetbrains.kotlin.js.inline.util.collectFreeVariables
import org.jetbrains.kotlin.js.inline.util.replaceNames
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils
import org.jetbrains.kotlin.js.translate.utils.JsAstUtils.pureFqn
import org.jetbrains.kotlin.js.translate.utils.splitToRanges
fun JsNode.collectNodesToSplit(breakContinueTargets: Map<JsContinue, JsStatement>): Set<JsNode> {
val root = this
val nodes = mutableSetOf<JsNode>()
val visitor = object : RecursiveJsVisitor() {
var childrenInSet = false
var finallyLevel = 0
override fun visitExpressionStatement(x: JsExpressionStatement) {
super.visitExpressionStatement(x)
if (x.expression.isSuspend) {
nodes += x.expression
childrenInSet = true
}
else {
val assignment = JsAstUtils.decomposeAssignment(x.expression)
if (assignment != null && assignment.second.isSuspend) {
nodes += assignment.second
childrenInSet = true
}
}
}
override fun visitReturn(x: JsReturn) {
super.visitReturn(x)
if (root in nodes || finallyLevel > 0) {
nodes += x
childrenInSet = true
}
}
// We don't handle JsThrow case here the same way as we do for JsReturn.
// Exception will be caught by the surrounding catch and then dispatched to a corresponding $exceptionState.
// Even if there's no `catch` clause, we generate a fake one that dispatches to a finally block.
override fun visitBreak(x: JsBreak) {
super.visitBreak(x)
val breakTarget = breakContinueTargets[x]!!
if (breakTarget in nodes) {
nodes += x
childrenInSet = true
}
}
override fun visitContinue(x: JsContinue) {
super.visitContinue(x)
val continueTarget = breakContinueTargets[x]!!
if (continueTarget in nodes) {
nodes += x
childrenInSet = true
}
}
override fun visitTry(x: JsTry) {
if (x.finallyBlock != null) {
finallyLevel++
}
super.visitTry(x)
if (x.finallyBlock != null) {
finallyLevel--
}
}
override fun visitElement(node: JsNode) {
val oldChildrenInSet = childrenInSet
childrenInSet = false
node.acceptChildren(this)
if (childrenInSet) {
nodes += node
}
else {
childrenInSet = oldChildrenInSet
}
}
}
while (true) {
val countBefore = nodes.size
visitor.accept(this)
val countAfter = nodes.size
if (countAfter == countBefore) break
}
return nodes
}
fun List<CoroutineBlock>.replaceCoroutineFlowStatements(context: CoroutineTransformationContext) {
val blockIndexes = withIndex().associate { (index, block) -> Pair(block, index) }
val blockReplacementVisitor = object : JsVisitorWithContextImpl() {
override fun endVisit(x: JsDebugger, ctx: JsContext<in JsStatement>) {
val target = x.targetBlock
if (target != null) {
val lhs = JsNameRef(context.metadata.stateName, JsAstUtils.stateMachineReceiver())
val rhs = JsIntLiteral(blockIndexes[target]!!)
ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
targetBlock = true
})
}
val exceptionTarget = x.targetExceptionBlock
if (exceptionTarget != null) {
val lhs = JsNameRef(context.metadata.exceptionStateName, JsAstUtils.stateMachineReceiver())
val rhs = JsIntLiteral(blockIndexes[exceptionTarget]!!)
ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
targetExceptionBlock = true
})
}
val finallyPath = x.finallyPath
if (finallyPath != null) {
if (finallyPath.isNotEmpty()) {
val lhs = JsNameRef(context.metadata.finallyPathName, JsAstUtils.stateMachineReceiver())
val rhs = JsArrayLiteral(finallyPath.map { JsIntLiteral(blockIndexes[it]!!) })
ctx.replaceMe(JsExpressionStatement(JsAstUtils.assignment(lhs, rhs).source(x.source)).apply {
this.finallyPath = true
})
}
else {
ctx.removeMe()
}
}
}
}
return forEach { blockReplacementVisitor.accept(it.jsBlock) }
}
fun CoroutineBlock.buildGraph(globalCatchBlock: CoroutineBlock?): Map<CoroutineBlock, Set<CoroutineBlock>> {
// That's a little more than DFS due to need of tracking finally paths
val visitedBlocks = mutableSetOf<CoroutineBlock>()
val graph = mutableMapOf<CoroutineBlock, MutableSet<CoroutineBlock>>()
fun visitBlock(block: CoroutineBlock) {
if (block in visitedBlocks) return
for (finallyPath in block.collectFinallyPaths()) {
for ((finallySource, finallyTarget) in (listOf(block) + finallyPath).zip(finallyPath)) {
if (graph.getOrPut(finallySource) { mutableSetOf() }.add(finallyTarget)) {
visitedBlocks -= finallySource
}
}
}
visitedBlocks += block
val successors = graph.getOrPut(block) { mutableSetOf() }
successors += block.collectTargetBlocks()
if (block == this && globalCatchBlock != null) {
successors += globalCatchBlock
}
successors.forEach(::visitBlock)
}
visitBlock(this)
return graph
}
private fun CoroutineBlock.collectTargetBlocks(): Set<CoroutineBlock> {
val targetBlocks = mutableSetOf<CoroutineBlock>()
jsBlock.accept(object : RecursiveJsVisitor() {
override fun visitDebugger(x: JsDebugger) {
targetBlocks += listOfNotNull(x.targetExceptionBlock) + listOfNotNull(x.targetBlock)
}
})
return targetBlocks
}
private fun CoroutineBlock.collectFinallyPaths(): List<List<CoroutineBlock>> {
val finallyPaths = mutableListOf<List<CoroutineBlock>>()
jsBlock.accept(object : RecursiveJsVisitor() {
override fun visitDebugger(x: JsDebugger) {
x.finallyPath?.let { finallyPaths += it }
}
})
return finallyPaths
}
fun JsBlock.replaceSpecialReferences(context: CoroutineTransformationContext) {
val visitor = object : JsVisitorWithContextImpl() {
override fun endVisit(x: JsThisRef, ctx: JsContext<in JsNode>) {
ctx.replaceMe(JsNameRef(context.receiverFieldName, JsThisRef()))
}
override fun visit(x: JsFunction, ctx: JsContext<*>) = false
override fun endVisit(x: JsNameRef, ctx: JsContext<in JsNode>) {
when {
x.coroutineReceiver -> {
ctx.replaceMe(JsThisRef())
}
x.coroutineController -> {
ctx.replaceMe(JsNameRef(context.controllerFieldName, x.qualifier).apply {
source = x.source
sideEffects = SideEffectKind.PURE
})
}
x.coroutineResult -> {
ctx.replaceMe(JsNameRef(context.metadata.resultName, x.qualifier).apply {
source = x.source
sideEffects = SideEffectKind.DEPENDS_ON_STATE
})
}
}
}
}
visitor.accept(this)
}
fun JsBlock.replaceSpecialReferencesInSimpleFunction(continuationParam: JsParameter, resultVar: JsName) {
val visitor = object : JsVisitorWithContextImpl() {
override fun visit(x: JsFunction, ctx: JsContext<*>) = false
override fun endVisit(x: JsNameRef, ctx: JsContext<in JsNode>) {
when {
x.coroutineReceiver -> {
ctx.replaceMe(pureFqn(continuationParam.name, null).source(x.source))
}
x.coroutineController -> {
ctx.replaceMe(JsThisRef().apply {
source = x.source
})
}
x.coroutineResult && x.qualifier.let { it is JsNameRef && it.name == continuationParam.name } -> {
ctx.replaceMe(pureFqn(resultVar, null).source(x.source))
}
}
}
}
visitor.accept(this)
}
fun List<CoroutineBlock>.collectVariablesSurvivingBetweenBlocks(localVariables: Set<JsName>, parameters: Set<JsName>): Set<JsName> {
val varDefinedIn = localVariables.associate { it to mutableSetOf<Int>() }
val varDeclaredIn = localVariables.associate { it to mutableSetOf<Int>() }
val varUsedIn = localVariables.associate { it to mutableSetOf<Int>() }
for ((blockIndex, block) in withIndex()) {
for (statement in block.statements) {
statement.accept(object : RecursiveJsVisitor() {
override fun visitNameRef(nameRef: JsNameRef) {
super.visitNameRef(nameRef)
varUsedIn[nameRef.name]?.add(blockIndex)
}
override fun visit(x: JsVars.JsVar) {
varDeclaredIn[x.name]?.add(blockIndex)
if (x.initExpression != null) {
varDefinedIn[x.name]?.add(blockIndex)
}
super.visit(x)
}
override fun visitBinaryExpression(x: JsBinaryOperation) {
val lhs = x.arg1
if (x.operator.isAssignment && lhs is JsNameRef) {
varDefinedIn[lhs.name]?.add(blockIndex)?.let {
accept(x.arg2)
return
}
}
super.visitBinaryExpression(x)
}
override fun visitFunction(x: JsFunction) {
x.name?.let {
varDefinedIn[it]?.add(blockIndex)
}
}
override fun visitLabel(x: JsLabel) {
accept(x.statement)
}
override fun visitBreak(x: JsBreak) {}
override fun visitContinue(x: JsContinue) {}
})
}
}
fun JsName.isLocalInBlock(): Boolean {
val def = varDefinedIn[this]!!
val use = varUsedIn[this]!!
val decl = varDeclaredIn[this]!!
if (def.size == 1 && use.size == 1) {
val singleDef = def.single()
val singleUse = use.single()
return singleDef == singleUse && decl.isNotEmpty()
}
return use.isEmpty()
}
return localVariables.filterNot { localVar ->
if (localVar in parameters) {
varUsedIn[localVar]!!.isEmpty() && varDefinedIn[localVar]!!.isEmpty() && varDeclaredIn[localVar]!!.isEmpty()
}
else {
localVar.isLocalInBlock()
}
}.toSet()
}
fun JsBlock.replaceLocalVariables(context: CoroutineTransformationContext, localVariables: Set<JsName>) {
replaceSpecialReferences(context)
val visitor = object : JsVisitorWithContextImpl() {
override fun visit(x: JsFunction, ctx: JsContext<*>): Boolean = false
override fun endVisit(x: JsFunction, ctx: JsContext<in JsNode>) {
val freeVars = x.collectFreeVariables().intersect(localVariables)
if (freeVars.isNotEmpty()) {
val wrapperFunction = JsFunction(x.scope.parent, JsBlock(), "")
val wrapperInvocation = JsInvocation(wrapperFunction)
wrapperFunction.body.statements += JsReturn(x)
val nameMap = freeVars.associate { it to JsScope.declareTemporaryName(it.ident) }
for (freeVar in freeVars) {
wrapperFunction.parameters += JsParameter(nameMap[freeVar]!!)
wrapperInvocation.arguments += JsNameRef(context.getFieldName(freeVar), JsThisRef())
}
x.body = replaceNames(x.body, nameMap.mapValues { it.value.makeRef() })
ctx.replaceMe(wrapperInvocation)
}
}
override fun endVisit(x: JsNameRef, ctx: JsContext<in JsNode>) {
if (x.qualifier == null && x.name in localVariables) {
val fieldName = context.getFieldName(x.name!!)
ctx.replaceMe(JsNameRef(fieldName, JsThisRef()).source(x.source))
}
}
override fun endVisit(x: JsVars, ctx: JsContext<in JsStatement>) {
if (x.vars.none { it.name in localVariables }) return
val statements = mutableListOf<JsStatement>()
for ((range, shouldReplace) in x.vars.splitToRanges { it.name in localVariables }) {
if (shouldReplace) {
val assignments = x.vars.mapNotNull {
val fieldName = context.getFieldName(it.name)
val initExpression = it.initExpression
if (initExpression != null) {
JsAstUtils.assignment(JsNameRef(fieldName, JsThisRef()), it.initExpression)
}
else {
null
}
}
if (assignments.isNotEmpty()) {
statements += JsExpressionStatement(JsAstUtils.newSequence(assignments))
}
}
else {
statements += JsVars(*range.toTypedArray())
}
}
if (statements.size == 1) {
ctx.replaceMe(statements[0])
}
else {
ctx.removeMe()
ctx.addPrevious(statements)
}
}
}
visitor.accept(this)
}
internal fun JsExpression?.isStateMachineResult() =
this is JsNameRef && this.coroutineResult && qualifier.let { it is JsNameRef && it.coroutineReceiver && it.qualifier == null }