blob: 296f422b33743cff288a39fbdf3ea567a2337340 [file] [log] [blame]
"""Classes representing TLS messages."""
from utils.compat import *
from utils.cryptomath import *
from errors import *
from utils.codec import *
from constants import *
from X509 import X509
from X509CertChain import X509CertChain
# The sha module is deprecated in Python 2.6
try:
import sha
except ImportError:
from hashlib import sha1 as sha
# The md5 module is deprecated in Python 2.6
try:
import md5
except ImportError:
from hashlib import md5
class RecordHeader3:
def __init__(self):
self.type = 0
self.version = (0,0)
self.length = 0
self.ssl2 = False
def create(self, version, type, length):
self.type = type
self.version = version
self.length = length
return self
def write(self):
w = Writer(5)
w.add(self.type, 1)
w.add(self.version[0], 1)
w.add(self.version[1], 1)
w.add(self.length, 2)
return w.bytes
def parse(self, p):
self.type = p.get(1)
self.version = (p.get(1), p.get(1))
self.length = p.get(2)
self.ssl2 = False
return self
class RecordHeader2:
def __init__(self):
self.type = 0
self.version = (0,0)
self.length = 0
self.ssl2 = True
def parse(self, p):
if p.get(1)!=128:
raise SyntaxError()
self.type = ContentType.handshake
self.version = (2,0)
#We don't support 2-byte-length-headers; could be a problem
self.length = p.get(1)
return self
class Msg:
def preWrite(self, trial):
if trial:
w = Writer()
else:
length = self.write(True)
w = Writer(length)
return w
def postWrite(self, w, trial):
if trial:
return w.index
else:
return w.bytes
class Alert(Msg):
def __init__(self):
self.contentType = ContentType.alert
self.level = 0
self.description = 0
def create(self, description, level=AlertLevel.fatal):
self.level = level
self.description = description
return self
def parse(self, p):
p.setLengthCheck(2)
self.level = p.get(1)
self.description = p.get(1)
p.stopLengthCheck()
return self
def write(self):
w = Writer(2)
w.add(self.level, 1)
w.add(self.description, 1)
return w.bytes
class HandshakeMsg(Msg):
def preWrite(self, handshakeType, trial):
if trial:
w = Writer()
w.add(handshakeType, 1)
w.add(0, 3)
else:
length = self.write(True)
w = Writer(length)
w.add(handshakeType, 1)
w.add(length-4, 3)
return w
class ClientHello(HandshakeMsg):
def __init__(self, ssl2=False):
self.contentType = ContentType.handshake
self.ssl2 = ssl2
self.client_version = (0,0)
self.random = createByteArrayZeros(32)
self.session_id = createByteArraySequence([])
self.cipher_suites = [] # a list of 16-bit values
self.certificate_types = [CertificateType.x509]
self.compression_methods = [] # a list of 8-bit values
self.srp_username = None # a string
self.channel_id = False
self.support_signed_cert_timestamps = False
def create(self, version, random, session_id, cipher_suites,
certificate_types=None, srp_username=None):
self.client_version = version
self.random = random
self.session_id = session_id
self.cipher_suites = cipher_suites
self.certificate_types = certificate_types
self.compression_methods = [0]
self.srp_username = srp_username
return self
def parse(self, p):
if self.ssl2:
self.client_version = (p.get(1), p.get(1))
cipherSpecsLength = p.get(2)
sessionIDLength = p.get(2)
randomLength = p.get(2)
self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
self.session_id = p.getFixBytes(sessionIDLength)
self.random = p.getFixBytes(randomLength)
if len(self.random) < 32:
zeroBytes = 32-len(self.random)
self.random = createByteArrayZeros(zeroBytes) + self.random
self.compression_methods = [0]#Fake this value
#We're not doing a stopLengthCheck() for SSLv2, oh well..
else:
p.startLengthCheck(3)
self.client_version = (p.get(1), p.get(1))
self.random = p.getFixBytes(32)
self.session_id = p.getVarBytes(1)
self.cipher_suites = p.getVarList(2, 2)
self.compression_methods = p.getVarList(1, 1)
if not p.atLengthCheck():
totalExtLength = p.get(2)
soFar = 0
while soFar != totalExtLength:
extType = p.get(2)
extLength = p.get(2)
if extType == 6:
self.srp_username = bytesToString(p.getVarBytes(1))
elif extType == 7:
self.certificate_types = p.getVarList(1, 1)
elif extType == ExtensionType.channel_id:
self.channel_id = True
elif extType == ExtensionType.signed_cert_timestamps:
if extLength:
raise SyntaxError()
self.support_signed_cert_timestamps = True
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
w.add(self.client_version[0], 1)
w.add(self.client_version[1], 1)
w.addFixSeq(self.random, 1)
w.addVarSeq(self.session_id, 1, 1)
w.addVarSeq(self.cipher_suites, 2, 2)
w.addVarSeq(self.compression_methods, 1, 1)
extLength = 0
if self.certificate_types and self.certificate_types != \
[CertificateType.x509]:
extLength += 5 + len(self.certificate_types)
if self.srp_username:
extLength += 5 + len(self.srp_username)
if extLength > 0:
w.add(extLength, 2)
if self.certificate_types and self.certificate_types != \
[CertificateType.x509]:
w.add(7, 2)
w.add(len(self.certificate_types)+1, 2)
w.addVarSeq(self.certificate_types, 1, 1)
if self.srp_username:
w.add(6, 2)
w.add(len(self.srp_username)+1, 2)
w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
return HandshakeMsg.postWrite(self, w, trial)
class ServerHello(HandshakeMsg):
def __init__(self):
self.contentType = ContentType.handshake
self.server_version = (0,0)
self.random = createByteArrayZeros(32)
self.session_id = createByteArraySequence([])
self.cipher_suite = 0
self.certificate_type = CertificateType.x509
self.compression_method = 0
self.channel_id = False
self.signed_cert_timestamps = None
def create(self, version, random, session_id, cipher_suite,
certificate_type):
self.server_version = version
self.random = random
self.session_id = session_id
self.cipher_suite = cipher_suite
self.certificate_type = certificate_type
self.compression_method = 0
return self
def parse(self, p):
p.startLengthCheck(3)
self.server_version = (p.get(1), p.get(1))
self.random = p.getFixBytes(32)
self.session_id = p.getVarBytes(1)
self.cipher_suite = p.get(2)
self.compression_method = p.get(1)
if not p.atLengthCheck():
totalExtLength = p.get(2)
soFar = 0
while soFar != totalExtLength:
extType = p.get(2)
extLength = p.get(2)
if extType == 7:
self.certificate_type = p.get(1)
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
w.add(self.server_version[0], 1)
w.add(self.server_version[1], 1)
w.addFixSeq(self.random, 1)
w.addVarSeq(self.session_id, 1, 1)
w.add(self.cipher_suite, 2)
w.add(self.compression_method, 1)
extLength = 0
if self.certificate_type and self.certificate_type != \
CertificateType.x509:
extLength += 5
if self.channel_id:
extLength += 4
if self.signed_cert_timestamps:
extLength += 4 + len(self.signed_cert_timestamps)
if extLength != 0:
w.add(extLength, 2)
if self.certificate_type and self.certificate_type != \
CertificateType.x509:
w.add(7, 2)
w.add(1, 2)
w.add(self.certificate_type, 1)
if self.channel_id:
w.add(ExtensionType.channel_id, 2)
w.add(0, 2)
if self.signed_cert_timestamps:
w.add(ExtensionType.signed_cert_timestamps, 2)
w.addVarSeq(stringToBytes(self.signed_cert_timestamps), 1, 2)
return HandshakeMsg.postWrite(self, w, trial)
class Certificate(HandshakeMsg):
def __init__(self, certificateType):
self.certificateType = certificateType
self.contentType = ContentType.handshake
self.certChain = None
def create(self, certChain):
self.certChain = certChain
return self
def parse(self, p):
p.startLengthCheck(3)
if self.certificateType == CertificateType.x509:
chainLength = p.get(3)
index = 0
certificate_list = []
while index != chainLength:
certBytes = p.getVarBytes(3)
x509 = X509()
x509.parseBinary(certBytes)
certificate_list.append(x509)
index += len(certBytes)+3
if certificate_list:
self.certChain = X509CertChain(certificate_list)
elif self.certificateType == CertificateType.cryptoID:
s = bytesToString(p.getVarBytes(2))
if s:
try:
import cryptoIDlib.CertChain
except ImportError:
raise SyntaxError(\
"cryptoID cert chain received, cryptoIDlib not present")
self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
else:
raise AssertionError()
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
if self.certificateType == CertificateType.x509:
chainLength = 0
if self.certChain:
certificate_list = self.certChain.x509List
else:
certificate_list = []
#determine length
for cert in certificate_list:
bytes = cert.writeBytes()
chainLength += len(bytes)+3
#add bytes
w.add(chainLength, 3)
for cert in certificate_list:
bytes = cert.writeBytes()
w.addVarSeq(bytes, 1, 3)
elif self.certificateType == CertificateType.cryptoID:
if self.certChain:
bytes = stringToBytes(self.certChain.write())
else:
bytes = createByteArraySequence([])
w.addVarSeq(bytes, 1, 2)
else:
raise AssertionError()
return HandshakeMsg.postWrite(self, w, trial)
class CertificateRequest(HandshakeMsg):
def __init__(self):
self.contentType = ContentType.handshake
#Apple's Secure Transport library rejects empty certificate_types, so
#default to rsa_sign.
self.certificate_types = [ClientCertificateType.rsa_sign]
self.certificate_authorities = []
def create(self, certificate_types, certificate_authorities):
self.certificate_types = certificate_types
self.certificate_authorities = certificate_authorities
return self
def parse(self, p):
p.startLengthCheck(3)
self.certificate_types = p.getVarList(1, 1)
ca_list_length = p.get(2)
index = 0
self.certificate_authorities = []
while index != ca_list_length:
ca_bytes = p.getVarBytes(2)
self.certificate_authorities.append(ca_bytes)
index += len(ca_bytes)+2
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
trial)
w.addVarSeq(self.certificate_types, 1, 1)
caLength = 0
#determine length
for ca_dn in self.certificate_authorities:
caLength += len(ca_dn)+2
w.add(caLength, 2)
#add bytes
for ca_dn in self.certificate_authorities:
w.addVarSeq(ca_dn, 1, 2)
return HandshakeMsg.postWrite(self, w, trial)
class ServerKeyExchange(HandshakeMsg):
def __init__(self, cipherSuite):
self.cipherSuite = cipherSuite
self.contentType = ContentType.handshake
self.srp_N = 0L
self.srp_g = 0L
self.srp_s = createByteArraySequence([])
self.srp_B = 0L
self.signature = createByteArraySequence([])
def createSRP(self, srp_N, srp_g, srp_s, srp_B):
self.srp_N = srp_N
self.srp_g = srp_g
self.srp_s = srp_s
self.srp_B = srp_B
return self
def parse(self, p):
p.startLengthCheck(3)
self.srp_N = bytesToNumber(p.getVarBytes(2))
self.srp_g = bytesToNumber(p.getVarBytes(2))
self.srp_s = p.getVarBytes(1)
self.srp_B = bytesToNumber(p.getVarBytes(2))
if self.cipherSuite in CipherSuite.srpRsaSuites:
self.signature = p.getVarBytes(2)
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
trial)
w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
w.addVarSeq(self.srp_s, 1, 1)
w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
if self.cipherSuite in CipherSuite.srpRsaSuites:
w.addVarSeq(self.signature, 1, 2)
return HandshakeMsg.postWrite(self, w, trial)
def hash(self, clientRandom, serverRandom):
oldCipherSuite = self.cipherSuite
self.cipherSuite = None
try:
bytes = clientRandom + serverRandom + self.write()[4:]
s = bytesToString(bytes)
return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
finally:
self.cipherSuite = oldCipherSuite
class ServerHelloDone(HandshakeMsg):
def __init__(self):
self.contentType = ContentType.handshake
def create(self):
return self
def parse(self, p):
p.startLengthCheck(3)
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
return HandshakeMsg.postWrite(self, w, trial)
class ClientKeyExchange(HandshakeMsg):
def __init__(self, cipherSuite, version=None):
self.cipherSuite = cipherSuite
self.version = version
self.contentType = ContentType.handshake
self.srp_A = 0
self.encryptedPreMasterSecret = createByteArraySequence([])
def createSRP(self, srp_A):
self.srp_A = srp_A
return self
def createRSA(self, encryptedPreMasterSecret):
self.encryptedPreMasterSecret = encryptedPreMasterSecret
return self
def parse(self, p):
p.startLengthCheck(3)
if self.cipherSuite in CipherSuite.srpSuites + \
CipherSuite.srpRsaSuites:
self.srp_A = bytesToNumber(p.getVarBytes(2))
elif self.cipherSuite in CipherSuite.rsaSuites:
if self.version in ((3,1), (3,2)):
self.encryptedPreMasterSecret = p.getVarBytes(2)
elif self.version == (3,0):
self.encryptedPreMasterSecret = \
p.getFixBytes(len(p.bytes)-p.index)
else:
raise AssertionError()
else:
raise AssertionError()
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
trial)
if self.cipherSuite in CipherSuite.srpSuites + \
CipherSuite.srpRsaSuites:
w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
elif self.cipherSuite in CipherSuite.rsaSuites:
if self.version in ((3,1), (3,2)):
w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
elif self.version == (3,0):
w.addFixSeq(self.encryptedPreMasterSecret, 1)
else:
raise AssertionError()
else:
raise AssertionError()
return HandshakeMsg.postWrite(self, w, trial)
class CertificateVerify(HandshakeMsg):
def __init__(self):
self.contentType = ContentType.handshake
self.signature = createByteArraySequence([])
def create(self, signature):
self.signature = signature
return self
def parse(self, p):
p.startLengthCheck(3)
self.signature = p.getVarBytes(2)
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
trial)
w.addVarSeq(self.signature, 1, 2)
return HandshakeMsg.postWrite(self, w, trial)
class ChangeCipherSpec(Msg):
def __init__(self):
self.contentType = ContentType.change_cipher_spec
self.type = 1
def create(self):
self.type = 1
return self
def parse(self, p):
p.setLengthCheck(1)
self.type = p.get(1)
p.stopLengthCheck()
return self
def write(self, trial=False):
w = Msg.preWrite(self, trial)
w.add(self.type,1)
return Msg.postWrite(self, w, trial)
class Finished(HandshakeMsg):
def __init__(self, version):
self.contentType = ContentType.handshake
self.version = version
self.verify_data = createByteArraySequence([])
def create(self, verify_data):
self.verify_data = verify_data
return self
def parse(self, p):
p.startLengthCheck(3)
if self.version == (3,0):
self.verify_data = p.getFixBytes(36)
elif self.version in ((3,1), (3,2)):
self.verify_data = p.getFixBytes(12)
else:
raise AssertionError()
p.stopLengthCheck()
return self
def write(self, trial=False):
w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
w.addFixSeq(self.verify_data, 1)
return HandshakeMsg.postWrite(self, w, trial)
class EncryptedExtensions(HandshakeMsg):
def __init__(self):
self.channel_id_key = None
self.channel_id_proof = None
def parse(self, p):
p.startLengthCheck(3)
soFar = 0
while soFar != p.lengthCheck:
extType = p.get(2)
extLength = p.get(2)
if extType == ExtensionType.channel_id:
if extLength != 32*4:
raise SyntaxError()
self.channel_id_key = p.getFixBytes(64)
self.channel_id_proof = p.getFixBytes(64)
else:
p.getFixBytes(extLength)
soFar += 4 + extLength
p.stopLengthCheck()
return self
class ApplicationData(Msg):
def __init__(self):
self.contentType = ContentType.application_data
self.bytes = createByteArraySequence([])
def create(self, bytes):
self.bytes = bytes
return self
def parse(self, p):
self.bytes = p.bytes
return self
def write(self):
return self.bytes