blob: 6d4e660e3334f538ce077b767e59a144268eab66 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright (C) 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for verify_overlaps_test.py."""
import io
import unittest
from signature_trie import InteriorNode
from signature_trie import signature_trie
class TestSignatureToElements(unittest.TestCase):
@staticmethod
def signature_to_elements(signature):
return InteriorNode.signature_to_elements(signature)
@staticmethod
def elements_to_signature(elements):
return InteriorNode.elements_to_selector(elements)
def test_nested_inner_classes(self):
elements = [
("package", "java"),
("package", "lang"),
("class", "ProcessBuilder"),
("class", "Redirect"),
("class", "1"),
("member", "<init>()V"),
]
signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_basic_member(self):
elements = [
("package", "java"),
("package", "lang"),
("class", "Object"),
("member", "hashCode()I"),
]
signature = "Ljava/lang/Object;->hashCode()I"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_double_dollar_class(self):
elements = [
("package", "java"),
("package", "lang"),
("class", "CharSequence"),
("class", ""),
("class", "ExternalSyntheticLambda0"),
("member", "<init>(Ljava/lang/CharSequence;)V"),
]
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
"-><init>(Ljava/lang/CharSequence;)V"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_no_member(self):
elements = [
("package", "java"),
("package", "lang"),
("class", "CharSequence"),
("class", ""),
("class", "ExternalSyntheticLambda0"),
]
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_wildcard(self):
elements = [
("package", "java"),
("package", "lang"),
("wildcard", "*"),
]
signature = "java/lang/*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_recursive_wildcard(self):
elements = [
("package", "java"),
("package", "lang"),
("wildcard", "**"),
]
signature = "java/lang/**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_wildcard(self):
elements = [
("wildcard", "*"),
]
signature = "*"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_recursive_wildcard(self):
elements = [
("wildcard", "**"),
]
signature = "**"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, self.elements_to_signature(elements))
def test_invalid_no_class_or_wildcard(self):
signature = "java/lang"
with self.assertRaises(Exception) as context:
self.signature_to_elements(signature)
self.assertIn(
"last element 'lang' is lower case but should be an "
"upper case class name or wildcard", str(context.exception))
def test_non_standard_class_name(self):
elements = [
("package", "javax"),
("package", "crypto"),
("class", "extObjectInputStream"),
]
signature = "Ljavax/crypto/extObjectInputStream"
self.assertEqual(elements, self.signature_to_elements(signature))
self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_invalid_pattern_wildcard(self):
pattern = "Ljava/lang/Class*"
with self.assertRaises(Exception) as context:
self.signature_to_elements(pattern)
self.assertIn("invalid wildcard 'Class*'", str(context.exception))
def test_invalid_pattern_wildcard_and_member(self):
pattern = "Ljava/lang/*;->hashCode()I"
with self.assertRaises(Exception) as context:
self.signature_to_elements(pattern)
self.assertIn(
"contains wildcard '*' and member signature 'hashCode()I'",
str(context.exception))
class TestValues(unittest.TestCase):
def test_add_then_get(self):
trie = signature_trie()
trie.add("La/b/C;->l()", 1)
trie.add("La/b/C$D;->m()", "A")
trie.add("La/b/C$D;->n()", {})
package_a_node = next(iter(trie.child_nodes()))
self.assertEqual("package", package_a_node.type)
self.assertEqual("a", package_a_node.selector)
package_b_node = next(iter(package_a_node.child_nodes()))
self.assertEqual("package", package_b_node.type)
self.assertEqual("a/b", package_b_node.selector)
class_c_node = next(iter(package_b_node.child_nodes()))
self.assertEqual("class", class_c_node.type)
self.assertEqual("a/b/C", class_c_node.selector)
self.assertEqual([1, "A", {}], class_c_node.values(lambda _: True))
class TestGetMatchingRows(unittest.TestCase):
extractInput = """
Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
Ljava/lang/Character;->serialVersionUID:J
Ljava/lang/Object;->hashCode()I
Ljava/lang/Object;->toString()Ljava/lang/String;
Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V
Ljava/util/zip/ZipFile;-><clinit>()V
"""
def read_trie(self):
trie = signature_trie()
with io.StringIO(self.extractInput.strip()) as f:
for line in iter(f.readline, ""):
line = line.rstrip()
trie.add(line, line)
return trie
def check_patterns(self, pattern, expected):
trie = self.read_trie()
self.check_node_patterns(trie, pattern, expected)
def check_node_patterns(self, node, pattern, expected):
actual = list(node.get_matching_rows(pattern))
actual.sort()
self.assertEqual(expected, actual)
def test_member_pattern(self):
self.check_patterns("java/util/zip/ZipFile;-><clinit>()V",
["Ljava/util/zip/ZipFile;-><clinit>()V"])
def test_class_pattern(self):
self.check_patterns("java/lang/Object", [
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
])
# pylint: disable=line-too-long
def test_nested_class_pattern(self):
self.check_patterns("java/lang/Character", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
])
def test_wildcard(self):
self.check_patterns("java/lang/*", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
])
def test_recursive_wildcard(self):
self.check_patterns("java/**", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
def test_node_wildcard(self):
trie = self.read_trie()
node = list(trie.child_nodes())[0]
self.check_node_patterns(node, "**", [
"Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
"Ljava/lang/Character;->serialVersionUID:J",
"Ljava/lang/Object;->hashCode()I",
"Ljava/lang/Object;->toString()Ljava/lang/String;",
"Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
# pylint: enable=line-too-long
if __name__ == "__main__":
unittest.main(verbosity=2)