blob: 88f64784bef8df0ed7ff00be5fd684c21c9e8679 [file] [log] [blame]
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
import io.grpc.SecurityLevel;
import io.grpc.alts.internal.Handshaker.HandshakerResult;
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.grpc.alts.internal.TsiPeer.Property;
import io.grpc.internal.GrpcAttributes;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2Connection;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2FrameReader;
import io.netty.handler.codec.http2.Http2FrameWriter;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link AltsProtocolNegotiator}. */
@RunWith(JUnit4.class)
public class AltsProtocolNegotiatorTest {
private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
private final List<ReferenceCounted> references = new ArrayList<>();
private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();
private EmbeddedChannel channel;
private Throwable caughtException;
private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
private ChannelHandler handler;
private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
private AltsAuthContext mockedAltsContext =
new AltsAuthContext(
HandshakerResult.newBuilder()
.setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
.build());
private final TsiHandshaker mockHandshaker =
new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return mockedAltsContext;
}
};
private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker);
@Before
public void setup() throws Exception {
ChannelHandler userEventHandler =
new ChannelDuplexHandler() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) {
tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
} else {
super.userEventTriggered(ctx, evt);
}
}
};
ChannelHandler uncaughtExceptionHandler =
new ChannelDuplexHandler() {
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
caughtException = cause;
super.exceptionCaught(ctx, cause);
}
};
TsiHandshakerFactory handshakerFactory =
new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
@Override
public TsiHandshaker newHandshaker(String authority) {
return new DelegatingTsiHandshaker(super.newHandshaker(authority)) {
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return mockedTsiPeer;
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return mockedAltsContext;
}
};
}
};
handler =
AltsProtocolNegotiator.createServerNegotiator(handshakerFactory).newHandler(grpcHandler);
channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
}
@After
public void teardown() throws Exception {
if (channel != null) {
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError = channel.close();
}
for (ReferenceCounted reference : references) {
ReferenceCountUtil.safeRelease(reference);
}
}
@Test
public void handshakeShouldBeSuccessful() throws Exception {
doHandshake();
}
@Test
@SuppressWarnings("unchecked") // List cast
public void protectShouldRoundtrip() throws Exception {
// Write the message 1 character at a time. The message should be buffered
// and not interfere with the handshake.
final AtomicInteger writeCount = new AtomicInteger();
String message = "hello";
for (int ix = 0; ix < message.length(); ++ix) {
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError =
channel
.write(in)
.addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
writeCount.incrementAndGet();
}
}
});
}
channel.flush();
// Now do the handshake. The buffered message will automatically be protected
// and sent.
doHandshake();
// Capture the protected data written to the wire.
assertEquals(1, channel.outboundMessages().size());
ByteBuf protectedData = channel.<ByteBuf>readOutbound();
assertEquals(message.length(), writeCount.get());
// Read the protected message at the server and verify it matches the original message.
TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
List<ByteBuf> unprotected = new ArrayList<>();
serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
// We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
// a settings frame is written (and an HTTP2 preface). This is hard coded into Netty, so we
// have to remove it here. See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
int settingsFrameLength = 9;
CompositeByteBuf unprotectedAll =
new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
assertEquals(message, unprotectedData.toString(UTF_8));
// Protect the same message at the server.
final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
serverProtector.protectFlush(
Collections.singletonList(unprotectedData),
new Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
newlyProtectedData.set(buf);
}
},
channel.alloc());
// Read the protected message at the client and verify that it matches the original message.
channel.writeInbound(newlyProtectedData.get());
assertEquals(1, channel.inboundMessages().size());
assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
}
@Test
public void unprotectLargeIncomingFrame() throws Exception {
// We use a server frameprotector with twice the standard frame size.
int serverFrameSize = 4096 * 2;
// This should fit into one frame.
byte[] unprotectedBytes = new byte[serverFrameSize - 500];
Arrays.fill(unprotectedBytes, (byte) 7);
ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
unprotectedData.writerIndex(unprotectedBytes.length);
// Perform handshake.
doHandshake();
// Protect the message on the server.
TsiFrameProtector serverProtector =
serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
serverProtector.protectFlush(
Collections.singletonList(unprotectedData),
new Consumer<ByteBuf>() {
@Override
public void accept(ByteBuf buf) {
channel.writeInbound(buf);
}
},
channel.alloc());
channel.flushInbound();
// Read the protected message at the client and verify that it matches the original message.
assertEquals(1, channel.inboundMessages().size());
ByteBuf receivedData1 = channel.<ByteBuf>readInbound();
int receivedLen1 = receivedData1.readableBytes();
byte[] receivedBytes = new byte[receivedLen1];
receivedData1.readBytes(receivedBytes, 0, receivedLen1);
assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
assertThat(unprotectedBytes).isEqualTo(receivedBytes);
}
@Test
public void flushShouldFailAllPromises() throws Exception {
doHandshake();
channel
.pipeline()
.addFirst(
new ChannelDuplexHandler() {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
throw new Exception("Fake exception");
}
});
// Write the message 1 character at a time.
String message = "hello";
final AtomicInteger failures = new AtomicInteger();
for (int ix = 0; ix < message.length(); ++ix) {
ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
@SuppressWarnings("unused") // go/futurereturn-lsc
Future<?> possiblyIgnoredError =
channel
.write(in)
.addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
failures.incrementAndGet();
}
}
});
}
channel.flush();
// Verify that the promises fail.
assertEquals(message.length(), failures.get());
}
@Test
public void doNotFlushEmptyBuffer() throws Exception {
doHandshake();
assertEquals(1, protectors.size());
InterceptingProtector protector = protectors.poll();
String message = "hello";
ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);
assertEquals(0, protector.flushes.get());
Future<?> done = channel.write(in);
channel.flush();
done.get(5, TimeUnit.SECONDS);
assertEquals(1, protector.flushes.get());
done = channel.write(Unpooled.EMPTY_BUFFER);
channel.flush();
done.get(5, TimeUnit.SECONDS);
assertEquals(1, protector.flushes.get());
}
@Test
public void peerPropagated() throws Exception {
doHandshake();
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.TSI_PEER_KEY))
.isEqualTo(mockedTsiPeer);
assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.ALTS_CONTEXT_KEY))
.isEqualTo(mockedAltsContext);
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
.isEqualTo("embedded");
assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR).toString())
.isEqualTo("embedded");
assertThat(grpcHandler.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL))
.isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
}
private void doHandshake() throws Exception {
// Capture the client frame and add to the server.
assertEquals(1, channel.outboundMessages().size());
ByteBuf clientFrame = channel.<ByteBuf>readOutbound();
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
// Get the server response handshake frames.
ByteBuf serverFrame = channel.alloc().buffer();
serverHandshaker.getBytesToSendToPeer(serverFrame);
channel.writeInbound(serverFrame);
// Capture the next client frame and add to the server.
assertEquals(1, channel.outboundMessages().size());
clientFrame = channel.<ByteBuf>readOutbound();
assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
// Get the server response handshake frames.
serverFrame = channel.alloc().buffer();
serverHandshaker.getBytesToSendToPeer(serverFrame);
channel.writeInbound(serverFrame);
// Ensure that both sides have confirmed that the handshake has completed.
assertFalse(serverHandshaker.isInProgress());
if (caughtException != null) {
throw new RuntimeException(caughtException);
}
assertNotNull(tsiEvent);
}
private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
// Netty Boilerplate. We don't really need any of this, but there is a tight coupling
// between a Http2ConnectionHandler and its dependencies.
Http2Connection connection = new DefaultHttp2Connection(true);
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
DefaultHttp2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(connection, frameWriter);
DefaultHttp2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
}
private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
private Attributes attrs;
private CapturingGrpcHttp2ConnectionHandler(
Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder,
Http2Settings initialSettings) {
super(null, decoder, encoder, initialSettings);
}
@Override
public void handleProtocolNegotiationCompleted(
Attributes attrs, InternalChannelz.Security securityInfo) {
// If we are added to the pipeline, we need to remove ourselves. The HTTP2 handler
channel.pipeline().remove(this);
this.attrs = attrs;
}
}
private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {
private TsiHandshakerFactory delegate;
DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
this.delegate = delegate;
}
@Override
public TsiHandshaker newHandshaker(String authority) {
return delegate.newHandshaker(authority);
}
}
private class DelegatingTsiHandshaker implements TsiHandshaker {
private final TsiHandshaker delegate;
DelegatingTsiHandshaker(TsiHandshaker delegate) {
this.delegate = delegate;
}
@Override
public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
delegate.getBytesToSendToPeer(bytes);
}
@Override
public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
return delegate.processBytesFromPeer(bytes);
}
@Override
public boolean isInProgress() {
return delegate.isInProgress();
}
@Override
public TsiPeer extractPeer() throws GeneralSecurityException {
return delegate.extractPeer();
}
@Override
public Object extractPeerObject() throws GeneralSecurityException {
return delegate.extractPeerObject();
}
@Override
public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
InterceptingProtector protector =
new InterceptingProtector(delegate.createFrameProtector(alloc));
protectors.add(protector);
return protector;
}
@Override
public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
InterceptingProtector protector =
new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
protectors.add(protector);
return protector;
}
}
private static class InterceptingProtector implements TsiFrameProtector {
private final TsiFrameProtector delegate;
final AtomicInteger flushes = new AtomicInteger();
InterceptingProtector(TsiFrameProtector delegate) {
this.delegate = delegate;
}
@Override
public void protectFlush(
List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
throws GeneralSecurityException {
flushes.incrementAndGet();
delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
}
@Override
public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
throws GeneralSecurityException {
delegate.unprotect(in, out, alloc);
}
@Override
public void destroy() {
delegate.destroy();
}
}
}