| package org.bouncycastle.crypto.tls; |
| |
| import java.io.ByteArrayOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| |
| import org.bouncycastle.util.io.SimpleOutputStream; |
| |
| /** |
| * An implementation of the TLS 1.0/1.1/1.2 record layer, allowing downgrade to SSLv3. |
| */ |
| class RecordStream |
| { |
| private static int DEFAULT_PLAINTEXT_LIMIT = (1 << 14); |
| static final int TLS_HEADER_SIZE = 5; |
| static final int TLS_HEADER_TYPE_OFFSET = 0; |
| static final int TLS_HEADER_VERSION_OFFSET = 1; |
| static final int TLS_HEADER_LENGTH_OFFSET = 3; |
| |
| private TlsProtocol handler; |
| private InputStream input; |
| private OutputStream output; |
| private TlsCompression pendingCompression = null, readCompression = null, writeCompression = null; |
| private TlsCipher pendingCipher = null, readCipher = null, writeCipher = null; |
| private SequenceNumber readSeqNo = new SequenceNumber(), writeSeqNo = new SequenceNumber(); |
| private ByteArrayOutputStream buffer = new ByteArrayOutputStream(); |
| |
| private TlsHandshakeHash handshakeHash = null; |
| private SimpleOutputStream handshakeHashUpdater = new SimpleOutputStream() |
| { |
| public void write(byte[] buf, int off, int len) throws IOException |
| { |
| handshakeHash.update(buf, off, len); |
| } |
| }; |
| |
| private ProtocolVersion readVersion = null, writeVersion = null; |
| private boolean restrictReadVersion = true; |
| |
| private int plaintextLimit, compressedLimit, ciphertextLimit; |
| |
| RecordStream(TlsProtocol handler, InputStream input, OutputStream output) |
| { |
| this.handler = handler; |
| this.input = input; |
| this.output = output; |
| this.readCompression = new TlsNullCompression(); |
| this.writeCompression = this.readCompression; |
| } |
| |
| void init(TlsContext context) |
| { |
| this.readCipher = new TlsNullCipher(context); |
| this.writeCipher = this.readCipher; |
| this.handshakeHash = new DeferredHash(); |
| this.handshakeHash.init(context); |
| |
| setPlaintextLimit(DEFAULT_PLAINTEXT_LIMIT); |
| } |
| |
| int getPlaintextLimit() |
| { |
| return plaintextLimit; |
| } |
| |
| void setPlaintextLimit(int plaintextLimit) |
| { |
| this.plaintextLimit = plaintextLimit; |
| this.compressedLimit = this.plaintextLimit + 1024; |
| this.ciphertextLimit = this.compressedLimit + 1024; |
| } |
| |
| ProtocolVersion getReadVersion() |
| { |
| return readVersion; |
| } |
| |
| void setReadVersion(ProtocolVersion readVersion) |
| { |
| this.readVersion = readVersion; |
| } |
| |
| void setWriteVersion(ProtocolVersion writeVersion) |
| { |
| this.writeVersion = writeVersion; |
| } |
| |
| /** |
| * RFC 5246 E.1. "Earlier versions of the TLS specification were not fully clear on what the |
| * record layer version number (TLSPlaintext.version) should contain when sending ClientHello |
| * (i.e., before it is known which version of the protocol will be employed). Thus, TLS servers |
| * compliant with this specification MUST accept any value {03,XX} as the record layer version |
| * number for ClientHello." |
| */ |
| void setRestrictReadVersion(boolean enabled) |
| { |
| this.restrictReadVersion = enabled; |
| } |
| |
| void setPendingConnectionState(TlsCompression tlsCompression, TlsCipher tlsCipher) |
| { |
| this.pendingCompression = tlsCompression; |
| this.pendingCipher = tlsCipher; |
| } |
| |
| void sentWriteCipherSpec() |
| throws IOException |
| { |
| if (pendingCompression == null || pendingCipher == null) |
| { |
| throw new TlsFatalAlert(AlertDescription.handshake_failure); |
| } |
| this.writeCompression = this.pendingCompression; |
| this.writeCipher = this.pendingCipher; |
| this.writeSeqNo = new SequenceNumber(); |
| } |
| |
| void receivedReadCipherSpec() |
| throws IOException |
| { |
| if (pendingCompression == null || pendingCipher == null) |
| { |
| throw new TlsFatalAlert(AlertDescription.handshake_failure); |
| } |
| this.readCompression = this.pendingCompression; |
| this.readCipher = this.pendingCipher; |
| this.readSeqNo = new SequenceNumber(); |
| } |
| |
| void finaliseHandshake() |
| throws IOException |
| { |
| if (readCompression != pendingCompression || writeCompression != pendingCompression |
| || readCipher != pendingCipher || writeCipher != pendingCipher) |
| { |
| throw new TlsFatalAlert(AlertDescription.handshake_failure); |
| } |
| this.pendingCompression = null; |
| this.pendingCipher = null; |
| } |
| |
| void checkRecordHeader(byte[] recordHeader) throws IOException |
| { |
| short type = TlsUtils.readUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); |
| |
| /* |
| * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an |
| * unexpected_message alert. |
| */ |
| checkType(type, AlertDescription.unexpected_message); |
| |
| if (!restrictReadVersion) |
| { |
| int version = TlsUtils.readVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); |
| if ((version & 0xffffff00) != 0x0300) |
| { |
| throw new TlsFatalAlert(AlertDescription.illegal_parameter); |
| } |
| } |
| else |
| { |
| ProtocolVersion version = TlsUtils.readVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); |
| if (readVersion == null) |
| { |
| // Will be set later in 'readRecord' |
| } |
| else if (!version.equals(readVersion)) |
| { |
| throw new TlsFatalAlert(AlertDescription.illegal_parameter); |
| } |
| } |
| |
| int length = TlsUtils.readUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); |
| |
| checkLength(length, ciphertextLimit, AlertDescription.record_overflow); |
| } |
| |
| boolean readRecord() |
| throws IOException |
| { |
| byte[] recordHeader = TlsUtils.readAllOrNothing(TLS_HEADER_SIZE, input); |
| if (recordHeader == null) |
| { |
| return false; |
| } |
| |
| short type = TlsUtils.readUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); |
| |
| /* |
| * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an |
| * unexpected_message alert. |
| */ |
| checkType(type, AlertDescription.unexpected_message); |
| |
| if (!restrictReadVersion) |
| { |
| int version = TlsUtils.readVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); |
| if ((version & 0xffffff00) != 0x0300) |
| { |
| throw new TlsFatalAlert(AlertDescription.illegal_parameter); |
| } |
| } |
| else |
| { |
| ProtocolVersion version = TlsUtils.readVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); |
| if (readVersion == null) |
| { |
| readVersion = version; |
| } |
| else if (!version.equals(readVersion)) |
| { |
| throw new TlsFatalAlert(AlertDescription.illegal_parameter); |
| } |
| } |
| |
| int length = TlsUtils.readUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); |
| |
| checkLength(length, ciphertextLimit, AlertDescription.record_overflow); |
| |
| byte[] plaintext = decodeAndVerify(type, input, length); |
| handler.processRecord(type, plaintext, 0, plaintext.length); |
| return true; |
| } |
| |
| byte[] decodeAndVerify(short type, InputStream input, int len) |
| throws IOException |
| { |
| byte[] buf = TlsUtils.readFully(len, input); |
| |
| long seqNo = readSeqNo.nextValue(AlertDescription.unexpected_message); |
| byte[] decoded = readCipher.decodeCiphertext(seqNo, type, buf, 0, buf.length); |
| |
| checkLength(decoded.length, compressedLimit, AlertDescription.record_overflow); |
| |
| /* |
| * TODO RFC 5246 6.2.2. Implementation note: Decompression functions are responsible for |
| * ensuring that messages cannot cause internal buffer overflows. |
| */ |
| OutputStream cOut = readCompression.decompress(buffer); |
| if (cOut != buffer) |
| { |
| cOut.write(decoded, 0, decoded.length); |
| cOut.flush(); |
| decoded = getBufferContents(); |
| } |
| |
| /* |
| * RFC 5246 6.2.2. If the decompression function encounters a TLSCompressed.fragment that |
| * would decompress to a length in excess of 2^14 bytes, it should report a fatal |
| * decompression failure error. |
| */ |
| checkLength(decoded.length, plaintextLimit, AlertDescription.decompression_failure); |
| |
| /* |
| * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, |
| * or ChangeCipherSpec content types. |
| */ |
| if (decoded.length < 1 && type != ContentType.application_data) |
| { |
| throw new TlsFatalAlert(AlertDescription.illegal_parameter); |
| } |
| |
| return decoded; |
| } |
| |
| void writeRecord(short type, byte[] plaintext, int plaintextOffset, int plaintextLength) |
| throws IOException |
| { |
| // Never send anything until a valid ClientHello has been received |
| if (writeVersion == null) |
| { |
| return; |
| } |
| |
| /* |
| * RFC 5246 6. Implementations MUST NOT send record types not defined in this document |
| * unless negotiated by some extension. |
| */ |
| checkType(type, AlertDescription.internal_error); |
| |
| /* |
| * RFC 5246 6.2.1 The length should not exceed 2^14. |
| */ |
| checkLength(plaintextLength, plaintextLimit, AlertDescription.internal_error); |
| |
| /* |
| * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, |
| * or ChangeCipherSpec content types. |
| */ |
| if (plaintextLength < 1 && type != ContentType.application_data) |
| { |
| throw new TlsFatalAlert(AlertDescription.internal_error); |
| } |
| |
| OutputStream cOut = writeCompression.compress(buffer); |
| |
| long seqNo = writeSeqNo.nextValue(AlertDescription.internal_error); |
| |
| byte[] ciphertext; |
| if (cOut == buffer) |
| { |
| ciphertext = writeCipher.encodePlaintext(seqNo, type, plaintext, plaintextOffset, plaintextLength); |
| } |
| else |
| { |
| cOut.write(plaintext, plaintextOffset, plaintextLength); |
| cOut.flush(); |
| byte[] compressed = getBufferContents(); |
| |
| /* |
| * RFC 5246 6.2.2. Compression must be lossless and may not increase the content length |
| * by more than 1024 bytes. |
| */ |
| checkLength(compressed.length, plaintextLength + 1024, AlertDescription.internal_error); |
| |
| ciphertext = writeCipher.encodePlaintext(seqNo, type, compressed, 0, compressed.length); |
| } |
| |
| /* |
| * RFC 5246 6.2.3. The length may not exceed 2^14 + 2048. |
| */ |
| checkLength(ciphertext.length, ciphertextLimit, AlertDescription.internal_error); |
| |
| byte[] record = new byte[ciphertext.length + TLS_HEADER_SIZE]; |
| TlsUtils.writeUint8(type, record, TLS_HEADER_TYPE_OFFSET); |
| TlsUtils.writeVersion(writeVersion, record, TLS_HEADER_VERSION_OFFSET); |
| TlsUtils.writeUint16(ciphertext.length, record, TLS_HEADER_LENGTH_OFFSET); |
| System.arraycopy(ciphertext, 0, record, TLS_HEADER_SIZE, ciphertext.length); |
| output.write(record); |
| output.flush(); |
| } |
| |
| void notifyHelloComplete() |
| { |
| this.handshakeHash = handshakeHash.notifyPRFDetermined(); |
| } |
| |
| TlsHandshakeHash getHandshakeHash() |
| { |
| return handshakeHash; |
| } |
| |
| OutputStream getHandshakeHashUpdater() |
| { |
| return handshakeHashUpdater; |
| } |
| |
| TlsHandshakeHash prepareToFinish() |
| { |
| TlsHandshakeHash result = handshakeHash; |
| this.handshakeHash = handshakeHash.stopTracking(); |
| return result; |
| } |
| |
| void safeClose() |
| { |
| try |
| { |
| input.close(); |
| } |
| catch (IOException e) |
| { |
| } |
| |
| try |
| { |
| output.close(); |
| } |
| catch (IOException e) |
| { |
| } |
| } |
| |
| void flush() |
| throws IOException |
| { |
| output.flush(); |
| } |
| |
| private byte[] getBufferContents() |
| { |
| byte[] contents = buffer.toByteArray(); |
| buffer.reset(); |
| return contents; |
| } |
| |
| private static void checkType(short type, short alertDescription) |
| throws IOException |
| { |
| switch (type) |
| { |
| case ContentType.application_data: |
| case ContentType.alert: |
| case ContentType.change_cipher_spec: |
| case ContentType.handshake: |
| // case ContentType.heartbeat: |
| break; |
| default: |
| throw new TlsFatalAlert(alertDescription); |
| } |
| } |
| |
| private static void checkLength(int length, int limit, short alertDescription) |
| throws IOException |
| { |
| if (length > limit) |
| { |
| throw new TlsFatalAlert(alertDescription); |
| } |
| } |
| |
| private static class SequenceNumber |
| { |
| private long value = 0L; |
| private boolean exhausted = false; |
| |
| synchronized long nextValue(short alertDescription) throws TlsFatalAlert |
| { |
| if (exhausted) |
| { |
| throw new TlsFatalAlert(alertDescription); |
| } |
| long result = value; |
| if (++value == 0) |
| { |
| exhausted = true; |
| } |
| return result; |
| } |
| } |
| } |