Fix composable calls in anonymous object initializer

Compose allows calls to functions from anonymous object initializers. K1 checker was updated to use non-local return functionality (similar to suspend functions) which doesn't allow this kind of manipulation. This change works around that by extracting some logic from `checkInlineUsage` function in compiler that resolves overloads correctly while using old way of checking whether argument is inline.

Fixes: 320261458
Tests: Compiler and runtime tests
Change-Id: I757373170fc484891d26f461d480ed142a02a189
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
index cd9027b..4cdfe6b 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
@@ -510,4 +510,34 @@
                 })
             }
         )
+
+    @Test
+    fun composableCallInAnonymousObjectInitializer() =
+        verifyGoldenComposeIrTransform(
+            extra = """
+                import androidx.compose.runtime.*
+
+                @Composable fun Foo(): State<Int> = TODO()
+            """,
+            source = """
+                import androidx.compose.runtime.*
+
+                @Composable fun Test(inputs: List<Int>) {
+                    val objs = inputs.map {
+                        object {
+                            init {
+                                Foo()
+                            }
+
+                            val state = Foo()
+                            val value by Foo()
+                        }
+                    }
+                    objs.forEach {
+                        println(it.state)
+                        println(it.value)
+                    }
+                }
+            """
+        )
 }
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
index eb8fa34..8cb5eba 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableCheckerTests.kt
@@ -1259,10 +1259,19 @@
                 val x = object {
                   val b = remember { mutableStateOf(3) }
                 }
+                val y = run {
+                    object {
+                      val b = remember { mutableStateOf(3) }
+                      val a = object {
+                        val b = remember { mutableStateOf(3) }
+                      }
+                    }
+                }
                 class Bar {
                   val a = <!COMPOSABLE_INVOCATION!>remember<!> { mutableStateOf(5) }
                 }
                 print(x)
+                print(y)
             }
         """
         )
diff --git "a/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = false\135.txt" "b/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = false\135.txt"
new file mode 100644
index 0000000..2354258
--- /dev/null
+++ "b/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = false\135.txt"
@@ -0,0 +1,68 @@
+//
+// Source
+// ------------------------------------------
+
+import androidx.compose.runtime.*
+
+@Composable fun Test(inputs: List<Int>) {
+    val objs = inputs.map {
+        object {
+            init {
+                Foo()
+            }
+
+            val state = Foo()
+            val value by Foo()
+        }
+    }
+    objs.forEach {
+        println(it.state)
+        println(it.value)
+    }
+}
+
+//
+// Transformed IR
+// ------------------------------------------
+
+@Composable
+fun Test(inputs: List<Int>, %composer: Composer?, %changed: Int) {
+  %composer = %composer.startRestartGroup(<>)
+  sourceInformation(%composer, "C(Test):Test.kt")
+  val %dirty = %changed
+  if (%changed and 0b0110 == 0) {
+    %dirty = %dirty or if (%composer.changedInstance(inputs)) 0b0100 else 0b0010
+  }
+  if (%dirty and 0b0011 != 0b0010 || !%composer.skipping) {
+    if (isTraceInProgress()) {
+      traceEventStart(<>, %dirty, -1, <>)
+    }
+    val objs = inputs.map { it: Int ->
+      val tmp0_return = <block>{
+        object {
+          init {
+            Foo(%composer, 0)
+          }
+          val state: State<Int> = Foo(%composer, 0)
+          val value: State<Int> = Foo(%composer, 0)
+            get() {
+              return <this>.value%delegate.getValue(<this>, ::value)
+            }
+        }
+      }
+      tmp0_return
+    }
+    objs.forEach { it: <no name provided> ->
+      println(it.state)
+      println(it.value)
+    }
+    if (isTraceInProgress()) {
+      traceEventEnd()
+    }
+  } else {
+    %composer.skipToGroupEnd()
+  }
+  %composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
+    Test(inputs, %composer, updateChangedFlags(%changed or 0b0001))
+  }
+}
diff --git "a/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = true\135.txt" "b/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = true\135.txt"
new file mode 100644
index 0000000..2354258
--- /dev/null
+++ "b/compose/compiler/compiler-hosted/integration-tests/src/test/resources/golden/androidx.compose.compiler.plugins.kotlin.ComposerParamTransformTests/composableCallInAnonymousObjectInitializer\133useFir = true\135.txt"
@@ -0,0 +1,68 @@
+//
+// Source
+// ------------------------------------------
+
+import androidx.compose.runtime.*
+
+@Composable fun Test(inputs: List<Int>) {
+    val objs = inputs.map {
+        object {
+            init {
+                Foo()
+            }
+
+            val state = Foo()
+            val value by Foo()
+        }
+    }
+    objs.forEach {
+        println(it.state)
+        println(it.value)
+    }
+}
+
+//
+// Transformed IR
+// ------------------------------------------
+
+@Composable
+fun Test(inputs: List<Int>, %composer: Composer?, %changed: Int) {
+  %composer = %composer.startRestartGroup(<>)
+  sourceInformation(%composer, "C(Test):Test.kt")
+  val %dirty = %changed
+  if (%changed and 0b0110 == 0) {
+    %dirty = %dirty or if (%composer.changedInstance(inputs)) 0b0100 else 0b0010
+  }
+  if (%dirty and 0b0011 != 0b0010 || !%composer.skipping) {
+    if (isTraceInProgress()) {
+      traceEventStart(<>, %dirty, -1, <>)
+    }
+    val objs = inputs.map { it: Int ->
+      val tmp0_return = <block>{
+        object {
+          init {
+            Foo(%composer, 0)
+          }
+          val state: State<Int> = Foo(%composer, 0)
+          val value: State<Int> = Foo(%composer, 0)
+            get() {
+              return <this>.value%delegate.getValue(<this>, ::value)
+            }
+        }
+      }
+      tmp0_return
+    }
+    objs.forEach { it: <no name provided> ->
+      println(it.state)
+      println(it.value)
+    }
+    if (isTraceInProgress()) {
+      traceEventEnd()
+    }
+  } else {
+    %composer.skipToGroupEnd()
+  }
+  %composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
+    Test(inputs, %composer, updateChangedFlags(%changed or 0b0001))
+  }
+}
diff --git a/compose/compiler/compiler-hosted/runtime-tests/src/commonTest/kotlin/androidx/compose/compiler/test/CompositionTests.kt b/compose/compiler/compiler-hosted/runtime-tests/src/commonTest/kotlin/androidx/compose/compiler/test/CompositionTests.kt
new file mode 100644
index 0000000..bd5a443
--- /dev/null
+++ b/compose/compiler/compiler-hosted/runtime-tests/src/commonTest/kotlin/androidx/compose/compiler/test/CompositionTests.kt
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package androidx.compose.compiler.test
+
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mock.Text
+import androidx.compose.runtime.mock.compositionTest
+import androidx.compose.runtime.mock.validate
+import androidx.compose.runtime.rememberUpdatedState
+import kotlin.test.Test
+
+class CompositionTests {
+    @Test
+    fun composableInAnonymousObjectDeclaration() = compositionTest {
+        val list = listOf("a", "b")
+        compose {
+            list.forEach { s ->
+                val obj = object {
+                    val value by rememberUpdatedState(s)
+                }
+                Text(obj.value)
+            }
+        }
+
+        validate {
+            list.forEach {
+                Text(it)
+            }
+        }
+    }
+}
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k1/ComposableCallChecker.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k1/ComposableCallChecker.kt
index 1a27a2d..a776a4d 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k1/ComposableCallChecker.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k1/ComposableCallChecker.kt
@@ -23,7 +23,6 @@
 import org.jetbrains.kotlin.container.useInstance
 import org.jetbrains.kotlin.descriptors.CallableDescriptor
 import org.jetbrains.kotlin.descriptors.ClassDescriptor
-import org.jetbrains.kotlin.descriptors.DeclarationDescriptor
 import org.jetbrains.kotlin.descriptors.FunctionDescriptor
 import org.jetbrains.kotlin.descriptors.ModuleDescriptor
 import org.jetbrains.kotlin.descriptors.PropertyDescriptor
@@ -61,12 +60,8 @@
 import org.jetbrains.kotlin.resolve.calls.model.VariableAsFunctionResolvedCall
 import org.jetbrains.kotlin.resolve.calls.util.getResolvedCall
 import org.jetbrains.kotlin.resolve.calls.util.getValueArgumentForExpression
-import org.jetbrains.kotlin.resolve.inline.InlineUtil
 import org.jetbrains.kotlin.resolve.inline.InlineUtil.isInlinedArgument
 import org.jetbrains.kotlin.resolve.sam.getSingleAbstractMethodOrNull
-import org.jetbrains.kotlin.resolve.scopes.LexicalScope
-import org.jetbrains.kotlin.resolve.scopes.LexicalScopeKind
-import org.jetbrains.kotlin.resolve.scopes.utils.parents
 import org.jetbrains.kotlin.resolve.source.PsiSourceElement
 import org.jetbrains.kotlin.types.KotlinType
 import org.jetbrains.kotlin.types.TypeUtils
@@ -182,46 +177,26 @@
                         return
                     }
 
-                    val containingScope = context.scope.parents.firstOrNull {
-                        it is LexicalScope &&
-                            it.kind == LexicalScopeKind.FUNCTION_INNER_SCOPE &&
-                            (it.ownerDescriptor as? FunctionDescriptor)
-                                ?.hasComposableAnnotation() == true
-                    }
-                    val containingComposable = (containingScope as? LexicalScope)?.ownerDescriptor
-
-                    if (containingComposable != null) {
-                        // TODO(lmr): in future, we should check for CALLS_IN_PLACE contract
-                        val isInlined = checkInlineUsage(
-                            containingComposable,
-                            context,
-                            resolvedCall
-                        )
-                        if (!isInlined) {
-                            illegalCall(context, reportOn)
-                            return
-                        } else {
-                            // since the function is inlined, we continue going up the PSI tree
-                            // until we find a composable context. We also mark this lambda
-                            context.trace.record(
-                                FrontendWritableSlices.LAMBDA_CAPABLE_OF_COMPOSER_CAPTURE,
-                                descriptor,
-                                true
-                            )
-                        }
+                    val isResolvedInline = bindingContext.get(
+                        BindingContext.NEW_INFERENCE_IS_LAMBDA_FOR_OVERLOAD_RESOLUTION_INLINE,
+                        node.functionLiteral
+                    ) == true
+                    val isInlined = isResolvedInline || isInlinedArgument(
+                        node.functionLiteral,
+                        bindingContext,
+                        true
+                    )
+                    if (!isInlined) {
+                        illegalCall(context, reportOn)
+                        return
                     } else {
-                        // if we didn't find a containing composable, the call is invalid. Stop
-                        // iteration when lambda is not inlined, as the lambda itself should be
-                        // composable to resolve compilation error here.
-                        val isInlined = isInlinedArgument(
-                            node.functionLiteral,
-                            bindingContext,
+                        // since the function is inlined, we continue going up the PSI tree
+                        // until we find a composable context. We also mark this lambda
+                        context.trace.record(
+                            FrontendWritableSlices.LAMBDA_CAPABLE_OF_COMPOSER_CAPTURE,
+                            descriptor,
                             true
                         )
-                        if (!isInlined) {
-                            illegalCall(context, reportOn)
-                            return
-                        }
                     }
                 }
                 is KtTryExpression -> {
@@ -345,17 +320,6 @@
         }
     }
 
-    private fun checkInlineUsage(
-        containingComposable: DeclarationDescriptor,
-        context: CallCheckerContext,
-        resolvedCall: ResolvedCall<*>
-    ): Boolean =
-        InlineUtil.checkNonLocalReturnUsage(
-            containingComposable as FunctionDescriptor,
-            resolvedCall.call.callElement as KtExpression,
-            context.resolutionContext
-        )
-
     private fun missingDisallowedComposableCallPropagation(
         context: CallCheckerContext,
         unmarkedParamEl: PsiElement,
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/IrSourcePrinter.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/IrSourcePrinter.kt
index 63c8129..c242640 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/IrSourcePrinter.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/IrSourcePrinter.kt
@@ -26,6 +26,7 @@
 import org.jetbrains.kotlin.ir.IrStatement
 import org.jetbrains.kotlin.ir.ObsoleteDescriptorBasedAPI
 import org.jetbrains.kotlin.ir.declarations.IrAnnotationContainer
+import org.jetbrains.kotlin.ir.declarations.IrAnonymousInitializer
 import org.jetbrains.kotlin.ir.declarations.IrClass
 import org.jetbrains.kotlin.ir.declarations.IrConstructor
 import org.jetbrains.kotlin.ir.declarations.IrDeclaration
@@ -1414,6 +1415,15 @@
         print("<<TYPEALIAS>>")
     }
 
+    override fun visitAnonymousInitializer(declaration: IrAnonymousInitializer) {
+        println("init {")
+        indented {
+            declaration.body.print()
+        }
+        println()
+        println("}")
+    }
+
     private fun IrType.renderSrc() =
         "${renderTypeAnnotations(annotations)}${renderTypeInner()}"