| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Primitive Creation consistency tests.""" |
| |
| from typing import Any |
| |
| from absl.testing import absltest |
| from absl.testing import parameterized |
| import tink |
| from tink import aead |
| from tink import daead |
| from tink import hybrid |
| from tink import jwt |
| from tink import mac |
| from tink import prf |
| from tink import signature |
| from tink import streaming_aead |
| |
| import tink_config |
| from util import test_keys |
| from util import testing_servers |
| from util import utilities |
| |
| |
| # We register the primitives here because creation of the keysets happens |
| # before "setUpModule" is called. |
| aead.register() |
| daead.register() |
| jwt.register_jwt_mac() |
| jwt.register_jwt_signature() |
| mac.register() |
| hybrid.register() |
| prf.register() |
| signature.register() |
| streaming_aead.register() |
| |
| |
| def setUpModule(): |
| testing_servers.start('primitive_creation') |
| |
| |
| def tearDownModule(): |
| testing_servers.stop() |
| |
| |
| def single_key_keysets(): |
| # TODO(tholenst): Add all templates. |
| for name in ('AES128_EAX', 'AES256_GCM', |
| 'ECIES_P256_HKDF_HMAC_SHA256_AES128_GCM'): |
| template = utilities.KEY_TEMPLATE[name] |
| yield test_keys.new_or_stored_keyset(template) |
| |
| |
| def named_testcases(): |
| case_num = 0 |
| for lang in utilities.ALL_LANGUAGES: |
| for keyset in single_key_keysets(): |
| for primitive in tink_config.all_primitives(): |
| yield { |
| 'testcase_name': |
| str(case_num) + '-' + lang + '-' + primitive.__name__ + '-' + |
| utilities.key_types_in_keyset(keyset)[0], |
| 'lang': |
| lang, |
| 'primitive': |
| primitive, |
| 'keyset': |
| keyset, |
| } |
| case_num += 1 |
| |
| |
| def _is_b243759652_test_case(lang: str, keyset: bytes, primitive: Any) -> bool: |
| """Returns whether the test case falls under b/243759652. |
| |
| When calling hybrid.NewHybridDecrypt or hybrid.NewHybridEncrypt, Tink asks |
| each key manager to create a primitive (whose type is fixed for each key |
| manager). Because of duck-typing, if the key manager returns an Aead, Tink |
| happily carries on in case it wants a HybridEncrypt/HybridDecrypt. |
| |
| Args: |
| lang: A string describing the language. |
| keyset: A serialized keyset |
| primitive: One of the primitives |
| Returns: |
| True iff this test case falls under b/243759652. |
| """ |
| # The bug only exists in go. |
| if lang != 'go': |
| return False |
| # The bug only happens if we create a HybridEncrypt or a HybridDecrypt |
| if primitive not in [tink.hybrid.HybridDecrypt, tink.hybrid.HybridEncrypt]: |
| return False |
| |
| keytypes = utilities.key_types_in_keyset(keyset) |
| primitives = [tink_config.primitive_for_keytype(k) for k in keytypes] |
| # For the bug to occur, we must only at least one AEAD keytype (as it only |
| # happens if at least one key type should *not* work). |
| if not any(p == aead.Aead for p in primitives): |
| return False |
| # For the bug to occur, all key types must be either for 'primitive' or |
| # for Aead (otherwise primitive creation fails). |
| if not all(p == aead.Aead or p == primitive for p in primitives): |
| return False |
| # For the bug to occur, we must not have an AesEaxKey: these are unsupported |
| # in go, and so if we have them, primitive creation fails. |
| if any(k == 'AesEaxKey' for k in keytypes): |
| return False |
| return True |
| |
| |
| class SupportedKeyTypesTest(parameterized.TestCase): |
| """Test class.""" |
| |
| @parameterized.named_parameters(named_testcases()) |
| def test_create(self, lang: str, keyset: bytes, primitive: Any): |
| """Tests primitive creation (see top level comment). |
| |
| This tests should pass for every keyset, as long as the keyset can be |
| correctly parsed. |
| Args: |
| lang: The language to test |
| keyset: A byte string representing a keyset. The keyset needs to be valid. |
| primitive: The primitive to try and instantiate |
| """ |
| keytypes = utilities.key_types_in_keyset(keyset) |
| self.assertLen(keytypes, 1) |
| keytype = keytypes[0] |
| |
| if _is_b243759652_test_case(lang, keyset, primitive): |
| # TODO(b/243759652): This should raise a TinkError, but doesn't |
| _ = testing_servers.remote_primitive(lang, keyset, primitive) |
| return |
| |
| if (lang in tink_config.supported_languages_for_key_type(keytype) and |
| primitive == tink_config.primitive_for_keytype(keytype)): |
| _ = testing_servers.remote_primitive(lang, keyset, primitive) |
| else: |
| with self.assertRaises(tink.TinkError): |
| _ = testing_servers.remote_primitive(lang, keyset, primitive) |
| |
| @parameterized.named_parameters(named_testcases()) |
| def test_create_with_public_keyset(self, lang: str, keyset: bytes, |
| primitive: Any): |
| """Tests primitive creation, after getting a public keyset.""" |
| try: |
| public_keyset = testing_servers.public_keyset(lang, keyset) |
| except tink.TinkError: |
| self.skipTest('Cannot get the public keyset') |
| |
| keytypes = utilities.key_types_in_keyset(public_keyset) |
| self.assertLen(keytypes, 1) |
| keytype = keytypes[0] |
| |
| if _is_b243759652_test_case(lang, public_keyset, primitive): |
| # TODO(b/243759652): This should raise a TinkError, but doesn't |
| _ = testing_servers.remote_primitive(lang, public_keyset, primitive) |
| return |
| |
| if (lang in tink_config.supported_languages_for_key_type(keytype) and |
| primitive == tink_config.primitive_for_keytype(keytype)): |
| _ = testing_servers.remote_primitive(lang, public_keyset, primitive) |
| else: |
| with self.assertRaises(tink.TinkError): |
| _ = testing_servers.remote_primitive(lang, public_keyset, primitive) |
| |
| |
| if __name__ == '__main__': |
| absltest.main() |