| package org.bouncycastle.crypto.tls; |
| |
| import java.io.ByteArrayOutputStream; |
| import java.io.IOException; |
| import java.util.Enumeration; |
| import java.util.Hashtable; |
| import java.util.Vector; |
| |
| import org.bouncycastle.util.Integers; |
| |
| class DTLSReliableHandshake |
| { |
| private final static int MAX_RECEIVE_AHEAD = 10; |
| |
| /* |
| * No 'final' modifiers so that it works in earlier JDKs |
| */ |
| private DTLSRecordLayer recordLayer; |
| |
| private TlsHandshakeHash handshakeHash; |
| |
| private Hashtable currentInboundFlight = new Hashtable(); |
| private Hashtable previousInboundFlight = null; |
| private Vector outboundFlight = new Vector(); |
| private boolean sending = true; |
| |
| private int message_seq = 0, next_receive_seq = 0; |
| |
| DTLSReliableHandshake(TlsContext context, DTLSRecordLayer transport) |
| { |
| this.recordLayer = transport; |
| this.handshakeHash = new DeferredHash(); |
| this.handshakeHash.init(context); |
| } |
| |
| void notifyHelloComplete() |
| { |
| this.handshakeHash = handshakeHash.notifyPRFDetermined(); |
| } |
| |
| TlsHandshakeHash getHandshakeHash() |
| { |
| return handshakeHash; |
| } |
| |
| TlsHandshakeHash prepareToFinish() |
| { |
| TlsHandshakeHash result = handshakeHash; |
| this.handshakeHash = handshakeHash.stopTracking(); |
| return result; |
| } |
| |
| void sendMessage(short msg_type, byte[] body) |
| throws IOException |
| { |
| TlsUtils.checkUint24(body.length); |
| |
| if (!sending) |
| { |
| checkInboundFlight(); |
| sending = true; |
| outboundFlight.removeAllElements(); |
| } |
| |
| Message message = new Message(message_seq++, msg_type, body); |
| |
| outboundFlight.addElement(message); |
| |
| writeMessage(message); |
| updateHandshakeMessagesDigest(message); |
| } |
| |
| byte[] receiveMessageBody(short msg_type) |
| throws IOException |
| { |
| Message message = receiveMessage(); |
| if (message.getType() != msg_type) |
| { |
| throw new TlsFatalAlert(AlertDescription.unexpected_message); |
| } |
| |
| return message.getBody(); |
| } |
| |
| Message receiveMessage() |
| throws IOException |
| { |
| if (sending) |
| { |
| sending = false; |
| prepareInboundFlight(); |
| } |
| |
| // Check if we already have the next message waiting |
| { |
| DTLSReassembler next = (DTLSReassembler)currentInboundFlight.get(Integers.valueOf(next_receive_seq)); |
| if (next != null) |
| { |
| byte[] body = next.getBodyIfComplete(); |
| if (body != null) |
| { |
| previousInboundFlight = null; |
| return updateHandshakeMessagesDigest(new Message(next_receive_seq++, next.getMsgType(), body)); |
| } |
| } |
| } |
| |
| byte[] buf = null; |
| |
| // TODO Check the conditions under which we should reset this |
| int readTimeoutMillis = 1000; |
| |
| for (;;) |
| { |
| int receiveLimit = recordLayer.getReceiveLimit(); |
| if (buf == null || buf.length < receiveLimit) |
| { |
| buf = new byte[receiveLimit]; |
| } |
| |
| // TODO Handle records containing multiple handshake messages |
| |
| try |
| { |
| for (; ; ) |
| { |
| int received = recordLayer.receive(buf, 0, receiveLimit, readTimeoutMillis); |
| if (received < 0) |
| { |
| break; |
| } |
| if (received < 12) |
| { |
| continue; |
| } |
| int fragment_length = TlsUtils.readUint24(buf, 9); |
| if (received != (fragment_length + 12)) |
| { |
| continue; |
| } |
| int seq = TlsUtils.readUint16(buf, 4); |
| if (seq > (next_receive_seq + MAX_RECEIVE_AHEAD)) |
| { |
| continue; |
| } |
| short msg_type = TlsUtils.readUint8(buf, 0); |
| int length = TlsUtils.readUint24(buf, 1); |
| int fragment_offset = TlsUtils.readUint24(buf, 6); |
| if (fragment_offset + fragment_length > length) |
| { |
| continue; |
| } |
| |
| if (seq < next_receive_seq) |
| { |
| /* |
| * NOTE: If we receive the previous flight of incoming messages in full |
| * again, retransmit our last flight |
| */ |
| if (previousInboundFlight != null) |
| { |
| DTLSReassembler reassembler = (DTLSReassembler)previousInboundFlight.get(Integers |
| .valueOf(seq)); |
| if (reassembler != null) |
| { |
| reassembler.contributeFragment(msg_type, length, buf, 12, fragment_offset, |
| fragment_length); |
| |
| if (checkAll(previousInboundFlight)) |
| { |
| resendOutboundFlight(); |
| |
| /* |
| * TODO[DTLS] implementations SHOULD back off handshake packet |
| * size during the retransmit backoff. |
| */ |
| readTimeoutMillis = Math.min(readTimeoutMillis * 2, 60000); |
| |
| resetAll(previousInboundFlight); |
| } |
| } |
| } |
| } |
| else |
| { |
| DTLSReassembler reassembler = (DTLSReassembler)currentInboundFlight.get(Integers.valueOf(seq)); |
| if (reassembler == null) |
| { |
| reassembler = new DTLSReassembler(msg_type, length); |
| currentInboundFlight.put(Integers.valueOf(seq), reassembler); |
| } |
| |
| reassembler.contributeFragment(msg_type, length, buf, 12, fragment_offset, fragment_length); |
| |
| if (seq == next_receive_seq) |
| { |
| byte[] body = reassembler.getBodyIfComplete(); |
| if (body != null) |
| { |
| previousInboundFlight = null; |
| return updateHandshakeMessagesDigest(new Message(next_receive_seq++, |
| reassembler.getMsgType(), body)); |
| } |
| } |
| } |
| } |
| } |
| catch (IOException e) |
| { |
| // NOTE: Assume this is a timeout for the moment |
| } |
| |
| resendOutboundFlight(); |
| |
| /* |
| * TODO[DTLS] implementations SHOULD back off handshake packet size during the |
| * retransmit backoff. |
| */ |
| readTimeoutMillis = Math.min(readTimeoutMillis * 2, 60000); |
| } |
| } |
| |
| void finish() |
| { |
| DTLSHandshakeRetransmit retransmit = null; |
| if (!sending) |
| { |
| checkInboundFlight(); |
| } |
| else if (currentInboundFlight != null) |
| { |
| /* |
| * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP], |
| * when in the FINISHED state, the node that transmits the last flight (the server in an |
| * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit |
| * of the peer's last flight with a retransmit of the last flight. |
| */ |
| retransmit = new DTLSHandshakeRetransmit() |
| { |
| public void receivedHandshakeRecord(int epoch, byte[] buf, int off, int len) |
| throws IOException |
| { |
| /* |
| * TODO Need to handle the case where the previous inbound flight contains |
| * messages from two epochs. |
| */ |
| if (len < 12) |
| { |
| return; |
| } |
| int fragment_length = TlsUtils.readUint24(buf, off + 9); |
| if (len != (fragment_length + 12)) |
| { |
| return; |
| } |
| int seq = TlsUtils.readUint16(buf, off + 4); |
| if (seq >= next_receive_seq) |
| { |
| return; |
| } |
| |
| short msg_type = TlsUtils.readUint8(buf, off); |
| |
| // TODO This is a hack that only works until we try to support renegotiation |
| int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0; |
| if (epoch != expectedEpoch) |
| { |
| return; |
| } |
| |
| int length = TlsUtils.readUint24(buf, off + 1); |
| int fragment_offset = TlsUtils.readUint24(buf, off + 6); |
| if (fragment_offset + fragment_length > length) |
| { |
| return; |
| } |
| |
| DTLSReassembler reassembler = (DTLSReassembler)currentInboundFlight.get(Integers.valueOf(seq)); |
| if (reassembler != null) |
| { |
| reassembler.contributeFragment(msg_type, length, buf, off + 12, fragment_offset, |
| fragment_length); |
| if (checkAll(currentInboundFlight)) |
| { |
| resendOutboundFlight(); |
| resetAll(currentInboundFlight); |
| } |
| } |
| } |
| }; |
| } |
| |
| recordLayer.handshakeSuccessful(retransmit); |
| } |
| |
| void resetHandshakeMessagesDigest() |
| { |
| handshakeHash.reset(); |
| } |
| |
| /** |
| * Check that there are no "extra" messages left in the current inbound flight |
| */ |
| private void checkInboundFlight() |
| { |
| Enumeration e = currentInboundFlight.keys(); |
| while (e.hasMoreElements()) |
| { |
| Integer key = (Integer)e.nextElement(); |
| if (key.intValue() >= next_receive_seq) |
| { |
| // TODO Should this be considered an error? |
| } |
| } |
| } |
| |
| private void prepareInboundFlight() |
| { |
| resetAll(currentInboundFlight); |
| previousInboundFlight = currentInboundFlight; |
| currentInboundFlight = new Hashtable(); |
| } |
| |
| private void resendOutboundFlight() |
| throws IOException |
| { |
| recordLayer.resetWriteEpoch(); |
| for (int i = 0; i < outboundFlight.size(); ++i) |
| { |
| writeMessage((Message)outboundFlight.elementAt(i)); |
| } |
| } |
| |
| private Message updateHandshakeMessagesDigest(Message message) |
| throws IOException |
| { |
| if (message.getType() != HandshakeType.hello_request) |
| { |
| byte[] body = message.getBody(); |
| byte[] buf = new byte[12]; |
| TlsUtils.writeUint8(message.getType(), buf, 0); |
| TlsUtils.writeUint24(body.length, buf, 1); |
| TlsUtils.writeUint16(message.getSeq(), buf, 4); |
| TlsUtils.writeUint24(0, buf, 6); |
| TlsUtils.writeUint24(body.length, buf, 9); |
| handshakeHash.update(buf, 0, buf.length); |
| handshakeHash.update(body, 0, body.length); |
| } |
| return message; |
| } |
| |
| private void writeMessage(Message message) |
| throws IOException |
| { |
| int sendLimit = recordLayer.getSendLimit(); |
| int fragmentLimit = sendLimit - 12; |
| |
| // TODO Support a higher minimum fragment size? |
| if (fragmentLimit < 1) |
| { |
| // TODO Should we be throwing an exception here? |
| throw new TlsFatalAlert(AlertDescription.internal_error); |
| } |
| |
| int length = message.getBody().length; |
| |
| // NOTE: Must still send a fragment if body is empty |
| int fragment_offset = 0; |
| do |
| { |
| int fragment_length = Math.min(length - fragment_offset, fragmentLimit); |
| writeHandshakeFragment(message, fragment_offset, fragment_length); |
| fragment_offset += fragment_length; |
| } |
| while (fragment_offset < length); |
| } |
| |
| private void writeHandshakeFragment(Message message, int fragment_offset, int fragment_length) |
| throws IOException |
| { |
| RecordLayerBuffer fragment = new RecordLayerBuffer(12 + fragment_length); |
| TlsUtils.writeUint8(message.getType(), fragment); |
| TlsUtils.writeUint24(message.getBody().length, fragment); |
| TlsUtils.writeUint16(message.getSeq(), fragment); |
| TlsUtils.writeUint24(fragment_offset, fragment); |
| TlsUtils.writeUint24(fragment_length, fragment); |
| fragment.write(message.getBody(), fragment_offset, fragment_length); |
| |
| fragment.sendToRecordLayer(recordLayer); |
| } |
| |
| private static boolean checkAll(Hashtable inboundFlight) |
| { |
| Enumeration e = inboundFlight.elements(); |
| while (e.hasMoreElements()) |
| { |
| if (((DTLSReassembler)e.nextElement()).getBodyIfComplete() == null) |
| { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| private static void resetAll(Hashtable inboundFlight) |
| { |
| Enumeration e = inboundFlight.elements(); |
| while (e.hasMoreElements()) |
| { |
| ((DTLSReassembler)e.nextElement()).reset(); |
| } |
| } |
| |
| static class Message |
| { |
| private final int message_seq; |
| private final short msg_type; |
| private final byte[] body; |
| |
| private Message(int message_seq, short msg_type, byte[] body) |
| { |
| this.message_seq = message_seq; |
| this.msg_type = msg_type; |
| this.body = body; |
| } |
| |
| public int getSeq() |
| { |
| return message_seq; |
| } |
| |
| public short getType() |
| { |
| return msg_type; |
| } |
| |
| public byte[] getBody() |
| { |
| return body; |
| } |
| } |
| |
| static class RecordLayerBuffer extends ByteArrayOutputStream |
| { |
| RecordLayerBuffer(int size) |
| { |
| super(size); |
| } |
| |
| void sendToRecordLayer(DTLSRecordLayer recordLayer) throws IOException |
| { |
| recordLayer.send(buf, 0, count); |
| buf = null; |
| } |
| } |
| } |