/*
 * Copyright (c) 2000, 2003, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

/* @test
 * @summary Test selectors and socketchannels
 * @library ..
 */

import java.io.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;
import java.nio.channels.spi.SelectorProvider;
import java.util.*;


public class SelectorTest {
    private static List clientList = new LinkedList();
    private static Random rnd = new Random();
    public static int NUM_CLIENTS = 30;
    public static int TEST_PORT = 31452;
    static PrintStream log = System.err;
    private static int FINISH_TIME = 30000;

    /*
     * Usage note
     *
     * java SelectorTest [server] [client <host>] [<port>]
     *
     * No arguments runs both client and server in separate threads
     * using the default port of 31452.
     *
     * client runs the client on this machine and connects to server
     * at the given IP address.
     *
     * server runs the server on localhost.
     */
    public static void main(String[] args) throws Exception {
        if (args.length == 0) {
            InetSocketAddress isa
                = new InetSocketAddress(InetAddress.getLocalHost(), TEST_PORT);
            Server server = new Server(isa);
            server.start();
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) { }
            Client client = new Client(isa);
            client.start();
            if ((server.finish(FINISH_TIME) & client.finish(FINISH_TIME)) == 0)
                throw new Exception("Failure");
            log.println();

        } else if (args[0].equals("server")) {

            if (args.length > 1)
                TEST_PORT = Integer.parseInt(args[1]);
            InetSocketAddress isa
                = new InetSocketAddress(InetAddress.getLocalHost(), TEST_PORT);
            Server server = new Server(isa);
            server.start();
            if (server.finish(FINISH_TIME) == 0)
                throw new Exception("Failure");
            log.println();

        } else if (args[0].equals("client")) {

            if (args.length < 2) {
                log.println("No host specified: terminating.");
                return;
            }
            String ip = args[1];
            if (args.length > 2)
                TEST_PORT = Integer.parseInt(args[2]);
            InetAddress ia = InetAddress.getByName(ip);
            InetSocketAddress isa = new InetSocketAddress(ia, TEST_PORT);
            Client client = new Client(isa);
            client.start();
            if (client.finish(FINISH_TIME) == 0)
                throw new Exception("Failure");
            log.println();

        } else {
            System.out.println("Usage note:");
            System.out.println("java SelectorTest [server] [client <host>] [<port>]");
            System.out.println("No arguments runs both client and server in separate threads using the default port of 31452.");
            System.out.println("client runs the client on this machine and connects to the server specified.");
            System.out.println("server runs the server on localhost.");
        }
    }

    static class Client extends TestThread {
        InetSocketAddress isa;
        Client(InetSocketAddress isa) {
            super("Client", SelectorTest.log);
            this.isa = isa;
        }

        public void go() throws Exception {
            log.println("starting client...");
            for (int i=0; i<NUM_CLIENTS; i++)
                clientList.add(new RemoteEntity(i, isa, log));

            Collections.shuffle(clientList);

            log.println("created "+NUM_CLIENTS+" clients");
            do {
                for (Iterator i = clientList.iterator(); i.hasNext(); ) {
                    RemoteEntity re = (RemoteEntity) i.next();
                    if (re.cycle()) {
                        i.remove();
                    }
                }
                Collections.shuffle(clientList);
            } while (clientList.size() > 0);
        }
    }

    static class Server extends TestThread {
        private List socketList = new ArrayList();
        private ServerSocket ss;
        private int connectionsAccepted = 0;
        private Selector pollSelector;
        private Selector acceptSelector;
        private InetSocketAddress isa;
        private Set pkeys;
        private Set pskeys;

        Server(InetSocketAddress isa) {
            super("Server", SelectorTest.log);
            this.isa = isa;
        }

        public void go() throws Exception {
            log.println("starting server...");
            acceptSelector = SelectorProvider.provider().openSelector();
            pollSelector = SelectorProvider.provider().openSelector();
            pkeys = pollSelector.keys();
            pskeys = pollSelector.selectedKeys();
            Set readyKeys = acceptSelector.selectedKeys();
            RequestHandler rh = new RequestHandler(pollSelector, log);
            Thread requestThread = new Thread(rh);

            requestThread.start();

            ServerSocketChannel ssc = ServerSocketChannel.open();
            ssc.configureBlocking(false);
            ssc.socket().setReuseAddress(true);
            ssc.socket().bind(isa);

            SelectionKey acceptKey = ssc.register(acceptSelector,
                                                  SelectionKey.OP_ACCEPT);
            while(connectionsAccepted < SelectorTest.NUM_CLIENTS) {
                int keysAdded = acceptSelector.select(100);
                if (keysAdded > 0) {
                    Iterator i = readyKeys.iterator();
                    while(i.hasNext()) {
                        SelectionKey sk = (SelectionKey)i.next();
                        i.remove();
                        ServerSocketChannel nextReady =
                            (ServerSocketChannel)sk.channel();
                        SocketChannel sc = nextReady.accept();
                        connectionsAccepted++;
                        if (sc != null) {
                            sc.configureBlocking(false);
                            synchronized (pkeys) {
                               sc.register(pollSelector, SelectionKey.OP_READ);
                            }
                        } else {
                            throw new RuntimeException(
                                "Socket does not support Channels");
                        }
                    }
                }
            }
            acceptKey.cancel();
            requestThread.join();
            acceptSelector.close();
            pollSelector.close();
        }
    }
}

class RemoteEntity {
    private static Random rnd = new Random();
    int id;
    ByteBuffer data;
    int dataWrittenIndex;
    int totalDataLength;
    boolean initiated = false;
    boolean connected = false;
    boolean written = false;
    boolean acked = false;
    boolean closed = false;
    private SocketChannel sc;
    ByteBuffer ackBuffer;
    PrintStream log;
    InetSocketAddress server;

    RemoteEntity(int id, InetSocketAddress server, PrintStream log)
        throws Exception
    {
        int connectFailures = 0;
        this.id = id;
        this.log = log;
        this.server = server;

        sc = SocketChannel.open();
        sc.configureBlocking(false);

        // Prepare the data buffer to write out from this entity
        // Let's use both slow and fast buffers
        if (rnd.nextBoolean())
            data = ByteBuffer.allocateDirect(100);
        else
            data = ByteBuffer.allocate(100);
        String number = Integer.toString(id);
        if (number.length() == 1)
            number = "0"+number;
        String source = "Testing from " + number;
        data.put(source.getBytes("8859_1"));
        data.flip();
        totalDataLength = source.length();

        // Allocate an ack buffer
        ackBuffer = ByteBuffer.allocateDirect(10);
    }

    private void reset() throws Exception {
        sc.close();
        sc = SocketChannel.open();
        sc.configureBlocking(false);
    }

    private void connect() throws Exception {
        try {
            connected = sc.connect(server);
            initiated = true;
        }  catch (ConnectException e) {
            initiated = false;
            reset();
        }
    }

    private void finishConnect() throws Exception {
        try {
            connected = sc.finishConnect();
        }  catch (IOException e) {
            initiated = false;
            reset();
        }
    }

    int id() {
        return id;
    }

    boolean cycle() throws Exception {
        if (!initiated)
            connect();
        else if (!connected)
            finishConnect();
        else if (!written)
            writeCycle();
        else if (!acked)
            ackCycle();
        else if (!closed)
            close();
        return closed;
    }

    private void ackCycle() throws Exception {
        //log.println("acking from "+id);
        int bytesRead = sc.read(ackBuffer);
        if (bytesRead > 0) {
            acked = true;
        }
    }

    private void close() throws Exception {
        sc.close();
        closed = true;
    }

    private void writeCycle() throws Exception {
        log.println("writing from "+id);
        int numBytesToWrite = rnd.nextInt(10)+1;
        int newWriteTarget = dataWrittenIndex + numBytesToWrite;
        if (newWriteTarget > totalDataLength)
            newWriteTarget = totalDataLength;
        data.limit(newWriteTarget);
        int bytesWritten = sc.write(data);
        if (bytesWritten > 0)
            dataWrittenIndex += bytesWritten;
        if (dataWrittenIndex == totalDataLength) {
            written = true;
            sc.socket().shutdownOutput();
        }
    }

}


class RequestHandler implements Runnable {
    private static Random rnd = new Random();
    private Selector selector;
    private int connectionsHandled = 0;
    private HashMap dataBin = new HashMap();
    PrintStream log;

    public RequestHandler(Selector selector, PrintStream log) {
        this.selector = selector;
        this.log = log;
    }

    public void run() {
        log.println("starting request handler...");
        int connectionsAccepted = 0;

        Set nKeys = selector.keys();
        Set readyKeys = selector.selectedKeys();

        try {
            while(connectionsHandled < SelectorTest.NUM_CLIENTS) {
                int numKeys = selector.select(100);

                // Process channels with data
                synchronized (nKeys) {
                    if (readyKeys.size() > 0) {
                        Iterator i = readyKeys.iterator();
                        while(i.hasNext()) {
                            SelectionKey sk = (SelectionKey)i.next();
                            i.remove();
                            SocketChannel sc = (SocketChannel)sk.channel();
                            if (sc.isOpen())
                                read(sk, sc);
                        }
                    }
                }

                // Give other threads a chance to run
                if (numKeys == 0) {
                    try {
                        Thread.sleep(1);
                    } catch (Exception x) {}
                }
            }
        } catch (Exception e) {
            log.println("Unexpected error 1: "+e);
            e.printStackTrace();
        }
    }

    private void read(SelectionKey sk, SocketChannel sc) throws Exception {
        ByteBuffer bin = (ByteBuffer)dataBin.get(sc);
        if (bin == null) {
            if (rnd.nextBoolean())
                bin = ByteBuffer.allocateDirect(100);
            else
                bin = ByteBuffer.allocate(100);
            dataBin.put(sc, bin);
        }

        int bytesRead = 0;
        do {
            bytesRead = sc.read(bin);
        } while(bytesRead > 0);

        if (bytesRead == -1) {
            sk.interestOps(0);
            bin.flip();
            int size = bin.limit();
            byte[] data = new byte[size];
            for(int j=0; j<size; j++)
                data[j] = bin.get();
            String message = new String(data, "8859_1");
            connectionsHandled++;
            acknowledge(sc);
            log.println("Received >>>"+message + "<<<");
            log.println("Handled: "+connectionsHandled);
        }
    }

    private void acknowledge(SocketChannel sc) throws Exception {
            ByteBuffer ackBuffer = ByteBuffer.allocateDirect(10);
            String s = "ack";
            ackBuffer.put(s.getBytes("8859_1"));
            ackBuffer.flip();
            int bytesWritten = 0;
            while(bytesWritten == 0) {
                bytesWritten += sc.write(ackBuffer);
            }
            sc.close();
    }
}
