| package org.bouncycastle.crypto.tls; |
| |
| import java.io.IOException; |
| |
| class DTLSRecordLayer |
| implements DatagramTransport |
| { |
| private static final int RECORD_HEADER_LENGTH = 13; |
| private static final int MAX_FRAGMENT_LENGTH = 1 << 14; |
| private static final long TCP_MSL = 1000L * 60 * 2; |
| private static final long RETRANSMIT_TIMEOUT = TCP_MSL * 2; |
| |
| private final DatagramTransport transport; |
| private final TlsContext context; |
| private final TlsPeer peer; |
| |
| private final ByteQueue recordQueue = new ByteQueue(); |
| |
| private volatile boolean closed = false; |
| private volatile boolean failed = false; |
| private volatile ProtocolVersion discoveredPeerVersion = null; |
| private volatile boolean inHandshake; |
| private volatile int plaintextLimit; |
| private DTLSEpoch currentEpoch, pendingEpoch; |
| private DTLSEpoch readEpoch, writeEpoch; |
| |
| private DTLSHandshakeRetransmit retransmit = null; |
| private DTLSEpoch retransmitEpoch = null; |
| private long retransmitExpiry = 0; |
| |
| DTLSRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, short contentType) |
| { |
| this.transport = transport; |
| this.context = context; |
| this.peer = peer; |
| |
| this.inHandshake = true; |
| |
| this.currentEpoch = new DTLSEpoch(0, new TlsNullCipher(context)); |
| this.pendingEpoch = null; |
| this.readEpoch = currentEpoch; |
| this.writeEpoch = currentEpoch; |
| |
| setPlaintextLimit(MAX_FRAGMENT_LENGTH); |
| } |
| |
| void setPlaintextLimit(int plaintextLimit) |
| { |
| this.plaintextLimit = plaintextLimit; |
| } |
| |
| ProtocolVersion getDiscoveredPeerVersion() |
| { |
| return discoveredPeerVersion; |
| } |
| |
| ProtocolVersion resetDiscoveredPeerVersion() |
| { |
| ProtocolVersion result = discoveredPeerVersion; |
| discoveredPeerVersion = null; |
| return result; |
| } |
| |
| void initPendingEpoch(TlsCipher pendingCipher) |
| { |
| if (pendingEpoch != null) |
| { |
| throw new IllegalStateException(); |
| } |
| |
| /* |
| * TODO "In order to ensure that any given sequence/epoch pair is unique, implementations |
| * MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment |
| * lifetime." |
| */ |
| |
| // TODO Check for overflow |
| this.pendingEpoch = new DTLSEpoch(writeEpoch.getEpoch() + 1, pendingCipher); |
| } |
| |
| void handshakeSuccessful(DTLSHandshakeRetransmit retransmit) |
| { |
| if (readEpoch == currentEpoch || writeEpoch == currentEpoch) |
| { |
| // TODO |
| throw new IllegalStateException(); |
| } |
| |
| if (retransmit != null) |
| { |
| this.retransmit = retransmit; |
| this.retransmitEpoch = currentEpoch; |
| this.retransmitExpiry = System.currentTimeMillis() + RETRANSMIT_TIMEOUT; |
| } |
| |
| this.inHandshake = false; |
| this.currentEpoch = pendingEpoch; |
| this.pendingEpoch = null; |
| } |
| |
| void resetWriteEpoch() |
| { |
| if (retransmitEpoch != null) |
| { |
| this.writeEpoch = retransmitEpoch; |
| } |
| else |
| { |
| this.writeEpoch = currentEpoch; |
| } |
| } |
| |
| public int getReceiveLimit() |
| throws IOException |
| { |
| return Math.min(this.plaintextLimit, |
| readEpoch.getCipher().getPlaintextLimit(transport.getReceiveLimit() - RECORD_HEADER_LENGTH)); |
| } |
| |
| public int getSendLimit() |
| throws IOException |
| { |
| return Math.min(this.plaintextLimit, |
| writeEpoch.getCipher().getPlaintextLimit(transport.getSendLimit() - RECORD_HEADER_LENGTH)); |
| } |
| |
| public int receive(byte[] buf, int off, int len, int waitMillis) |
| throws IOException |
| { |
| byte[] record = null; |
| |
| for (;;) |
| { |
| int receiveLimit = Math.min(len, getReceiveLimit()) + RECORD_HEADER_LENGTH; |
| if (record == null || record.length < receiveLimit) |
| { |
| record = new byte[receiveLimit]; |
| } |
| |
| try |
| { |
| if (retransmit != null && System.currentTimeMillis() > retransmitExpiry) |
| { |
| retransmit = null; |
| retransmitEpoch = null; |
| } |
| |
| int received = receiveRecord(record, 0, receiveLimit, waitMillis); |
| if (received < 0) |
| { |
| return received; |
| } |
| if (received < RECORD_HEADER_LENGTH) |
| { |
| continue; |
| } |
| int length = TlsUtils.readUint16(record, 11); |
| if (received != (length + RECORD_HEADER_LENGTH)) |
| { |
| continue; |
| } |
| |
| short type = TlsUtils.readUint8(record, 0); |
| |
| // TODO Support user-specified custom protocols? |
| switch (type) |
| { |
| case ContentType.alert: |
| case ContentType.application_data: |
| case ContentType.change_cipher_spec: |
| case ContentType.handshake: |
| case ContentType.heartbeat: |
| break; |
| default: |
| // TODO Exception? |
| continue; |
| } |
| |
| int epoch = TlsUtils.readUint16(record, 3); |
| |
| DTLSEpoch recordEpoch = null; |
| if (epoch == readEpoch.getEpoch()) |
| { |
| recordEpoch = readEpoch; |
| } |
| else if (type == ContentType.handshake && retransmitEpoch != null |
| && epoch == retransmitEpoch.getEpoch()) |
| { |
| recordEpoch = retransmitEpoch; |
| } |
| |
| if (recordEpoch == null) |
| { |
| continue; |
| } |
| |
| long seq = TlsUtils.readUint48(record, 5); |
| if (recordEpoch.getReplayWindow().shouldDiscard(seq)) |
| { |
| continue; |
| } |
| |
| ProtocolVersion version = TlsUtils.readVersion(record, 1); |
| if (discoveredPeerVersion != null && !discoveredPeerVersion.equals(version)) |
| { |
| continue; |
| } |
| |
| byte[] plaintext = recordEpoch.getCipher().decodeCiphertext( |
| getMacSequenceNumber(recordEpoch.getEpoch(), seq), type, record, RECORD_HEADER_LENGTH, |
| received - RECORD_HEADER_LENGTH); |
| |
| recordEpoch.getReplayWindow().reportAuthenticated(seq); |
| |
| if (plaintext.length > this.plaintextLimit) |
| { |
| continue; |
| } |
| |
| if (discoveredPeerVersion == null) |
| { |
| discoveredPeerVersion = version; |
| } |
| |
| switch (type) |
| { |
| case ContentType.alert: |
| { |
| if (plaintext.length == 2) |
| { |
| short alertLevel = plaintext[0]; |
| short alertDescription = plaintext[1]; |
| |
| peer.notifyAlertReceived(alertLevel, alertDescription); |
| |
| if (alertLevel == AlertLevel.fatal) |
| { |
| fail(alertDescription); |
| throw new TlsFatalAlert(alertDescription); |
| } |
| |
| // TODO Can close_notify be a fatal alert? |
| if (alertDescription == AlertDescription.close_notify) |
| { |
| closeTransport(); |
| } |
| } |
| else |
| { |
| // TODO What exception? |
| } |
| |
| continue; |
| } |
| case ContentType.application_data: |
| { |
| if (inHandshake) |
| { |
| // TODO Consider buffering application data for new epoch that arrives |
| // out-of-order with the Finished message |
| continue; |
| } |
| break; |
| } |
| case ContentType.change_cipher_spec: |
| { |
| // Implicitly receive change_cipher_spec and change to pending cipher state |
| |
| for (int i = 0; i < plaintext.length; ++i) |
| { |
| short message = TlsUtils.readUint8(plaintext, i); |
| if (message != ChangeCipherSpec.change_cipher_spec) |
| { |
| continue; |
| } |
| |
| if (pendingEpoch != null) |
| { |
| readEpoch = pendingEpoch; |
| } |
| } |
| |
| continue; |
| } |
| case ContentType.handshake: |
| { |
| if (!inHandshake) |
| { |
| if (retransmit != null) |
| { |
| retransmit.receivedHandshakeRecord(epoch, plaintext, 0, plaintext.length); |
| } |
| |
| // TODO Consider support for HelloRequest |
| continue; |
| } |
| break; |
| } |
| case ContentType.heartbeat: |
| { |
| // TODO[RFC 6520] |
| continue; |
| } |
| } |
| |
| /* |
| * NOTE: If we receive any non-handshake data in the new epoch implies the peer has |
| * received our final flight. |
| */ |
| if (!inHandshake && retransmit != null) |
| { |
| this.retransmit = null; |
| this.retransmitEpoch = null; |
| } |
| |
| System.arraycopy(plaintext, 0, buf, off, plaintext.length); |
| return plaintext.length; |
| } |
| catch (IOException e) |
| { |
| // NOTE: Assume this is a timeout for the moment |
| throw e; |
| } |
| } |
| } |
| |
| public void send(byte[] buf, int off, int len) |
| throws IOException |
| { |
| short contentType = ContentType.application_data; |
| |
| if (this.inHandshake || this.writeEpoch == this.retransmitEpoch) |
| { |
| contentType = ContentType.handshake; |
| |
| short handshakeType = TlsUtils.readUint8(buf, off); |
| if (handshakeType == HandshakeType.finished) |
| { |
| DTLSEpoch nextEpoch = null; |
| if (this.inHandshake) |
| { |
| nextEpoch = pendingEpoch; |
| } |
| else if (this.writeEpoch == this.retransmitEpoch) |
| { |
| nextEpoch = currentEpoch; |
| } |
| |
| if (nextEpoch == null) |
| { |
| // TODO |
| throw new IllegalStateException(); |
| } |
| |
| // Implicitly send change_cipher_spec and change to pending cipher state |
| |
| // TODO Send change_cipher_spec and finished records in single datagram? |
| byte[] data = new byte[]{ 1 }; |
| sendRecord(ContentType.change_cipher_spec, data, 0, data.length); |
| |
| writeEpoch = nextEpoch; |
| } |
| } |
| |
| sendRecord(contentType, buf, off, len); |
| } |
| |
| public void close() |
| throws IOException |
| { |
| if (!closed) |
| { |
| if (inHandshake) |
| { |
| warn(AlertDescription.user_canceled, "User canceled handshake"); |
| } |
| closeTransport(); |
| } |
| } |
| |
| void fail(short alertDescription) |
| { |
| if (!closed) |
| { |
| try |
| { |
| raiseAlert(AlertLevel.fatal, alertDescription, null, null); |
| } |
| catch (Exception e) |
| { |
| // Ignore |
| } |
| |
| failed = true; |
| |
| closeTransport(); |
| } |
| } |
| |
| void warn(short alertDescription, String message) |
| throws IOException |
| { |
| raiseAlert(AlertLevel.warning, alertDescription, message, null); |
| } |
| |
| private void closeTransport() |
| { |
| if (!closed) |
| { |
| /* |
| * RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is |
| * required to send a close_notify alert before closing the write side of the |
| * connection. The other party MUST respond with a close_notify alert of its own and |
| * close down the connection immediately, discarding any pending writes. |
| */ |
| |
| try |
| { |
| if (!failed) |
| { |
| warn(AlertDescription.close_notify, null); |
| } |
| transport.close(); |
| } |
| catch (Exception e) |
| { |
| // Ignore |
| } |
| |
| closed = true; |
| } |
| } |
| |
| private void raiseAlert(short alertLevel, short alertDescription, String message, Exception cause) |
| throws IOException |
| { |
| peer.notifyAlertRaised(alertLevel, alertDescription, message, cause); |
| |
| byte[] error = new byte[2]; |
| error[0] = (byte)alertLevel; |
| error[1] = (byte)alertDescription; |
| |
| sendRecord(ContentType.alert, error, 0, 2); |
| } |
| |
| private int receiveRecord(byte[] buf, int off, int len, int waitMillis) |
| throws IOException |
| { |
| if (recordQueue.size() > 0) |
| { |
| int length = 0; |
| if (recordQueue.size() >= RECORD_HEADER_LENGTH) |
| { |
| byte[] lengthBytes = new byte[2]; |
| recordQueue.read(lengthBytes, 0, 2, 11); |
| length = TlsUtils.readUint16(lengthBytes, 0); |
| } |
| |
| int received = Math.min(recordQueue.size(), RECORD_HEADER_LENGTH + length); |
| recordQueue.removeData(buf, off, received, 0); |
| return received; |
| } |
| |
| int received = transport.receive(buf, off, len, waitMillis); |
| if (received >= RECORD_HEADER_LENGTH) |
| { |
| int fragmentLength = TlsUtils.readUint16(buf, off + 11); |
| int recordLength = RECORD_HEADER_LENGTH + fragmentLength; |
| if (received > recordLength) |
| { |
| recordQueue.addData(buf, off + recordLength, received - recordLength); |
| received = recordLength; |
| } |
| } |
| |
| return received; |
| } |
| |
| private void sendRecord(short contentType, byte[] buf, int off, int len) |
| throws IOException |
| { |
| if (len > this.plaintextLimit) |
| { |
| throw new TlsFatalAlert(AlertDescription.internal_error); |
| } |
| |
| /* |
| * RFC 5264 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, |
| * or ChangeCipherSpec content types. |
| */ |
| if (len < 1 && contentType != ContentType.application_data) |
| { |
| throw new TlsFatalAlert(AlertDescription.internal_error); |
| } |
| |
| int recordEpoch = writeEpoch.getEpoch(); |
| long recordSequenceNumber = writeEpoch.allocateSequenceNumber(); |
| |
| byte[] ciphertext = writeEpoch.getCipher().encodePlaintext( |
| getMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len); |
| |
| // TODO Check the ciphertext length? |
| |
| byte[] record = new byte[ciphertext.length + RECORD_HEADER_LENGTH]; |
| TlsUtils.writeUint8(contentType, record, 0); |
| ProtocolVersion version = discoveredPeerVersion != null ? discoveredPeerVersion : context.getClientVersion(); |
| TlsUtils.writeVersion(version, record, 1); |
| TlsUtils.writeUint16(recordEpoch, record, 3); |
| TlsUtils.writeUint48(recordSequenceNumber, record, 5); |
| TlsUtils.writeUint16(ciphertext.length, record, 11); |
| System.arraycopy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.length); |
| |
| transport.send(record, 0, record.length); |
| } |
| |
| private static long getMacSequenceNumber(int epoch, long sequence_number) |
| { |
| return ((long)epoch << 48) | sequence_number; |
| } |
| } |