blob: 3e999fa63b45566eb1d3ba1aa0e08f1be5f500b6 [file] [log] [blame]
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.List;
/** Performs encryption and decryption with AES-GCM using JCE. All methods are thread-compatible. */
final class AltsChannelCrypter implements ChannelCrypterNetty {
private static final int KEY_LENGTH = AesGcmHkdfAeadCrypter.getKeyLength();
private static final int COUNTER_LENGTH = 12;
// The counter will overflow after 2^64 operations and encryption/decryption will stop working.
private static final int COUNTER_OVERFLOW_LENGTH = 8;
private static final int TAG_LENGTH = 16;
private final AeadCrypter aeadCrypter;
private final byte[] outCounter = new byte[COUNTER_LENGTH];
private final byte[] inCounter = new byte[COUNTER_LENGTH];
private final byte[] oldCounter = new byte[COUNTER_LENGTH];
AltsChannelCrypter(byte[] key, boolean isClient) {
checkArgument(key.length == KEY_LENGTH);
byte[] counter = isClient ? inCounter : outCounter;
counter[counter.length - 1] = (byte) 0x80;
this.aeadCrypter = new AesGcmHkdfAeadCrypter(key);
}
static int getKeyLength() {
return KEY_LENGTH;
}
static int getCounterLength() {
return COUNTER_LENGTH;
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
@Override
public void encrypt(ByteBuf outBuf, List<ByteBuf> plainBufs) throws GeneralSecurityException {
checkArgument(outBuf.nioBufferCount() == 1);
// Copy plaintext buffers into outBuf for in-place encryption on single direct buffer.
ByteBuf plainBuf = outBuf.slice(outBuf.writerIndex(), outBuf.writableBytes());
plainBuf.writerIndex(0);
for (ByteBuf inBuf : plainBufs) {
plainBuf.writeBytes(inBuf);
}
verify(outBuf.writableBytes() == plainBuf.readableBytes() + TAG_LENGTH);
ByteBuffer out = outBuf.internalNioBuffer(outBuf.writerIndex(), outBuf.writableBytes());
ByteBuffer plain = out.duplicate();
plain.limit(out.limit() - TAG_LENGTH);
byte[] counter = incrementOutCounter();
int outPosition = out.position();
aeadCrypter.encrypt(out, plain, counter);
int bytesWritten = out.position() - outPosition;
outBuf.writerIndex(outBuf.writerIndex() + bytesWritten);
verify(!outBuf.isWritable());
}
@Override
public void decrypt(ByteBuf out, ByteBuf tag, List<ByteBuf> ciphertextBufs)
throws GeneralSecurityException {
ByteBuf cipherTextAndTag = out.slice(out.writerIndex(), out.writableBytes());
cipherTextAndTag.writerIndex(0);
for (ByteBuf inBuf : ciphertextBufs) {
cipherTextAndTag.writeBytes(inBuf);
}
cipherTextAndTag.writeBytes(tag);
decrypt(out, cipherTextAndTag);
}
@SuppressWarnings("BetaApi") // verify is stable in Guava
@Override
public void decrypt(ByteBuf out, ByteBuf ciphertextAndTag) throws GeneralSecurityException {
int bytesRead = ciphertextAndTag.readableBytes();
checkArgument(bytesRead == out.writableBytes());
checkArgument(out.nioBufferCount() == 1);
ByteBuffer outBuffer = out.internalNioBuffer(out.writerIndex(), out.writableBytes());
checkArgument(ciphertextAndTag.nioBufferCount() == 1);
ByteBuffer ciphertextAndTagBuffer =
ciphertextAndTag.nioBuffer(ciphertextAndTag.readerIndex(), bytesRead);
byte[] counter = incrementInCounter();
int outPosition = outBuffer.position();
aeadCrypter.decrypt(outBuffer, ciphertextAndTagBuffer, counter);
int bytesWritten = outBuffer.position() - outPosition;
out.writerIndex(out.writerIndex() + bytesWritten);
ciphertextAndTag.readerIndex(out.readerIndex() + bytesRead);
verify(out.writableBytes() == TAG_LENGTH);
}
@Override
public int getSuffixLength() {
return TAG_LENGTH;
}
@Override
public void destroy() {
// no destroy required
}
/** Increments {@code counter}, store the unincremented value in {@code oldCounter}. */
static void incrementCounter(byte[] counter, byte[] oldCounter) throws GeneralSecurityException {
System.arraycopy(counter, 0, oldCounter, 0, counter.length);
int i = 0;
for (; i < COUNTER_OVERFLOW_LENGTH; i++) {
counter[i]++;
if (counter[i] != (byte) 0x00) {
break;
}
}
if (i == COUNTER_OVERFLOW_LENGTH) {
// Restore old counter value to ensure that encrypt and decrypt keep failing.
System.arraycopy(oldCounter, 0, counter, 0, counter.length);
throw new GeneralSecurityException("Counter has overflowed.");
}
}
/** Increments the input counter, returning the previous (unincremented) value. */
private byte[] incrementInCounter() throws GeneralSecurityException {
incrementCounter(inCounter, oldCounter);
return oldCounter;
}
/** Increments the output counter, returning the previous (unincremented) value. */
private byte[] incrementOutCounter() throws GeneralSecurityException {
incrementCounter(outCounter, oldCounter);
return oldCounter;
}
@VisibleForTesting
void incrementInCounterForTesting(int n) throws GeneralSecurityException {
for (int i = 0; i < n; i++) {
incrementInCounter();
}
}
@VisibleForTesting
void incrementOutCounterForTesting(int n) throws GeneralSecurityException {
for (int i = 0; i < n; i++) {
incrementOutCounter();
}
}
}