| package org.bouncycastle.crypto.tls.test; |
| |
| import java.io.OutputStream; |
| import java.io.PipedInputStream; |
| import java.io.PipedOutputStream; |
| import java.security.SecureRandom; |
| |
| import junit.framework.TestCase; |
| |
| import org.bouncycastle.crypto.tls.TlsClientProtocol; |
| import org.bouncycastle.crypto.tls.TlsServerProtocol; |
| import org.bouncycastle.util.Arrays; |
| import org.bouncycastle.util.io.Streams; |
| |
| public class TlsPSKProtocolTest |
| extends TestCase |
| { |
| public void testClientServer() throws Exception |
| { |
| SecureRandom secureRandom = new SecureRandom(); |
| |
| PipedInputStream clientRead = new PipedInputStream(); |
| PipedInputStream serverRead = new PipedInputStream(); |
| PipedOutputStream clientWrite = new PipedOutputStream(serverRead); |
| PipedOutputStream serverWrite = new PipedOutputStream(clientRead); |
| |
| TlsClientProtocol clientProtocol = new TlsClientProtocol(clientRead, clientWrite, secureRandom); |
| TlsServerProtocol serverProtocol = new TlsServerProtocol(serverRead, serverWrite, secureRandom); |
| |
| ServerThread serverThread = new ServerThread(serverProtocol); |
| serverThread.start(); |
| |
| MockPSKTlsClient client = new MockPSKTlsClient(null); |
| clientProtocol.connect(client); |
| |
| // NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity |
| int length = 1000; |
| |
| byte[] data = new byte[length]; |
| secureRandom.nextBytes(data); |
| |
| OutputStream output = clientProtocol.getOutputStream(); |
| output.write(data); |
| |
| byte[] echo = new byte[data.length]; |
| int count = Streams.readFully(clientProtocol.getInputStream(), echo); |
| |
| assertEquals(count, data.length); |
| assertTrue(Arrays.areEqual(data, echo)); |
| |
| output.close(); |
| |
| serverThread.join(); |
| } |
| |
| static class ServerThread |
| extends Thread |
| { |
| private final TlsServerProtocol serverProtocol; |
| |
| ServerThread(TlsServerProtocol serverProtocol) |
| { |
| this.serverProtocol = serverProtocol; |
| } |
| |
| public void run() |
| { |
| try |
| { |
| MockPSKTlsServer server = new MockPSKTlsServer(); |
| serverProtocol.accept(server); |
| Streams.pipeAll(serverProtocol.getInputStream(), serverProtocol.getOutputStream()); |
| serverProtocol.close(); |
| } |
| catch (Exception e) |
| { |
| // throw new RuntimeException(e); |
| } |
| } |
| } |
| } |