blob: a6832ad8de28c597ca0e4bcc95a49e954c2f7beb [file] [log] [blame]
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static org.junit.Assert.assertEquals;
import com.google.common.testing.GcFinalization;
import io.grpc.alts.internal.ByteBufTestUtils.RegisterRef;
import io.grpc.alts.internal.TsiTest.Handshakers;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import io.netty.util.ResourceLeakDetector;
import io.netty.util.ResourceLeakDetector.Level;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link AltsTsiHandshaker}. */
@RunWith(JUnit4.class)
public class AltsTsiTest {
private static final int OVERHEAD =
FakeChannelCrypter.getTagBytes() + AltsTsiFrameProtector.getHeaderBytes();
private final List<ReferenceCounted> references = new ArrayList<>();
private AltsHandshakerClient client;
private AltsHandshakerClient server;
private final RegisterRef ref =
new RegisterRef() {
@Override
public ByteBuf register(ByteBuf buf) {
if (buf != null) {
references.add(buf);
}
return buf;
}
};
@Before
public void setUp() throws Exception {
ResourceLeakDetector.setLevel(Level.PARANOID);
// Use MockAltsHandshakerStub for all the tests.
AltsHandshakerOptions handshakerOptions = new AltsHandshakerOptions(null);
MockAltsHandshakerStub clientStub = new MockAltsHandshakerStub();
MockAltsHandshakerStub serverStub = new MockAltsHandshakerStub();
client = new AltsHandshakerClient(clientStub, handshakerOptions);
server = new AltsHandshakerClient(serverStub, handshakerOptions);
}
@After
public void tearDown() {
for (ReferenceCounted reference : references) {
reference.release();
}
references.clear();
// Increase our chances to detect ByteBuf leaks.
GcFinalization.awaitFullGc();
}
private Handshakers newHandshakers() {
TsiHandshaker clientHandshaker = new AltsTsiHandshaker(true, client);
TsiHandshaker serverHandshaker = new AltsTsiHandshaker(false, server);
return new Handshakers(clientHandshaker, serverHandshaker);
}
@Test
public void verifyHandshakePeer() throws Exception {
Handshakers handshakers = newHandshakers();
TsiTest.performHandshake(TsiTest.getDefaultTransportBufferSize(), handshakers);
TsiPeer clientPeer = handshakers.getClient().extractPeer();
assertEquals(1, clientPeer.getProperties().size());
assertEquals(
MockAltsHandshakerResp.getTestPeerAccount(),
clientPeer.getProperty("service_account").getValue());
TsiPeer serverPeer = handshakers.getServer().extractPeer();
assertEquals(1, serverPeer.getProperties().size());
assertEquals(
MockAltsHandshakerResp.getTestPeerAccount(),
serverPeer.getProperty("service_account").getValue());
}
@Test
public void handshake() throws GeneralSecurityException {
TsiTest.handshakeTest(newHandshakers());
}
@Test
public void handshakeSmallBuffer() throws GeneralSecurityException {
TsiTest.handshakeSmallBufferTest(newHandshakers());
}
@Test
public void pingPong() throws GeneralSecurityException {
TsiTest.pingPongTest(newHandshakers(), ref);
}
@Test
public void pingPongExactFrameSize() throws GeneralSecurityException {
TsiTest.pingPongExactFrameSizeTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallBufferTest(newHandshakers(), ref);
}
@Test
public void pingPongSmallFrame() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void pingPongSmallFrameSmallBuffer() throws GeneralSecurityException {
TsiTest.pingPongSmallFrameSmallBufferTest(OVERHEAD, newHandshakers(), ref);
}
@Test
public void corruptedCounter() throws GeneralSecurityException {
TsiTest.corruptedCounterTest(newHandshakers(), ref);
}
@Test
public void corruptedCiphertext() throws GeneralSecurityException {
TsiTest.corruptedCiphertextTest(newHandshakers(), ref);
}
@Test
public void corruptedTag() throws GeneralSecurityException {
TsiTest.corruptedTagTest(newHandshakers(), ref);
}
@Test
public void reflectedCiphertext() throws GeneralSecurityException {
TsiTest.reflectedCiphertextTest(newHandshakers(), ref);
}
private static class MockAltsHandshakerStub extends AltsHandshakerStub {
private boolean started = false;
@Override
public HandshakerResp send(HandshakerReq req) {
if (started) {
// Expect handshake next message.
if (req.getReqOneofCase().getNumber() != 3) {
return MockAltsHandshakerResp.getErrorResponse();
}
return MockAltsHandshakerResp.getFinishedResponse(req.getNext().getInBytes().size());
} else {
List<String> recordProtocols;
int bytesConsumed = 0;
switch (req.getReqOneofCase().getNumber()) {
case 1:
recordProtocols = req.getClientStart().getRecordProtocolsList();
break;
case 2:
recordProtocols =
req.getServerStart()
.getHandshakeParametersMap()
.get(HandshakeProtocol.ALTS.getNumber())
.getRecordProtocolsList();
bytesConsumed = req.getServerStart().getInBytes().size();
break;
default:
return MockAltsHandshakerResp.getErrorResponse();
}
if (recordProtocols.isEmpty()) {
return MockAltsHandshakerResp.getErrorResponse();
}
started = true;
return MockAltsHandshakerResp.getOkResponse(bytesConsumed);
}
}
@Override
public void close() {}
}
}