blob: ff57f69ca4ad028094a9f709bf458e94d09e5aa9 [file] [log] [blame]
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cross-language tests for the "kid" header set by JWT primitives."""
import base64
import json
from typing import Optional
from absl.testing import absltest
from absl.testing import parameterized
import tink
from tink import jwt
from tink.proto import jwt_ecdsa_pb2
from tink.proto import jwt_hmac_pb2
from tink.proto import jwt_rsa_ssa_pkcs1_pb2
from tink.proto import jwt_rsa_ssa_pss_pb2
from tink.proto import tink_pb2
from util import testing_servers
from util import utilities
SUPPORTED_LANGUAGES = testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE['jwt']
def setUpModule():
jwt.register_jwt_mac()
jwt.register_jwt_signature()
testing_servers.start('jwt')
def tearDownModule():
testing_servers.stop()
def base64_decode(encoded_data: bytes) -> bytes:
padded_encoded_data = encoded_data + b'==='
return base64.urlsafe_b64decode(padded_encoded_data)
def decode_kid(compact: str) -> Optional[str]:
encoded_header, _, _ = compact.encode('utf8').split(b'.')
json_header = base64_decode(encoded_header)
header = json.loads(json_header)
return header.get('kid', None)
def generate_jwt_mac_keyset_with_custom_kid(
template_name: str, custom_kid: str) -> tink_pb2.Keyset:
key_template = utilities.KEY_TEMPLATE[template_name]
keyset_handle = tink.new_keyset_handle(key_template)
# parse key_data.value, set custom_kid and serialize
key_data_value = keyset_handle._keyset.key[0].key_data.value
if template_name.startswith('JWT_HS256'):
hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(key_data_value)
hmac_key.custom_kid.value = custom_kid
key_data_value = hmac_key.SerializeToString()
else:
raise ValueError('unknown alg')
keyset_handle._keyset.key[0].key_data.value = key_data_value
return keyset_handle._keyset
def generate_jwt_signature_keyset_with_custom_kid(
template_name: str, custom_kid: str) -> tink_pb2.Keyset:
key_template = utilities.KEY_TEMPLATE[template_name]
keyset_handle = tink.new_keyset_handle(key_template)
# parse key_data.value, set custom_kid and serialize
key_data_value = keyset_handle._keyset.key[0].key_data.value
if template_name.startswith('JWT_ES256'):
private_key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data_value)
private_key.public_key.custom_kid.value = custom_kid
key_data_value = private_key.SerializeToString()
elif template_name.startswith('JWT_RS256'):
private_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString(
key_data_value)
private_key.public_key.custom_kid.value = custom_kid
key_data_value = private_key.SerializeToString()
elif template_name.startswith('JWT_PS256'):
private_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(
key_data_value)
private_key.public_key.custom_kid.value = custom_kid
key_data_value = private_key.SerializeToString()
else:
raise ValueError('unknown template name')
keyset_handle._keyset.key[0].key_data.value = key_data_value
keyset = keyset_handle._keyset
return keyset
class JwtKidTest(parameterized.TestCase):
"""Tests that all JWT primitives consistently add a "kid" header to tokens."""
@parameterized.parameters(['JWT_HS256'])
def test_jwt_mac_sets_kid_for_tink_templates(self, template_name):
key_template = utilities.KEY_TEMPLATE[template_name]
keyset = testing_servers.new_keyset('cc', key_template)
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
for lang in SUPPORTED_LANGUAGES:
jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac)
compact = jwt_mac.compute_mac_and_encode(raw_jwt)
self.assertIsNotNone(decode_kid(compact))
@parameterized.parameters(['JWT_HS256_RAW'])
def test_jwt_mac_does_not_sets_kid_for_raw_templates(self, template_name):
key_template = utilities.KEY_TEMPLATE[template_name]
keyset = testing_servers.new_keyset('cc', key_template)
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
for lang in SUPPORTED_LANGUAGES:
jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac)
compact = jwt_mac.compute_mac_and_encode(raw_jwt)
self.assertIsNone(decode_kid(compact))
@parameterized.parameters(
['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4'])
def test_jwt_public_key_sign_sets_kid_for_tink_templates(self, template_name):
key_template = utilities.KEY_TEMPLATE[template_name]
keyset = testing_servers.new_keyset('cc', key_template)
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
template_name]
for lang in supported_langs:
jwt_sign = testing_servers.remote_primitive(lang, keyset,
jwt.JwtPublicKeySign)
compact = jwt_sign.sign_and_encode(raw_jwt)
self.assertIsNotNone(decode_kid(compact))
@parameterized.parameters(
['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW'])
def test_jwt_public_key_sign_does_not_sets_kid_for_raw_templates(
self, template_name):
key_template = utilities.KEY_TEMPLATE[template_name]
keyset = testing_servers.new_keyset('cc', key_template)
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
template_name]
for lang in supported_langs:
jwt_sign = testing_servers.remote_primitive(lang, keyset,
jwt.JwtPublicKeySign)
compact = jwt_sign.sign_and_encode(raw_jwt)
self.assertIsNone(decode_kid(compact))
@parameterized.parameters(['JWT_HS256_RAW'])
def test_jwt_mac_sets_custom_kid_for_raw_keys(self, template_name):
keyset = generate_jwt_mac_keyset_with_custom_kid(
template_name=template_name, custom_kid='my kid')
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
for lang in SUPPORTED_LANGUAGES:
jwt_mac = testing_servers.remote_primitive(lang,
keyset.SerializeToString(),
jwt.JwtMac)
compact = jwt_mac.compute_mac_and_encode(raw_jwt)
self.assertEqual(decode_kid(compact), 'my kid')
@parameterized.parameters(['JWT_HS256'])
def test_jwt_mac_fails_for_tink_keys_with_custom_kid(self, template_name):
keyset = generate_jwt_mac_keyset_with_custom_kid(
template_name=template_name, custom_kid='my kid')
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
for lang in SUPPORTED_LANGUAGES:
with self.assertRaises(
tink.TinkError,
msg=('%s supports JWT mac keys with TINK output prefix type '
'and custom_kid set unexpectedly') % lang):
jwt_mac = testing_servers.remote_primitive(lang,
keyset.SerializeToString(),
jwt.JwtMac)
jwt_mac.compute_mac_and_encode(raw_jwt)
@parameterized.parameters(
['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW'])
def test_jwt_public_key_sign_sets_custom_kid_for_raw_keys(
self, template_name):
keyset = generate_jwt_signature_keyset_with_custom_kid(
template_name=template_name, custom_kid='my kid')
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
template_name]
for lang in supported_langs:
jwt_sign = testing_servers.remote_primitive(lang,
keyset.SerializeToString(),
jwt.JwtPublicKeySign)
compact = jwt_sign.sign_and_encode(raw_jwt)
self.assertEqual(decode_kid(compact), 'my kid')
@parameterized.parameters(
['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4'])
def test_jwt_public_key_sign_fails_for_tink_keys_with_custom_kid(
self, template_name):
keyset = generate_jwt_signature_keyset_with_custom_kid(
template_name=template_name, custom_kid='my kid')
raw_jwt = jwt.new_raw_jwt(without_expiration=True)
for lang in SUPPORTED_LANGUAGES:
with self.assertRaises(
tink.TinkError,
msg=('%s supports JWT signature keys with TINK output prefix type '
'and custom_kid set unexpectedly') % lang):
jwt_sign = testing_servers.remote_primitive(lang,
keyset.SerializeToString(),
jwt.JwtPublicKeySign)
jwt_sign.sign_and_encode(raw_jwt)
if __name__ == '__main__':
absltest.main()