blob: e2a553203ff745a0e6d3657124f60448e66ecd98 [file] [log] [blame]
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.okhttp;
import com.squareup.okhttp.internal.NamedRunnable;
import com.squareup.okhttp.internal.Util;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.net.Proxy;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.Okio;
/**
* A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer.
* See <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC 1928</a>.
*/
public final class SocksProxy {
private static final int VERSION_5 = 5;
private static final int METHOD_NONE = 0xff;
private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0;
private static final int ADDRESS_TYPE_IPV4 = 1;
private static final int ADDRESS_TYPE_DOMAIN_NAME = 3;
private static final int COMMAND_CONNECT = 1;
private static final int REPLY_SUCCEEDED = 0;
private static final Logger logger = Logger.getLogger(SocksProxy.class.getName());
private final ExecutorService executor = Executors.newCachedThreadPool(
Util.threadFactory("SocksProxy", false));
private ServerSocket serverSocket;
private AtomicInteger connectionCount = new AtomicInteger();
public void play() throws IOException {
serverSocket = new ServerSocket(0);
executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) {
@Override protected void execute() {
try {
while (true) {
Socket socket = serverSocket.accept();
connectionCount.incrementAndGet();
service(socket);
}
} catch (SocketException e) {
logger.info(name + " done accepting connections: " + e.getMessage());
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed unexpectedly", e);
}
}
});
}
public Proxy proxy() {
return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved(
"localhost", serverSocket.getLocalPort()));
}
public int connectionCount() {
return connectionCount.get();
}
public void shutdown() throws Exception {
serverSocket.close();
executor.shutdown();
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
throw new IOException("Gave up waiting for executor to shut down");
}
}
private void service(final Socket from) {
executor.execute(new NamedRunnable("SocksProxy %s", from.getRemoteSocketAddress()) {
@Override protected void execute() {
try {
BufferedSource fromSource = Okio.buffer(Okio.source(from));
BufferedSink fromSink = Okio.buffer(Okio.sink(from));
hello(fromSource, fromSink);
acceptCommand(from.getInetAddress(), fromSource, fromSink);
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
Util.closeQuietly(from);
}
}
});
}
private void hello(BufferedSource fromSource, BufferedSink fromSink) throws IOException {
int version = fromSource.readByte() & 0xff;
int methodCount = fromSource.readByte() & 0xff;
int selectedMethod = METHOD_NONE;
if (version != VERSION_5) {
throw new ProtocolException("unsupported version: " + version);
}
for (int i = 0; i < methodCount; i++) {
int candidateMethod = fromSource.readByte() & 0xff;
if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) {
selectedMethod = candidateMethod;
}
}
switch (selectedMethod) {
case METHOD_NO_AUTHENTICATION_REQUIRED:
fromSink.writeByte(VERSION_5);
fromSink.writeByte(selectedMethod);
fromSink.emit();
break;
default:
throw new ProtocolException("unsupported method: " + selectedMethod);
}
}
private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource,
BufferedSink fromSink) throws IOException {
// Read the command.
int version = fromSource.readByte() & 0xff;
if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version);
int command = fromSource.readByte() & 0xff;
int reserved = fromSource.readByte() & 0xff;
if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved);
int addressType = fromSource.readByte() & 0xff;
InetAddress toAddress;
switch (addressType) {
case ADDRESS_TYPE_IPV4:
toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L));
break;
case ADDRESS_TYPE_DOMAIN_NAME:
int domainNameLength = fromSource.readByte() & 0xff;
String domainName = fromSource.readUtf8(domainNameLength);
toAddress = InetAddress.getByName(domainName);
break;
default:
throw new ProtocolException("unsupported address type: " + addressType);
}
int port = fromSource.readShort() & 0xffff;
switch (command) {
case COMMAND_CONNECT:
Socket toSocket = new Socket(toAddress, port);
byte[] localAddress = toSocket.getLocalAddress().getAddress();
if (localAddress.length != 4) {
throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress());
}
// Write the reply.
fromSink.writeByte(VERSION_5);
fromSink.writeByte(REPLY_SUCCEEDED);
fromSink.writeByte(0);
fromSink.writeByte(ADDRESS_TYPE_IPV4);
fromSink.write(localAddress);
fromSink.writeShort(toSocket.getLocalPort());
fromSink.emit();
logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress);
// Copy sources to sinks in both directions.
BufferedSource toSource = Okio.buffer(Okio.source(toSocket));
BufferedSink toSink = Okio.buffer(Okio.sink(toSocket));
transfer(fromAddress, toAddress, fromSource, toSink);
transfer(fromAddress, toAddress, toSource, fromSink);
break;
default:
throw new ProtocolException("unexpected command: " + command);
}
}
private void transfer(final InetAddress fromAddress, final InetAddress toAddress,
final BufferedSource source, final BufferedSink sink) {
executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) {
@Override protected void execute() {
Buffer buffer = new Buffer();
try {
while (true) {
long byteCount = source.read(buffer, 2048L);
if (byteCount == -1L) break;
sink.write(buffer, byteCount);
sink.emit();
}
} catch (SocketException e) {
logger.info(name + " done: " + e.getMessage());
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}
try {
source.close();
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}
try {
sink.close();
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}
}
});
}
}