blob: 09af443b57c6f5dcd9a69e5f250ed99bb1929a7e [file] [log] [blame]
## This file is part of Scapy
## Copyright (C) 2017 Maxence Tury
## This program is published under a GPLv2 license
"""
TLS 1.3 key exchange logic.
"""
import math
from scapy.config import conf, crypto_validator
from scapy.error import log_runtime, warning
from scapy.fields import *
from scapy.packet import Packet, Raw, Padding
from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA
from scapy.layers.tls.session import _GenericTLSSessionInheritance
from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField
from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext
from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip
from scapy.layers.tls.crypto.groups import (_tls_named_ffdh_groups,
_tls_named_curves, _ffdh_groups,
_tls_named_groups)
if conf.crypto_valid:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dh, ec
if conf.crypto_valid_advanced:
from cryptography.hazmat.primitives.asymmetric import x25519
class KeyShareEntry(Packet):
"""
When building from scratch, we create a DH private key, and when
dissecting, we create a DH public key. Default group is secp256r1.
"""
__slots__ = ["privkey", "pubkey"]
name = "Key Share Entry"
fields_desc = [ShortEnumField("group", None, _tls_named_groups),
FieldLenField("kxlen", None, length_of="key_exchange"),
StrLenField("key_exchange", "",
length_from=lambda pkt: pkt.kxlen) ]
def __init__(self, *args, **kargs):
self.privkey = None
self.pubkey = None
super(KeyShareEntry, self).__init__(*args, **kargs)
def do_build(self):
"""
We need this hack, else 'self' would be replaced by __iter__.next().
"""
tmp = self.explicit
self.explicit = True
b = super(KeyShareEntry, self).do_build()
self.explicit = tmp
return b
@crypto_validator
def create_privkey(self):
"""
This is called by post_build() for key creation.
"""
if self.group in _tls_named_ffdh_groups:
params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
privkey = params.generate_private_key()
self.privkey = privkey
pubkey = privkey.public_key()
self.key_exchange = pubkey.public_numbers().y
elif self.group in _tls_named_curves:
if _tls_named_curves[self.group] == "x25519":
if conf.crypto_valid_advanced:
privkey = x25519.X25519PrivateKey.generate()
self.privkey = privkey
pubkey = privkey.public_key()
self.key_exchange = pubkey.public_bytes()
elif _tls_named_curves[self.group] != "x448":
curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
privkey = ec.generate_private_key(curve, default_backend())
self.privkey = privkey
pubkey = privkey.public_key()
self.key_exchange = pubkey.public_numbers().encode_point()
def post_build(self, pkt, pay):
if self.group is None:
self.group = 23 # secp256r1
if not self.key_exchange:
try:
self.create_privkey()
except ImportError:
pass
if self.kxlen is None:
self.kxlen = len(self.key_exchange)
group = struct.pack("!H", self.group)
kxlen = struct.pack("!H", self.kxlen)
return group + kxlen + self.key_exchange + pay
@crypto_validator
def register_pubkey(self):
if self.group in _tls_named_ffdh_groups:
params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
pn = params.parameter_numbers()
public_numbers = dh.DHPublicNumbers(self.key_exchange, pn)
self.pubkey = public_numbers.public_key(default_backend())
elif self.group in _tls_named_curves:
if _tls_named_curves[self.group] == "x25519":
if conf.crypto_valid_advanced:
import_point = x25519.X25519PublicKey.from_public_bytes
self.pubkey = import_point(self.key_exchange)
elif _tls_named_curves[self.group] != "x448":
curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
import_point = ec.EllipticCurvePublicNumbers.from_encoded_point
public_numbers = import_point(curve, self.key_exchange)
self.pubkey = public_numbers.public_key(default_backend())
def post_dissection(self, r):
try:
self.register_pubkey()
except ImportError:
pass
def extract_padding(self, s):
return "", s
class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown):
name = "TLS Extension - Key Share (for ClientHello)"
fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
ShortField("len", None),
FieldLenField("client_shares_len", None,
length_of="client_shares"),
PacketListField("client_shares", [], KeyShareEntry,
length_from=lambda pkt: pkt.client_shares_len) ]
def post_build(self, pkt, pay):
if not self.tls_session.frozen:
privshares = self.tls_session.tls13_client_privshares
for kse in self.client_shares:
if kse.privkey:
if _tls_named_curves[kse.group] in privshares:
pkt_info = pkt.firstlayer().summary()
log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
break
privshares[_tls_named_groups[kse.group]] = kse.privkey
return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay)
def post_dissection(self, r):
if not self.tls_session.frozen:
for kse in self.client_shares:
if kse.pubkey:
pubshares = self.tls_session.tls13_client_pubshares
if _tls_named_curves[kse.group] in pubshares:
pkt_info = r.firstlayer().summary()
log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
break
pubshares[_tls_named_curves[kse.group]] = kse.pubkey
return super(TLS_Ext_KeyShare_CH, self).post_dissection(r)
class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown):
name = "TLS Extension - Key Share (for HelloRetryRequest)"
fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
ShortField("len", None),
ShortEnumField("selected_group", None, _tls_named_groups) ]
class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown):
name = "TLS Extension - Key Share (for ServerHello)"
fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
ShortField("len", None),
PacketField("server_share", None, KeyShareEntry) ]
def post_build(self, pkt, pay):
if not self.tls_session.frozen and self.server_share.privkey:
# if there is a privkey, we assume the crypto library is ok
privshare = self.tls_session.tls13_server_privshare
if len(privshare) > 0:
pkt_info = pkt.firstlayer().summary()
log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
group_name = _tls_named_groups[self.server_share.group]
privshare[group_name] = self.server_share.privkey
if group_name in self.tls_session.tls13_client_pubshares:
privkey = self.server_share.privkey
pubkey = self.tls_session.tls13_client_pubshares[group_name]
if group_name in six.itervalues(_tls_named_ffdh_groups):
pms = privkey.exchange(pubkey)
elif group_name in six.itervalues(_tls_named_curves):
if group_name == "x25519":
pms = privkey.exchange(pubkey)
else:
pms = privkey.exchange(ec.ECDH(), pubkey)
self.tls_session.tls13_dhe_secret = pms
return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay)
def post_dissection(self, r):
if not self.tls_session.frozen and self.server_share.pubkey:
# if there is a pubkey, we assume the crypto library is ok
pubshare = self.tls_session.tls13_server_pubshare
if len(pubshare) > 0:
pkt_info = r.firstlayer().summary()
log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
group_name = _tls_named_groups[self.server_share.group]
pubshare[group_name] = self.server_share.pubkey
if group_name in self.tls_session.tls13_client_privshares:
pubkey = self.server_share.pubkey
privkey = self.tls_session.tls13_client_privshares[group_name]
if group_name in six.itervalues(_tls_named_ffdh_groups):
pms = privkey.exchange(pubkey)
elif group_name in six.itervalues(_tls_named_curves):
if group_name == "x25519":
pms = privkey.exchange(pubkey)
else:
pms = privkey.exchange(ec.ECDH(), pubkey)
self.tls_session.tls13_dhe_secret = pms
return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
_tls_ext_keyshare_cls = { 1: TLS_Ext_KeyShare_CH,
2: TLS_Ext_KeyShare_SH,
6: TLS_Ext_KeyShare_HRR }
class Ticket(Packet):
name = "Recommended Ticket Construction (from RFC 5077)"
fields_desc = [ StrFixedLenField("key_name", None, 16),
StrFixedLenField("iv", None, 16),
FieldLenField("encstatelen", None, length_of="encstate"),
StrLenField("encstate", "",
length_from=lambda pkt: pkt.encstatelen),
StrFixedLenField("mac", None, 32) ]
class TicketField(PacketField):
__slots__ = ["length_from"]
def __init__(self, name, default, length_from=None, **kargs):
self.length_from = length_from
PacketField.__init__(self, name, default, Ticket, **kargs)
def m2i(self, pkt, m):
l = self.length_from(pkt)
tbd, rem = m[:l], m[l:]
return self.cls(tbd)/Padding(rem)
class PSKIdentity(Packet):
name = "PSK Identity"
fields_desc = [FieldLenField("identity_len", None,
length_of="identity"),
TicketField("identity", "",
length_from=lambda pkt: pkt.identity_len),
IntField("obfuscated_ticket_age", 0) ]
class PSKBinderEntry(Packet):
name = "PSK Binder Entry"
fields_desc = [FieldLenField("binder_len", None, fmt="B",
length_of="binder"),
StrLenField("binder", "",
length_from=lambda pkt: pkt.binder_len) ]
class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown):
#XXX define post_build and post_dissection methods
name = "TLS Extension - Pre Shared Key (for ClientHello)"
fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
ShortField("len", None),
FieldLenField("identities_len", None,
length_of="identities"),
PacketListField("identities", [], PSKIdentity,
length_from=lambda pkt: pkt.identities_len),
FieldLenField("binders_len", None,
length_of="binders"),
PacketListField("binders", [], PSKBinderEntry,
length_from=lambda pkt: pkt.binders_len) ]
class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown):
name = "TLS Extension - Pre Shared Key (for ServerHello)"
fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
ShortField("len", None),
ShortField("selected_identity", None) ]
_tls_ext_presharedkey_cls = { 1: TLS_Ext_PreSharedKey_CH,
2: TLS_Ext_PreSharedKey_SH }