blob: 70b30fe2f1bd4fe1fb7d4801d3bcb9d8898f7f2e [file] [log] [blame]
/*
* Copyright 2018 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.alts.internal;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import io.grpc.alts.internal.TsiFrameProtector.Consumer;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelPromise;
import io.netty.channel.PendingWriteQueue;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes
* must not be made before the TSI handshake is complete.
*/
public final class TsiFrameHandler extends ByteToMessageDecoder implements ChannelOutboundHandler {
private static final Logger logger = Logger.getLogger(TsiFrameHandler.class.getName());
private TsiFrameProtector protector;
private PendingWriteQueue pendingUnprotectedWrites;
private boolean closeInitiated;
public TsiFrameHandler(TsiFrameProtector protector) {
this.protector = checkNotNull(protector, "protector");
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
assert pendingUnprotectedWrites == null;
pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx));
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
checkState(protector != null, "decode() called after close()");
protector.unprotect(in, out, ctx.alloc());
}
@Override
@SuppressWarnings("FutureReturnValueIgnored") // for setSuccess
public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) {
if (protector == null) {
promise.setFailure(new IllegalStateException("write() called after close()"));
return;
}
ByteBuf msg = (ByteBuf) message;
if (!msg.isReadable()) {
// Nothing to encode.
promise.setSuccess();
return;
}
// Just add the message to the pending queue. We'll write it on the next flush.
pendingUnprotectedWrites.add(msg, promise);
}
@Override
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (pendingUnprotectedWrites != null && !pendingUnprotectedWrites.isEmpty()) {
pendingUnprotectedWrites.removeAndFailAll(
new ChannelException("Pending write on removal of TSI handler"));
}
destroyProtector();
}
@Override
public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) {
doClose(ctx);
ctx.disconnect(promise);
}
@Override
public void close(ChannelHandlerContext ctx, ChannelPromise promise) {
doClose(ctx);
ctx.close(promise);
}
private void doClose(ChannelHandlerContext ctx) {
if (closeInitiated) {
return;
}
closeInitiated = true;
try {
// flush any remaining writes before close
if (!pendingUnprotectedWrites.isEmpty()) {
flush(ctx);
}
} catch (GeneralSecurityException e) {
logger.log(Level.FINE, "Ignored error on flush before close", e);
} finally {
pendingUnprotectedWrites = null;
destroyProtector();
}
}
@Override
@SuppressWarnings("FutureReturnValueIgnored") // for aggregatePromise.doneAllocatingPromises
public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException {
if (pendingUnprotectedWrites.isEmpty()) {
// Return early if there's nothing to write. Otherwise protector.protectFlush() below may
// not check for "no-data" and go on writing the 0-byte "data" to the socket with the
// protection framing.
return;
}
// Flushes can happen after close, but only when there are no pending writes.
checkState(protector != null, "flush() called after close()");
final ProtectedPromise aggregatePromise =
new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size());
List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size());
// Drain the unprotected writes.
while (!pendingUnprotectedWrites.isEmpty()) {
ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current();
bufs.add(in.retain());
// Remove and release the buffer and add its promise to the aggregate.
aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove());
}
final class ProtectedFrameWriteFlusher implements Consumer<ByteBuf> {
@Override
public void accept(ByteBuf byteBuf) {
ctx.writeAndFlush(byteBuf, aggregatePromise.newPromise());
}
}
protector.protectFlush(bufs, new ProtectedFrameWriteFlusher(), ctx.alloc());
// We're done writing, start the flow of promise events.
aggregatePromise.doneAllocatingPromises();
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) {
ctx.bind(localAddress, promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void connect(
ChannelHandlerContext ctx,
SocketAddress remoteAddress,
SocketAddress localAddress,
ChannelPromise promise) {
ctx.connect(remoteAddress, localAddress, promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) {
ctx.deregister(promise);
}
// Only here to fulfill ChannelOutboundHandler
@Override
public void read(ChannelHandlerContext ctx) {
ctx.read();
}
private void destroyProtector() {
if (protector != null) {
try {
protector.destroy();
} finally {
protector = null;
}
}
}
}