blob: e078c50c7cd94196f79343b860bee4c2018a396e [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cross-language tests for the JWT primitives."""
import datetime
import json
from absl.testing import absltest
from absl.testing import parameterized
import tink
from tink import jwt
from util import testing_servers
from util import utilities
SUPPORTED_LANGUAGES = testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE['jwt']
def setUpModule():
testing_servers.start('jwt')
def tearDownModule():
testing_servers.stop()
class JwtTest(parameterized.TestCase):
@parameterized.parameters(utilities.tinkey_template_names_for(jwt.JwtMac))
def test_compute_verify_jwt_mac(self, key_template_name):
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
key_template_name]
self.assertNotEmpty(supported_langs)
key_template = utilities.KEY_TEMPLATE[key_template_name]
# Take the first supported language to generate the keyset.
keyset = testing_servers.new_keyset(supported_langs[0], key_template)
supported_jwt_macs = []
for lang in supported_langs:
supported_jwt_macs.append(
testing_servers.remote_primitive(lang, keyset, jwt.JwtMac))
now = datetime.datetime.now(tz=datetime.timezone.utc)
raw_jwt = jwt.new_raw_jwt(
issuer='issuer',
expiration=now + datetime.timedelta(seconds=100))
for p in supported_jwt_macs:
compact = p.compute_mac_and_encode(raw_jwt)
validator = jwt.new_validator(expected_issuer='issuer', fixed_now=now)
for p2 in supported_jwt_macs:
verified_jwt = p2.verify_mac_and_decode(compact, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
@parameterized.parameters(
utilities.tinkey_template_names_for(jwt.JwtPublicKeySign))
def test_jwt_public_key_sign_verify(self, key_template_name):
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
key_template_name]
key_template = utilities.KEY_TEMPLATE[key_template_name]
self.assertNotEmpty(supported_langs)
# Take the first supported language to generate the private keyset.
private_keyset = testing_servers.new_keyset(supported_langs[0],
key_template)
supported_signers = {}
for lang in supported_langs:
supported_signers[lang] = testing_servers.remote_primitive(
lang, private_keyset, jwt.JwtPublicKeySign)
public_keyset = testing_servers.public_keyset('java', private_keyset)
supported_verifiers = {}
for lang in supported_langs:
supported_verifiers[lang] = testing_servers.remote_primitive(
lang, public_keyset, jwt.JwtPublicKeyVerify)
now = datetime.datetime.now(tz=datetime.timezone.utc)
raw_jwt = jwt.new_raw_jwt(
issuer='issuer', expiration=now + datetime.timedelta(seconds=100))
for signer in supported_signers.values():
compact = signer.sign_and_encode(raw_jwt)
validator = jwt.new_validator(expected_issuer='issuer', fixed_now=now)
for verifier in supported_verifiers.values():
verified_jwt = verifier.verify_and_decode(compact, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
@parameterized.parameters(
utilities.tinkey_template_names_for(jwt.JwtPublicKeySign))
def test_jwt_public_key_sign_export_import_verify(self, key_template_name):
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
key_template_name]
self.assertNotEmpty(supported_langs)
key_template = utilities.KEY_TEMPLATE[key_template_name]
# Take the first supported language to generate the private keyset.
private_keyset = testing_servers.new_keyset(supported_langs[0],
key_template)
now = datetime.datetime.now(tz=datetime.timezone.utc)
raw_jwt = jwt.new_raw_jwt(
issuer='issuer', expiration=now + datetime.timedelta(seconds=100))
validator = jwt.new_validator(expected_issuer='issuer', fixed_now=now)
for lang1 in supported_langs:
# in lang1: sign token and export public keyset to a JWK set
signer = testing_servers.remote_primitive(lang1, private_keyset,
jwt.JwtPublicKeySign)
compact = signer.sign_and_encode(raw_jwt)
public_keyset = testing_servers.public_keyset(lang1, private_keyset)
public_jwk_set = testing_servers.jwk_set_from_keyset(lang1, public_keyset)
for lang2 in supported_langs:
# in lang2: import the public JWK set and verify the token
public_keyset = testing_servers.jwk_set_to_keyset(lang2, public_jwk_set)
verifier = testing_servers.remote_primitive(lang2, public_keyset,
jwt.JwtPublicKeyVerify)
verified_jwt = verifier.verify_and_decode(compact, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
# Additional tests for the "kid" property of the JWK and the "kid"
# header of the token. Either of them may be missing, but they must not
# have different values.
jwks = json.loads(public_jwk_set)
has_kid = 'kid' in jwks['keys'][0]
if has_kid:
# Change the "kid" property of the JWK.
jwks['keys'][0]['kid'] = 'unknown kid'
public_keyset = testing_servers.jwk_set_to_keyset(
lang2, json.dumps(jwks))
verifier = testing_servers.remote_primitive(lang2, public_keyset,
jwt.JwtPublicKeyVerify)
with self.assertRaises(
tink.TinkError,
msg='%s accepts tokens with an incorrect kid unexpectedly' %
lang2):
verifier.verify_and_decode(compact, validator)
# Remove the "kid" property of the JWK.
del jwks['keys'][0]['kid']
public_keyset = testing_servers.jwk_set_to_keyset(
lang2, json.dumps(jwks))
verifier = testing_servers.remote_primitive(lang2, public_keyset,
jwt.JwtPublicKeyVerify)
verified_jwt = verifier.verify_and_decode(compact, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
else:
# Add a "kid" property of the JWK.
jwks['keys'][0]['kid'] = 'unknown kid'
public_keyset = testing_servers.jwk_set_to_keyset(
lang2, json.dumps(jwks))
verifier = testing_servers.remote_primitive(lang2, public_keyset,
jwt.JwtPublicKeyVerify)
verified_jwt = verifier.verify_and_decode(compact, validator)
self.assertEqual(verified_jwt.issuer(), 'issuer')
if __name__ == '__main__':
absltest.main()