blob: 7355b456933955bedf75eb2894f379cb19641104 [file] [log] [blame]
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.security;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.security.SecurityProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.grpc.Attributes;
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.internal.FakeClock;
import io.grpc.internal.TestUtils.NoopChannelLogger;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiationEvent;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.ProtocolNegotiationEvent;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.InternalXdsAttributes;
import io.grpc.xds.TlsContextManager;
import io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils;
import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsHandler;
import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSdsProtocolNegotiator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http2.DefaultHttp2Connection;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.cert.CertStoreException;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link SecurityProtocolNegotiators}. */
@RunWith(JUnit4.class)
public class SecurityProtocolNegotiatorsTest {
private final GrpcHttp2ConnectionHandler grpcHandler =
FakeGrpcHttp2ConnectionHandler.newHandler();
private EmbeddedChannel channel = new EmbeddedChannel();
private ChannelPipeline pipeline = channel.pipeline();
private ChannelHandlerContext channelHandlerCtx;
@Test
public void clientSdsProtocolNegotiatorNewHandler_noTlsContextAttribute() {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler);
ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(mockProtocolNegotiator);
ChannelHandler newHandler = pn.newHandler(grpcHandler);
assertThat(newHandler).isNotNull();
assertThat(newHandler).isSameInstanceAs(mockChannelHandler);
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_noFallback_expectException() {
ClientSdsProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(/* fallbackProtocolNegotiator= */ null);
try {
pn.newHandler(grpcHandler);
fail("exception expected!");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.contains("No TLS config and no fallbackProtocolNegotiator!");
}
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_withTlsContextAttribute() {
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build());
ClientSdsProtocolNegotiator pn =
new ClientSdsProtocolNegotiator(InternalProtocolNegotiators.plaintext());
GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class);
ChannelLogger logger = mock(ChannelLogger.class);
doNothing().when(logger).log(any(ChannelLogLevel.class), anyString());
when(mockHandler.getNegotiationLogger()).thenReturn(logger);
TlsContextManager mockTlsContextManager = mock(TlsContextManager.class);
when(mockHandler.getEagAttributes())
.thenReturn(
Attributes.newBuilder()
.set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager))
.build());
ChannelHandler newHandler = pn.newHandler(mockHandler);
assertThat(newHandler).isNotNull();
assertThat(newHandler).isInstanceOf(ClientSdsHandler.class);
}
@Test
public void clientSdsHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor);
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, null, null, null, null);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil
.buildUpstreamTlsContext("google_cloud_private_spiffe-client", true);
SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient));
SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it
// kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
assertThat(executor.runDueTasks()).isEqualTo(1);
channel.runPendingTasks();
Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull();
// pipeline should have SslHandler and ClientTlsHandler
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
// ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name
assertThat(iterator.next().getValue().getClass().getCanonicalName())
.contains("ProtocolNegotiators.ClientTlsHandler");
CommonCertProviderTestUtils.register0();
}
@Test
public void serverSdsHandler_addLast()
throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor);
// we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test
channel =
new EmbeddedChannel() {
@Override
public SocketAddress localAddress() {
return new InetSocketAddress("172.168.1.1", 80);
}
@Override
public SocketAddress remoteAddress() {
return new InetSocketAddress("172.168.2.2", 90);
}
};
pipeline = channel.pipeline();
Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE,
SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null);
DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContext(
"google_cloud_private_spiffe-server", true, true);
TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer);
SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SecurityProtocolNegotiators.HandlerPickerHandler(grpcHandler,
InternalProtocolNegotiators.serverPlaintext());
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler
ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault();
Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event)
.toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER,
new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager)).build();
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr));
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNull();
channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class);
assertThat(channelHandlerCtx).isNotNull();
SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager);
final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
assertThat(executor.runDueTasks()).isEqualTo(1);
Object fromFuture = future.get(2, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSdsHandler.class);
assertThat(channelHandlerCtx).isNull();
// pipeline should only have SslHandler and ServerTlsHandler
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
// ProtocolNegotiators.ServerTlsHandler.class is not accessible, get canonical name
assertThat(iterator.next().getValue().getClass().getCanonicalName())
.contains("ProtocolNegotiators.ServerTlsHandler");
CommonCertProviderTestUtils.register0();
}
@Test
public void serverSdsHandler_defaultDownstreamTlsContext_expectFallbackProtocolNegotiator()
throws IOException {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler);
// we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test
channel =
new EmbeddedChannel() {
@Override
public SocketAddress localAddress() {
return new InetSocketAddress("172.168.1.1", 80);
}
};
pipeline = channel.pipeline();
SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SecurityProtocolNegotiators.HandlerPickerHandler(
grpcHandler, mockProtocolNegotiator);
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler
ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault();
Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event)
.toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, null).build();
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr));
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNull();
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isSameInstanceAs(mockChannelHandler);
// no more handlers in the pipeline
assertThat(iterator.hasNext()).isFalse();
}
@Test
public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() {
ChannelHandler mockChannelHandler = mock(ChannelHandler.class);
ProtocolNegotiator mockProtocolNegotiator = mock(ProtocolNegotiator.class);
when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler);
SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SecurityProtocolNegotiators.HandlerPickerHandler(
grpcHandler, mockProtocolNegotiator);
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNull();
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
assertThat(iterator.next().getValue()).isSameInstanceAs(mockChannelHandler);
// no more handlers in the pipeline
assertThat(iterator.hasNext()).isFalse();
}
@Test
public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() {
SecurityProtocolNegotiators.HandlerPickerHandler handlerPickerHandler =
new SecurityProtocolNegotiators.HandlerPickerHandler(
grpcHandler, null);
pipeline.addLast(handlerPickerHandler);
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler
// kick off protocol negotiation
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
channelHandlerCtx = pipeline.context(handlerPickerHandler);
assertThat(channelHandlerCtx).isNotNull(); // HandlerPickerHandler still there
try {
channel.checkException();
fail("exception expected!");
} catch (Exception e) {
assertThat(e).isInstanceOf(CertStoreException.class);
assertThat(e).hasMessageThat().contains("No certificate source found!");
}
}
@Test
public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
throws InterruptedException, TimeoutException, ExecutionException {
FakeClock executor = new FakeClock();
CommonCertProviderTestUtils.register(executor);
Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils
.buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE,
CA_PEM_FILE, null, null, null, null);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil
.buildUpstreamTlsContext("google_cloud_private_spiffe-client", true);
SslContextProviderSupplier sslContextProviderSupplier =
new SslContextProviderSupplier(upstreamTlsContext,
new TlsContextManagerImpl(bootstrapInfoForClient));
SecurityProtocolNegotiators.ClientSdsHandler clientSdsHandler =
new SecurityProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier);
pipeline.addLast(clientSdsHandler);
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertNotNull(channelHandlerCtx); // non-null since we just added it
// kick off protocol negotiation.
pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
final SettableFuture<Object> future = SettableFuture.create();
sslContextProviderSupplier
.updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) {
@Override
public void updateSecret(SslContext sslContext) {
future.set(sslContext);
}
@Override
protected void onException(Throwable throwable) {
future.set(throwable);
}
});
executor.runDueTasks();
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
Object fromFuture = future.get(5, TimeUnit.SECONDS);
assertThat(fromFuture).isInstanceOf(SslContext.class);
channel.runPendingTasks();
channelHandlerCtx = pipeline.context(clientSdsHandler);
assertThat(channelHandlerCtx).isNull();
Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
pipeline.fireUserEventTriggered(sslEvent);
channel.runPendingTasks(); // need this for tasks to execute on eventLoop
assertTrue(channel.isOpen());
CommonCertProviderTestUtils.register0();
}
private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
FakeGrpcHttp2ConnectionHandler(
ChannelPromise channelUnused,
Http2ConnectionDecoder decoder,
Http2ConnectionEncoder encoder,
Http2Settings initialSettings) {
super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger());
}
static FakeGrpcHttp2ConnectionHandler newHandler() {
DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
DefaultHttp2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
DefaultHttp2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader());
Http2Settings settings = new Http2Settings();
return new FakeGrpcHttp2ConnectionHandler(
/*channelUnused=*/ null, decoder, encoder, settings);
}
@Override
public String getAuthority() {
return "authority";
}
}
}