blob: d72e220b5ba8be58e90ed7a77bfc2d6251292883 [file] [log] [blame]
/*
* Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package org.jetbrains.kotlin.generators.protobuf
import org.jetbrains.kotlin.generators.util.GeneratorsFileUtil
import org.jetbrains.kotlin.library.metadata.DebugKlibMetadataProtoBuf
import org.jetbrains.kotlin.metadata.DebugExtOptionsProtoBuf
import org.jetbrains.kotlin.metadata.DebugProtoBuf
import org.jetbrains.kotlin.metadata.builtins.DebugBuiltInsProtoBuf
import org.jetbrains.kotlin.metadata.java.DebugJavaClassProtoBuf
import org.jetbrains.kotlin.metadata.js.DebugJsProtoBuf
import org.jetbrains.kotlin.metadata.jvm.DebugJvmProtoBuf
import org.jetbrains.kotlin.protobuf.Descriptors
import org.jetbrains.kotlin.utils.Printer
import java.io.File
import java.util.*
class GenerateProtoBufCompare {
companion object {
val DEST_FILE: File = File("build-common/src/org/jetbrains/kotlin/incremental/ProtoCompareGenerated.kt")
@JvmStatic
fun main(args: Array<String>) {
generate(DEST_FILE)
}
fun generate(destFile: File, forbidGenerationOnTeamcity: Boolean = true) {
GeneratorsFileUtil.writeFileIfContentChanged(
destFile, GenerateProtoBufCompare().generate(),
forbidGenerationOnTeamcity = forbidGenerationOnTeamcity
)
}
}
private val JAVA_TYPES_WITH_INLINED_EQUALS: EnumSet<Descriptors.FieldDescriptor.JavaType> = EnumSet.of(
Descriptors.FieldDescriptor.JavaType.INT,
Descriptors.FieldDescriptor.JavaType.LONG,
Descriptors.FieldDescriptor.JavaType.FLOAT,
Descriptors.FieldDescriptor.JavaType.DOUBLE,
Descriptors.FieldDescriptor.JavaType.BOOLEAN,
Descriptors.FieldDescriptor.JavaType.STRING,
Descriptors.FieldDescriptor.JavaType.ENUM
)
private val RESULT_NAME = "result"
private val STRING_INDEXES_NAME = "StringIndexes"
private val CLASS_ID_INDEXES_NAME = "ClassIdIndexes"
private val OLD_PREFIX = "old"
private val NEW_PREFIX = "new"
private val CHECK_EQUALS_NAME = "checkEquals"
private val CHECK_STRING_EQUALS_NAME = "checkStringEquals"
private val CHECK_CLASS_ID_EQUALS_NAME = "checkClassIdEquals"
private val HASH_CODE_NAME = "hashCode"
private val extensions = object {
val jvm = DebugJvmProtoBuf.getDescriptor().extensions
val js = DebugJsProtoBuf.getDescriptor().extensions
val java = DebugJavaClassProtoBuf.getDescriptor().extensions
val builtIns = DebugBuiltInsProtoBuf.getDescriptor().extensions
val klib = DebugKlibMetadataProtoBuf.getDescriptor().extensions
private val extensionsMap = (jvm + js + java + builtIns + klib).groupBy { it.containingType }
operator fun get(desc: Descriptors.Descriptor): List<Descriptors.FieldDescriptor>? = extensionsMap[desc]
fun getEnumName(fieldDescriptor: Descriptors.FieldDescriptor): String {
var extensionPrefix = ""
if (fieldDescriptor.isExtension) {
extensionPrefix = when {
fieldDescriptor in jvm -> "jvmExt_"
fieldDescriptor in js -> "jsExt_"
fieldDescriptor in java -> "javaExt_"
fieldDescriptor in builtIns -> "builtInsExt_"
fieldDescriptor in klib -> "klibExt_"
else -> error("Unknown extension")
}
}
return (extensionPrefix + fieldDescriptor.name.javaName + (if (fieldDescriptor.isRepeated) "List" else ""))
.replace("[A-Z]".toRegex()) { "_" + it.value }
.uppercase()
}
}
private val allMessages: MutableSet<Descriptors.Descriptor> = linkedSetOf()
private val messagesToProcess: Queue<Descriptors.Descriptor> = LinkedList()
private val repeatedFields: MutableSet<Descriptors.FieldDescriptor> = linkedSetOf()
fun generate(): String {
val sb = StringBuilder()
val p = Printer(sb)
p.println(File("license/COPYRIGHT.txt").readText())
p.println("@file:Suppress(\"UNUSED_PARAMETER\")")
p.println("")
p.println("package org.jetbrains.kotlin.incremental")
p.println()
p.println("import org.jetbrains.kotlin.library.metadata.KlibMetadataProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.ProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.builtins.BuiltInsProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.deserialization.NameResolver")
p.println("import org.jetbrains.kotlin.metadata.java.JavaClassProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.js.JsProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.jvm.JvmProtoBuf")
p.println("import org.jetbrains.kotlin.metadata.serialization.Interner")
p.println("import org.jetbrains.kotlin.name.ClassId")
p.println("import org.jetbrains.kotlin.serialization.deserialization.getClassId")
p.println("import java.util.*")
p.println()
p.println("/** This file is generated by org.jetbrains.kotlin.generators.protobuf.GenerateProtoBufCompare. DO NOT MODIFY MANUALLY */")
p.println()
p.println("open class ProtoCompareGenerated(")
p.pushIndent()
p.println("val oldNameResolver: NameResolver,")
p.println("val newNameResolver: NameResolver,")
p.println("oldTypeTable: ProtoBuf.TypeTable?,")
p.println("newTypeTable: ProtoBuf.TypeTable?")
p.popIndent()
p.println(") {")
p.pushIndent()
p.println("private val strings = Interner<String>()")
p.println("val $OLD_PREFIX${STRING_INDEXES_NAME}Map: MutableMap<Int, Int> = hashMapOf()")
p.println("val $NEW_PREFIX${STRING_INDEXES_NAME}Map: MutableMap<Int, Int> = hashMapOf()")
p.println("val $OLD_PREFIX${CLASS_ID_INDEXES_NAME}Map: MutableMap<Int, Int> = hashMapOf()")
p.println("val $NEW_PREFIX${CLASS_ID_INDEXES_NAME}Map: MutableMap<Int, Int> = hashMapOf()")
p.println("val oldTypeTable: ProtoBuf.TypeTable = oldTypeTable ?: ProtoBuf.TypeTable.getDefaultInstance()")
p.println("val newTypeTable: ProtoBuf.TypeTable = newTypeTable ?: ProtoBuf.TypeTable.getDefaultInstance()")
p.println()
p.println("private val classIds = Interner<ClassId>()")
val fileDescriptor = DebugProtoBuf.getDescriptor()
addMessageToProcessIfNeeded(fileDescriptor.findMessageTypeByName("Package"))
addMessageToProcessIfNeeded(fileDescriptor.findMessageTypeByName("Class"))
val generateDifference = allMessages.toSet()
while (messagesToProcess.isNotEmpty()) {
p.println()
val message = messagesToProcess.poll()
generateForMessage(message, p)
if (message in generateDifference) {
generateDiffForMessage(message, p)
}
}
repeatedFields.forEach { generateHelperMethodForRepeatedField(it, p) }
p.println()
generateAuxiliaryMethods(p)
p.popIndent()
p.println("}")
allMessages.forEach { generateHashCodeFun(it, p) }
return sb.toString()
}
fun generateAuxiliaryMethods(p: Printer) {
p.println("fun oldGetTypeById(id: Int): ProtoBuf.Type = oldTypeTable.getType(id) ?: error(\"Unknown type id: ${'$'}id\")")
p.println("fun newGetTypeById(id: Int): ProtoBuf.Type = newTypeTable.getType(id) ?: error(\"Unknown type id: ${'$'}id\")")
p.println()
p.println("fun oldGetIndexOfString(index: Int): Int = getIndexOfString(index, oldStringIndexesMap, oldNameResolver)")
p.println("fun newGetIndexOfString(index: Int): Int = getIndexOfString(index, newStringIndexesMap, newNameResolver)")
p.println()
p.println("fun getIndexOfString(index: Int, map: MutableMap<Int, Int>, nameResolver: NameResolver): Int {")
p.println(" map[index]?.let { return it }")
p.println()
p.println(" val result = strings.intern(nameResolver.getString(index))")
p.println(" map[index] = result")
p.println(" return result")
p.println("}")
p.println()
p.println("fun oldGetIndexOfClassId(index: Int): Int = getIndexOfClassId(index, oldClassIdIndexesMap, oldNameResolver)")
p.println("fun newGetIndexOfClassId(index: Int): Int = getIndexOfClassId(index, newClassIdIndexesMap, newNameResolver)")
p.println()
p.println("fun getIndexOfClassId(index: Int, map: MutableMap<Int, Int>, nameResolver: NameResolver): Int {")
p.println(" map[index]?.let { return it }")
p.println()
p.println(" val result = classIds.intern(nameResolver.getClassId(index))")
p.println(" map[index] = result")
p.println(" return result")
p.println("}")
p.println()
p.println("private fun $CHECK_STRING_EQUALS_NAME(old: Int, new: Int): Boolean {")
p.println(" return oldGetIndexOfString(old) == newGetIndexOfString(new)")
p.println("}")
p.println()
p.println("private fun $CHECK_CLASS_ID_EQUALS_NAME(old: Int, new: Int): Boolean {")
p.println(" return oldGetIndexOfClassId(old) == newGetIndexOfClassId(new)")
p.println("}")
}
fun generateHashCodeFun(descriptor: Descriptors.Descriptor, p: Printer) {
val typeName = descriptor.typeName
val fields = descriptor.fields.filter { !it.shouldSkip }
val extFields = extensions[descriptor]?.filter { !it.shouldSkip } ?: emptyList()
p.println()
p.println("fun $typeName.$HASH_CODE_NAME(stringIndexes: (Int) -> Int, fqNameIndexes: (Int) -> Int, typeById: (Int) -> ProtoBuf.Type): Int {")
p.pushIndent()
p.println("var $HASH_CODE_NAME = 1")
fields.forEach { field -> generateHashCodeForField(field, p, false) }
extFields.forEach { field -> generateHashCodeForField(field, p, true) }
p.println()
p.println("return $HASH_CODE_NAME")
p.popIndent()
p.println("}")
}
fun generateHashCodeForField(field: Descriptors.FieldDescriptor, p: Printer, isExtensionField: Boolean) {
val fieldName = field.name.javaName
val capFieldName = fieldName.replaceFirstChar(Char::uppercaseChar)
val outerClassName = field.file.options.javaOuterClassname.removePrefix("Debug")
val fullFieldName = "$outerClassName.$fieldName"
val upperBound = if (isExtensionField) "getExtensionCount($fullFieldName)" else "${fieldName}Count"
val hasMethod = if (isExtensionField) "hasExtension($fullFieldName)" else "has$capFieldName()"
val fieldValue = if (isExtensionField) "getExtension($fullFieldName)" else fieldName
val repeatedFieldValue = if (isExtensionField) "getExtension($fullFieldName, i)" else "get$capFieldName(i)"
p.println()
if (field.isRepeated) {
p.println("for(i in 0..$upperBound - 1) {")
p.println(" $HASH_CODE_NAME = 31 * $HASH_CODE_NAME + ${fieldToHashCode(field, repeatedFieldValue)}")
p.println("}")
} else if (field.isRequired) {
p.println("$HASH_CODE_NAME = 31 * $HASH_CODE_NAME + ${fieldToHashCode(field, fieldValue)}")
} else if (field.isOptional) {
p.println("if ($hasMethod) {")
p.println(" $HASH_CODE_NAME = 31 * $HASH_CODE_NAME + ${fieldToHashCode(field, fieldValue)}")
p.println("}")
}
}
fun generateForMessage(descriptor: Descriptors.Descriptor, p: Printer) {
val typeName = descriptor.typeName
val fields = descriptor.fields.filter { !it.shouldSkip }
val extFields = extensions[descriptor]?.filter { !it.shouldSkip } ?: emptyList()
p.println("open fun $CHECK_EQUALS_NAME(old: $typeName, new: $typeName): Boolean {")
p.pushIndent()
fields.forEach { field -> FieldGeneratorImpl(field, p).generate() }
extFields.forEach { field -> ExtFieldGeneratorImpl(field, p).generate() }
p.println("return true")
p.popIndent()
p.println("}")
}
fun generateDiffForMessage(descriptor: Descriptors.Descriptor, p: Printer) {
val typeName = descriptor.typeName
val className = typeName.replace(".", "")
val fields = descriptor.fields.filter { !it.shouldSkip }
val extFields = extensions[descriptor]?.filter { !it.shouldSkip } ?: emptyList()
val allFields = fields + extFields
p.println("enum class ${className}Kind {")
p.println(allFields.joinToString(",\n ") { " " + extensions.getEnumName(it) })
p.println("}")
p.println()
p.println("fun difference(old: $typeName, new: $typeName): EnumSet<${className}Kind> {")
p.pushIndent()
p.println("val $RESULT_NAME = EnumSet.noneOf(${className}Kind::class.java)")
p.println()
fields.forEach { field -> FieldGeneratorForDiff(field, p).generate() }
extFields.forEach { field -> ExtFieldGeneratorForDiff(field, p).generate() }
p.println("return $RESULT_NAME")
p.popIndent()
p.println("}")
}
fun generateHelperMethodForRepeatedField(field: Descriptors.FieldDescriptor, p: Printer) {
assert(field.isRepeated) { "expected repeated field: ${field.name}" }
val typeName = field.containingType.typeName
val fieldName = field.name.javaName
val capFieldName = fieldName.replaceFirstChar(Char::uppercaseChar)
val methodName = field.helperMethodName()
p.println()
p.println("open fun $methodName(old: $typeName, new: $typeName): Boolean {")
p.pushIndent()
p.println("if (old.${fieldName}Count != new.${fieldName}Count) return false")
p.println()
p.println("for(i in 0..old.${fieldName}Count - 1) {")
p.printlnIfWithComparisonIndent(field, "get$capFieldName(i)")
p.println("}")
p.println()
p.println("return true")
p.popIndent()
p.println("}")
}
abstract inner class FieldGenerator(val field: Descriptors.FieldDescriptor, val p: Printer) {
val statement = field.getStatement()
abstract fun Descriptors.FieldDescriptor.getStatement(): String
fun generate() {
if (field.isRepeated) {
printRepeatedField()
} else if (field.isRequired) {
printRequiredField()
} else if (field.isOptional) {
printOptionalField()
}
p.println()
addMessageTypeToProcessIfNeeded(field)
}
abstract fun printRepeatedField()
abstract fun printRequiredField()
abstract fun printOptionalField()
}
open inner class FieldGeneratorImpl(field: Descriptors.FieldDescriptor, p: Printer) : FieldGenerator(field, p) {
val fieldName = field.name.javaName
val capFieldName = fieldName.replaceFirstChar(Char::uppercaseChar)
override fun printRepeatedField() {
repeatedFields.add(field)
p.println("if (!${field.helperMethodName()}(old, new)) $statement")
}
override fun printRequiredField() {
p.printlnIfWithComparison(field, fieldName, statement)
}
override fun printOptionalField() {
p.println("if (old.has$capFieldName() != new.has$capFieldName()) $statement")
p.println("if (old.has$capFieldName()) {")
p.printlnIfWithComparisonIndent(field, fieldName, statement)
p.println("}")
}
override fun Descriptors.FieldDescriptor.getStatement(): String = "return false"
}
open inner class ExtFieldGeneratorImpl(field: Descriptors.FieldDescriptor, p: Printer) : FieldGenerator(field, p) {
val outerClassName = field.file.options.javaOuterClassname.removePrefix("Debug")
val fieldName = field.name.javaName
val fullFieldName = "$outerClassName.$fieldName"
override fun printRepeatedField() {
p.printlnMultiline(
"""
if (old.getExtensionCount($fullFieldName) != new.getExtensionCount($fullFieldName)) {
$statement
}
else {
for(i in 0..old.getExtensionCount($fullFieldName) - 1) {
${ifWithComparison(field, "getExtension($fullFieldName, i)", statement)}
}
}
"""
)
}
override fun printRequiredField() {
p.printlnIfWithComparison(field, "getExtension($fullFieldName)", statement)
}
override fun printOptionalField() {
p.println("if (old.hasExtension($fullFieldName) != new.hasExtension($fullFieldName)) $statement")
p.println("if (old.hasExtension($fullFieldName)) {")
p.printlnIfWithComparisonIndent(field, "getExtension($fullFieldName)", statement)
p.println("}")
}
override fun Descriptors.FieldDescriptor.getStatement(): String = "return false"
}
inner class FieldGeneratorForDiff(field: Descriptors.FieldDescriptor, p: Printer) : FieldGeneratorImpl(field, p) {
override fun Descriptors.FieldDescriptor.getStatement(): String = statementForDiff
}
inner class ExtFieldGeneratorForDiff(field: Descriptors.FieldDescriptor, p: Printer) : ExtFieldGeneratorImpl(field, p) {
override fun Descriptors.FieldDescriptor.getStatement(): String = statementForDiff
}
private val Descriptors.FieldDescriptor.statementForDiff: String
get() = "$RESULT_NAME.add(${containingType.typeName.replace(".", "")}Kind.${extensions.getEnumName(this)})"
private val Descriptors.Descriptor.shouldSkip: Boolean
get() = options.getExtension(DebugExtOptionsProtoBuf.skipMessageInComparison)
private val Descriptors.FieldDescriptor.shouldSkip: Boolean
get() = options.getExtension(DebugExtOptionsProtoBuf.skipInComparison)
|| (type == Descriptors.FieldDescriptor.Type.MESSAGE && messageType.shouldSkip)
private fun addMessageTypeToProcessIfNeeded(field: Descriptors.FieldDescriptor) {
if (field.javaType == Descriptors.FieldDescriptor.JavaType.MESSAGE) {
addMessageToProcessIfNeeded(field.messageType)
}
}
private fun addMessageToProcessIfNeeded(descriptor: Descriptors.Descriptor) {
if (descriptor !in allMessages && !descriptor.shouldSkip) {
allMessages.add(descriptor)
messagesToProcess.add(descriptor)
}
}
private fun ifWithComparison(field: Descriptors.FieldDescriptor, expr: String, statement: String) =
when {
field.options.getExtension(DebugExtOptionsProtoBuf.typeIdInTable) ->
"if (!$CHECK_EQUALS_NAME(oldTypeTable.getType(old.$expr), newTypeTable.getType(new.$expr))) $statement"
field.options.getExtension(DebugExtOptionsProtoBuf.stringIdInTable) ||
field.options.getExtension(DebugExtOptionsProtoBuf.nameIdInTable) ->
"if (!$CHECK_STRING_EQUALS_NAME(old.$expr, new.$expr)) $statement"
field.options.getExtension(DebugExtOptionsProtoBuf.fqNameIdInTable) ->
"if (!$CHECK_CLASS_ID_EQUALS_NAME(old.$expr, new.$expr)) $statement"
field.javaType in JAVA_TYPES_WITH_INLINED_EQUALS ->
"if (old.$expr != new.$expr) $statement"
else ->
"if (!$CHECK_EQUALS_NAME(old.$expr, new.$expr)) $statement"
}
private fun Printer.printlnIfWithComparison(field: Descriptors.FieldDescriptor, expr: String, statement: String = "return false") {
this.println(ifWithComparison(field, expr, statement))
}
fun Printer.printlnIfWithComparisonIndent(field: Descriptors.FieldDescriptor, expr: String, statement: String = "return false") {
pushIndent()
printlnIfWithComparison(field, expr, statement)
popIndent()
}
private fun fieldToHashCode(field: Descriptors.FieldDescriptor, expr: String): String =
when {
field.options.getExtension(DebugExtOptionsProtoBuf.typeIdInTable) ->
"typeById($expr).$HASH_CODE_NAME(stringIndexes, fqNameIndexes, typeById)"
field.options.getExtension(DebugExtOptionsProtoBuf.stringIdInTable) ||
field.options.getExtension(DebugExtOptionsProtoBuf.nameIdInTable) ->
"stringIndexes($expr)"
field.options.getExtension(DebugExtOptionsProtoBuf.fqNameIdInTable) ->
"fqNameIndexes($expr)"
field.javaType == Descriptors.FieldDescriptor.JavaType.INT ->
"$expr"
field.javaType in JAVA_TYPES_WITH_INLINED_EQUALS ->
"$expr.$HASH_CODE_NAME()"
else ->
"$expr.$HASH_CODE_NAME(stringIndexes, fqNameIndexes, typeById)"
}
private val Descriptors.Descriptor.typeName: String
get() {
val outerClassName = file.options.javaOuterClassname.removePrefix("Debug")
val packageHeader = file.`package`
return outerClassName + fullName.removePrefix(packageHeader)
}
private fun Descriptors.FieldDescriptor.helperMethodName(): String {
val packageHeader = this.file.`package`
val descriptor = this.containingType
val className = descriptor.fullName.removePrefix(packageHeader).replace(".", "")
val capFieldName = this.name.javaName.replaceFirstChar(Char::uppercaseChar)
return "$CHECK_EQUALS_NAME$className$capFieldName"
}
private val String.javaName: String
get() = this.split("_").joinToString("") { it.replaceFirstChar(Char::uppercaseChar) }.replaceFirstChar(Char::lowercaseChar)
}
private fun Printer.printlnMultiline(string: String) {
string.trimIndent().split("\n").forEach { this.println(it) }
}