blob: 1f17d72f669e18cbfdca68b852d9713fea56f8b6 [file] [log] [blame]
// Copyright (c) 2005 Brian Wellington (bwelling@xbill.org)
package org.xbill.DNS;
import java.io.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;
final class TCPClient extends Client {
public
TCPClient(long endTime) throws IOException {
super(SocketChannel.open(), endTime);
}
void
bind(SocketAddress addr) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
channel.socket().bind(addr);
}
void
connect(SocketAddress addr) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
if (channel.connect(addr))
return;
key.interestOps(SelectionKey.OP_CONNECT);
try {
while (!channel.finishConnect()) {
if (!key.isConnectable())
blockUntil(key, endTime);
}
}
finally {
if (key.isValid())
key.interestOps(0);
}
}
void
send(byte [] data) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
verboseLog("TCP write", data);
byte [] lengthArray = new byte[2];
lengthArray[0] = (byte)(data.length >>> 8);
lengthArray[1] = (byte)(data.length & 0xFF);
ByteBuffer [] buffers = new ByteBuffer[2];
buffers[0] = ByteBuffer.wrap(lengthArray);
buffers[1] = ByteBuffer.wrap(data);
int nsent = 0;
key.interestOps(SelectionKey.OP_WRITE);
try {
while (nsent < data.length + 2) {
if (key.isWritable()) {
long n = channel.write(buffers);
if (n < 0)
throw new EOFException();
nsent += (int) n;
if (nsent < data.length + 2 &&
System.currentTimeMillis() > endTime)
throw new SocketTimeoutException();
} else
blockUntil(key, endTime);
}
}
finally {
if (key.isValid())
key.interestOps(0);
}
}
private byte []
_recv(int length) throws IOException {
SocketChannel channel = (SocketChannel) key.channel();
int nrecvd = 0;
byte [] data = new byte[length];
ByteBuffer buffer = ByteBuffer.wrap(data);
key.interestOps(SelectionKey.OP_READ);
try {
while (nrecvd < length) {
if (key.isReadable()) {
long n = channel.read(buffer);
if (n < 0)
throw new EOFException();
nrecvd += (int) n;
if (nrecvd < length &&
System.currentTimeMillis() > endTime)
throw new SocketTimeoutException();
} else
blockUntil(key, endTime);
}
}
finally {
if (key.isValid())
key.interestOps(0);
}
return data;
}
byte []
recv() throws IOException {
byte [] buf = _recv(2);
int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF);
byte [] data = _recv(length);
verboseLog("TCP read", data);
return data;
}
static byte []
sendrecv(SocketAddress local, SocketAddress remote, byte [] data, long endTime)
throws IOException
{
TCPClient client = new TCPClient(endTime);
try {
if (local != null)
client.bind(local);
client.connect(remote);
client.send(data);
return client.recv();
}
finally {
client.cleanup();
}
}
static byte []
sendrecv(SocketAddress addr, byte [] data, long endTime) throws IOException {
return sendrecv(null, addr, data, endTime);
}
}