| /* |
| * Copyright 2016 The gRPC Authors |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package io.grpc.testing.integration; |
| |
| import static com.google.common.base.Preconditions.checkArgument; |
| |
| import io.netty.util.concurrent.DefaultThreadFactory; |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.IOException; |
| import java.net.InetSocketAddress; |
| import java.net.ServerSocket; |
| import java.net.Socket; |
| import java.util.concurrent.BlockingQueue; |
| import java.util.concurrent.DelayQueue; |
| import java.util.concurrent.Delayed; |
| import java.util.concurrent.LinkedBlockingQueue; |
| import java.util.concurrent.ThreadPoolExecutor; |
| import java.util.concurrent.TimeUnit; |
| import java.util.logging.Logger; |
| |
| public final class TrafficControlProxy { |
| |
| private static final int DEFAULT_BAND_BPS = 1024 * 1024; |
| private static final int DEFAULT_DELAY_NANOS = 200 * 1000 * 1000; |
| private static final Logger logger = Logger.getLogger(TrafficControlProxy.class.getName()); |
| |
| // TODO: make host and ports arguments |
| private final String localhost = "localhost"; |
| private final int serverPort; |
| private final int queueLength; |
| private final int chunkSize; |
| private final int bandwidth; |
| private final long latency; |
| private volatile boolean shutDown; |
| private ServerSocket clientAcceptor; |
| private Socket serverSock; |
| private Socket clientSock; |
| private final ThreadPoolExecutor executor = |
| new ThreadPoolExecutor(5, 10, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(), |
| new DefaultThreadFactory("proxy-pool", true)); |
| |
| /** |
| * Returns a new TrafficControlProxy with default bandwidth and latency. |
| */ |
| public TrafficControlProxy(int serverPort) { |
| this(serverPort, DEFAULT_BAND_BPS, DEFAULT_DELAY_NANOS, TimeUnit.NANOSECONDS); |
| } |
| |
| /** |
| * Returns a new TrafficControlProxy with bandwidth set to targetBPS, and latency set to |
| * targetLatency in latencyUnits. |
| */ |
| public TrafficControlProxy(int serverPort, int targetBps, int targetLatency, |
| TimeUnit latencyUnits) { |
| checkArgument(targetBps > 0); |
| checkArgument(targetLatency > 0); |
| this.serverPort = serverPort; |
| bandwidth = targetBps; |
| // divide by 2 because latency is applied in both directions |
| latency = latencyUnits.toNanos(targetLatency) / 2; |
| queueLength = (int) Math.max(bandwidth * latency / TimeUnit.SECONDS.toNanos(1), 1); |
| chunkSize = Math.max(1, queueLength); |
| } |
| |
| /** |
| * Starts a new thread that waits for client and server and start reader/writer threads. |
| */ |
| public void start() throws IOException { |
| // ClientAcceptor uses a ServerSocket server so that the client can connect to the proxy as it |
| // normally would a server. serverSock then connects the server using a regular Socket as a |
| // client normally would. |
| clientAcceptor = new ServerSocket(); |
| clientAcceptor.bind(new InetSocketAddress(localhost, 0)); |
| executor.execute(new Runnable() { |
| @Override |
| public void run() { |
| try { |
| clientSock = clientAcceptor.accept(); |
| serverSock = new Socket(); |
| serverSock.connect(new InetSocketAddress(localhost, serverPort)); |
| startWorkers(); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| }); |
| logger.info("Started new proxy on port " + clientAcceptor.getLocalPort() |
| + " with Queue Length " + queueLength); |
| } |
| |
| public int getPort() { |
| return clientAcceptor.getLocalPort(); |
| } |
| |
| /** Interrupt all workers and close sockets. */ |
| public void shutDown() throws IOException { |
| // TODO: Handle case where a socket fails to close, therefore blocking the others from closing |
| logger.info("Proxy shutting down... "); |
| shutDown = true; |
| executor.shutdown(); |
| clientAcceptor.close(); |
| clientSock.close(); |
| serverSock.close(); |
| logger.info("Shutdown Complete"); |
| } |
| |
| private void startWorkers() throws IOException { |
| DataInputStream clientIn = new DataInputStream(clientSock.getInputStream()); |
| DataOutputStream clientOut = new DataOutputStream(serverSock.getOutputStream()); |
| DataInputStream serverIn = new DataInputStream(serverSock.getInputStream()); |
| DataOutputStream serverOut = new DataOutputStream(clientSock.getOutputStream()); |
| |
| MessageQueue clientPipe = new MessageQueue(clientIn, clientOut); |
| MessageQueue serverPipe = new MessageQueue(serverIn, serverOut); |
| |
| executor.execute(new Reader(clientPipe)); |
| executor.execute(new Writer(clientPipe)); |
| executor.execute(new Reader(serverPipe)); |
| executor.execute(new Writer(serverPipe)); |
| } |
| |
| private final class Reader implements Runnable { |
| |
| private final MessageQueue queue; |
| |
| Reader(MessageQueue queue) { |
| this.queue = queue; |
| } |
| |
| @Override |
| public void run() { |
| while (!shutDown) { |
| try { |
| queue.readIn(); |
| } catch (IOException e) { |
| shutDown = true; |
| } catch (InterruptedException e) { |
| shutDown = true; |
| } |
| } |
| } |
| |
| } |
| |
| private final class Writer implements Runnable { |
| |
| private final MessageQueue queue; |
| |
| Writer(MessageQueue queue) { |
| this.queue = queue; |
| } |
| |
| @Override |
| public void run() { |
| while (!shutDown) { |
| try { |
| queue.writeOut(); |
| } catch (IOException e) { |
| shutDown = true; |
| } catch (InterruptedException e) { |
| shutDown = true; |
| } |
| } |
| } |
| } |
| |
| /** |
| * A Delay Queue that counts by number of bytes instead of the number of elements. |
| */ |
| private class MessageQueue { |
| DataInputStream inStream; |
| DataOutputStream outStream; |
| int bytesQueued; |
| BlockingQueue<Message> queue = new DelayQueue<Message>(); |
| |
| MessageQueue(DataInputStream inputStream, DataOutputStream outputStream) { |
| inStream = inputStream; |
| outStream = outputStream; |
| } |
| |
| /** |
| * Take a message off the queue and write it to an endpoint. Blocks until a message becomes |
| * available. |
| */ |
| void writeOut() throws InterruptedException, IOException { |
| Message next = queue.take(); |
| outStream.write(next.message, 0, next.messageLength); |
| incrementBytes(-next.messageLength); |
| } |
| |
| /** |
| * Read bytes from an endpoint and add them as a message to the queue. Blocks if the queue is |
| * full. |
| */ |
| void readIn() throws InterruptedException, IOException { |
| byte[] request = new byte[getNextChunk()]; |
| int readableBytes = inStream.read(request); |
| long sendTime = System.nanoTime() + latency; |
| queue.put(new Message(sendTime, request, readableBytes)); |
| incrementBytes(readableBytes); |
| } |
| |
| /** |
| * Block until space on the queue becomes available. Returns how many bytes can be read on to |
| * the queue |
| */ |
| synchronized int getNextChunk() throws InterruptedException { |
| while (bytesQueued == queueLength) { |
| wait(); |
| } |
| return Math.max(0, Math.min(chunkSize, queueLength - bytesQueued)); |
| } |
| |
| synchronized void incrementBytes(int delta) { |
| bytesQueued += delta; |
| if (bytesQueued < queueLength) { |
| notifyAll(); |
| } |
| } |
| } |
| |
| private static class Message implements Delayed { |
| long sendTime; |
| byte[] message; |
| int messageLength; |
| |
| Message(long sendTime, byte[] message, int messageLength) { |
| this.sendTime = sendTime; |
| this.message = message; |
| this.messageLength = messageLength; |
| } |
| |
| @Override |
| public int compareTo(Delayed o) { |
| return ((Long) sendTime).compareTo(((Message) o).sendTime); |
| } |
| |
| @Override |
| public long getDelay(TimeUnit unit) { |
| return unit.convert(sendTime - System.nanoTime(), TimeUnit.NANOSECONDS); |
| } |
| } |
| } |