blob: 627e52a6dbc5c85066bdc0a5054b88354e525fed [file] [log] [blame]
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Matchers;
/** Unit tests for {@link AltsTsiHandshaker}. */
@RunWith(JUnit4.class)
public class AltsTsiHandshakerTest {
private static final String TEST_KEY_DATA = "super secret 123";
private static final String TEST_APPLICATION_PROTOCOL = "grpc";
private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
private static final String TEST_CLIENT_SERVICE_ACCOUNT = "client@developer.gserviceaccount.com";
private static final String TEST_SERVER_SERVICE_ACCOUNT = "server@developer.gserviceaccount.com";
private static final int OUT_FRAME_SIZE = 100;
private static final int TRANSPORT_BUFFER_SIZE = 200;
private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
private static final int TEST_MAX_RPC_VERSION_MINOR = 2;
private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
private static final RpcProtocolVersions TEST_RPC_PROTOCOL_VERSIONS =
RpcProtocolVersions.newBuilder()
.setMaxRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MAX_RPC_VERSION_MAJOR)
.setMinor(TEST_MAX_RPC_VERSION_MINOR)
.build())
.setMinRpcVersion(
RpcProtocolVersions.Version.newBuilder()
.setMajor(TEST_MIN_RPC_VERSION_MAJOR)
.setMinor(TEST_MIN_RPC_VERSION_MINOR)
.build())
.build();
private AltsHandshakerClient mockClient;
private AltsHandshakerClient mockServer;
private AltsTsiHandshaker handshakerClient;
private AltsTsiHandshaker handshakerServer;
@Before
public void setUp() throws Exception {
mockClient = mock(AltsHandshakerClient.class);
mockServer = mock(AltsHandshakerClient.class);
handshakerClient = new AltsTsiHandshaker(true, mockClient);
handshakerServer = new AltsTsiHandshaker(false, mockServer);
}
private HandshakerResult getHandshakerResult(boolean isClient) {
HandshakerResult.Builder builder =
HandshakerResult.newBuilder()
.setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
.setRecordProtocol(TEST_RECORD_PROTOCOL)
.setKeyData(ByteString.copyFromUtf8(TEST_KEY_DATA))
.setPeerRpcVersions(TEST_RPC_PROTOCOL_VERSIONS);
if (isClient) {
builder.setPeerIdentity(
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
builder.setLocalIdentity(
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
} else {
builder.setPeerIdentity(
Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
builder.setLocalIdentity(
Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
}
return builder.build();
}
@Test
public void processBytesFromPeerFalseStart() throws Exception {
verify(mockClient, never()).startClientHandshake();
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
verify(mockClient, never()).next(Matchers.<ByteBuffer>any());
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServer() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
transportBuffer.position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(false);
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
// Expect processBytesFromPeer return False, because more data are needed from the peer.
transportBuffer.position(transportBuffer.limit());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);
assertFalse(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerStartServerFinished() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
// Mock handshake complete after processBytesFromPeer.
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(true);
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerNoBytesConsumed() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
verify(mockServer, never()).startClientHandshake();
verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
when(mockServer.isFinished()).thenReturn(false);
try {
assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshaker did not consume any bytes.", expected.getMessage());
}
}
@Test
public void processBytesFromPeerClientNext() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(false);
handshakerClient.getBytesToSendToPeer(transportBuffer);
transportBuffer.position(transportBuffer.limit());
assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void processBytesFromPeerClientNextFinished() throws Exception {
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(true);
handshakerClient.getBytesToSendToPeer(transportBuffer);
assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
}
@Test
public void extractPeerFailure() throws Exception {
when(mockClient.isFinished()).thenReturn(false);
try {
handshakerClient.extractPeer();
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshake is not complete.", expected.getMessage());
}
}
@Test
public void extractPeerObjectFailure() throws Exception {
when(mockClient.isFinished()).thenReturn(false);
try {
handshakerClient.extractPeerObject();
fail("Expected IllegalStateException");
} catch (IllegalStateException expected) {
assertEquals("Handshake is not complete.", expected.getMessage());
}
}
@Test
public void extractClientPeerSuccess() throws Exception {
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
when(mockClient.startClientHandshake()).thenReturn(outputFrame);
when(mockClient.isFinished()).thenReturn(true);
when(mockClient.getResult()).thenReturn(getHandshakerResult(/* isClient = */ true));
handshakerClient.getBytesToSendToPeer(transportBuffer);
TsiPeer clientPeer = handshakerClient.extractPeer();
assertEquals(1, clientPeer.getProperties().size());
assertEquals(
TEST_SERVER_SERVICE_ACCOUNT,
clientPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
AltsAuthContext clientContext = (AltsAuthContext) handshakerClient.extractPeerObject();
assertEquals(TEST_APPLICATION_PROTOCOL, clientContext.getApplicationProtocol());
assertEquals(TEST_RECORD_PROTOCOL, clientContext.getRecordProtocol());
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, clientContext.getPeerServiceAccount());
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, clientContext.getLocalServiceAccount());
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, clientContext.getPeerRpcVersions());
}
@Test
public void extractServerPeerSuccess() throws Exception {
ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
when(mockServer.startServerHandshake(Matchers.<ByteBuffer>any())).thenReturn(outputFrame);
when(mockServer.isFinished()).thenReturn(true);
when(mockServer.getResult()).thenReturn(getHandshakerResult(/* isClient = */ false));
handshakerServer.processBytesFromPeer(transportBuffer);
handshakerServer.getBytesToSendToPeer(transportBuffer);
TsiPeer serverPeer = handshakerServer.extractPeer();
assertEquals(1, serverPeer.getProperties().size());
assertEquals(
TEST_CLIENT_SERVICE_ACCOUNT,
serverPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
AltsAuthContext serverContext = (AltsAuthContext) handshakerServer.extractPeerObject();
assertEquals(TEST_APPLICATION_PROTOCOL, serverContext.getApplicationProtocol());
assertEquals(TEST_RECORD_PROTOCOL, serverContext.getRecordProtocol());
assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, serverContext.getPeerServiceAccount());
assertEquals(TEST_SERVER_SERVICE_ACCOUNT, serverContext.getLocalServiceAccount());
assertEquals(TEST_RPC_PROTOCOL_VERSIONS, serverContext.getPeerRpcVersions());
}
}