blob: f32554a472c58b3881f9b556119c5f7d0f7e5177 [file] [log] [blame]
package org.bouncycastle.crypto.tls.test;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.security.SecureRandom;
import junit.framework.TestCase;
import org.bouncycastle.crypto.tls.TlsClientProtocol;
import org.bouncycastle.crypto.tls.TlsServerProtocol;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.io.Streams;
public class TlsPSKProtocolTest
extends TestCase
{
public void testClientServer() throws Exception
{
SecureRandom secureRandom = new SecureRandom();
PipedInputStream clientRead = new PipedInputStream();
PipedInputStream serverRead = new PipedInputStream();
PipedOutputStream clientWrite = new PipedOutputStream(serverRead);
PipedOutputStream serverWrite = new PipedOutputStream(clientRead);
TlsClientProtocol clientProtocol = new TlsClientProtocol(clientRead, clientWrite, secureRandom);
TlsServerProtocol serverProtocol = new TlsServerProtocol(serverRead, serverWrite, secureRandom);
ServerThread serverThread = new ServerThread(serverProtocol);
serverThread.start();
MockPSKTlsClient client = new MockPSKTlsClient(null);
clientProtocol.connect(client);
// NOTE: Because we write-all before we read-any, this length can't be more than the pipe capacity
int length = 1000;
byte[] data = new byte[length];
secureRandom.nextBytes(data);
OutputStream output = clientProtocol.getOutputStream();
output.write(data);
byte[] echo = new byte[data.length];
int count = Streams.readFully(clientProtocol.getInputStream(), echo);
assertEquals(count, data.length);
assertTrue(Arrays.areEqual(data, echo));
output.close();
serverThread.join();
}
static class ServerThread
extends Thread
{
private final TlsServerProtocol serverProtocol;
ServerThread(TlsServerProtocol serverProtocol)
{
this.serverProtocol = serverProtocol;
}
public void run()
{
try
{
MockPSKTlsServer server = new MockPSKTlsServer();
serverProtocol.accept(server);
Streams.pipeAll(serverProtocol.getInputStream(), serverProtocol.getOutputStream());
serverProtocol.close();
}
catch (Exception e)
{
// throw new RuntimeException(e);
}
}
}
}