blob: a817c47a15af04b4d26875c71e7b43af48e8a0c3 [file] [log] [blame]
package org.bouncycastle.crypto.tls.test;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.security.SecureRandom;
import junit.framework.TestCase;
import org.bouncycastle.crypto.tls.TlsClientProtocol;
import org.bouncycastle.crypto.tls.TlsServerProtocol;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.io.Streams;
public class TlsProtocolTest
extends TestCase
{
public void testClientServer()
throws Exception
{
SecureRandom secureRandom = new SecureRandom();
PipedInputStream clientRead = new PipedInputStream();
PipedInputStream serverRead = new PipedInputStream();
PipedOutputStream clientWrite = new PipedOutputStream(serverRead);
PipedOutputStream serverWrite = new PipedOutputStream(clientRead);
TlsClientProtocol clientProtocol = new TlsClientProtocol(clientRead, clientWrite, secureRandom);
TlsServerProtocol serverProtocol = new TlsServerProtocol(serverRead, serverWrite, secureRandom);
ServerThread serverThread = new ServerThread(serverProtocol);
serverThread.start();
MockTlsClient client = new MockTlsClient(null);
clientProtocol.connect(client);
// NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity
int length = 1000;
byte[] data = new byte[length];
secureRandom.nextBytes(data);
OutputStream output = clientProtocol.getOutputStream();
output.write(data);
byte[] echo = new byte[data.length];
int count = Streams.readFully(clientProtocol.getInputStream(), echo);
assertEquals(count, data.length);
assertTrue(Arrays.areEqual(data, echo));
output.close();
serverThread.join();
}
static class ServerThread
extends Thread
{
private final TlsServerProtocol serverProtocol;
ServerThread(TlsServerProtocol serverProtocol)
{
this.serverProtocol = serverProtocol;
}
public void run()
{
try
{
MockTlsServer server = new MockTlsServer();
serverProtocol.accept(server);
Streams.pipeAll(serverProtocol.getInputStream(), serverProtocol.getOutputStream());
serverProtocol.close();
}
catch (Exception e)
{
// throw new RuntimeException(e);
}
}
}
}