| /* |
| * Copyright 2018 The gRPC Authors |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package io.grpc.alts.internal; |
| |
| import static com.google.common.base.Preconditions.checkNotNull; |
| import static com.google.common.base.Preconditions.checkState; |
| |
| import io.grpc.alts.internal.TsiFrameProtector.Consumer; |
| import io.netty.buffer.ByteBuf; |
| import io.netty.channel.ChannelException; |
| import io.netty.channel.ChannelHandlerContext; |
| import io.netty.channel.ChannelOutboundHandler; |
| import io.netty.channel.ChannelPromise; |
| import io.netty.channel.PendingWriteQueue; |
| import io.netty.handler.codec.ByteToMessageDecoder; |
| import java.net.SocketAddress; |
| import java.security.GeneralSecurityException; |
| import java.util.ArrayList; |
| import java.util.List; |
| import java.util.logging.Level; |
| import java.util.logging.Logger; |
| |
| /** |
| * Encrypts and decrypts TSI Frames. Writes are buffered here until {@link #flush} is called. Writes |
| * must not be made before the TSI handshake is complete. |
| */ |
| public final class TsiFrameHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { |
| |
| private static final Logger logger = Logger.getLogger(TsiFrameHandler.class.getName()); |
| |
| private TsiFrameProtector protector; |
| private PendingWriteQueue pendingUnprotectedWrites; |
| private boolean closeInitiated; |
| |
| public TsiFrameHandler(TsiFrameProtector protector) { |
| this.protector = checkNotNull(protector, "protector"); |
| } |
| |
| @Override |
| public void handlerAdded(ChannelHandlerContext ctx) throws Exception { |
| super.handlerAdded(ctx); |
| assert pendingUnprotectedWrites == null; |
| pendingUnprotectedWrites = new PendingWriteQueue(checkNotNull(ctx)); |
| } |
| |
| @Override |
| protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { |
| checkState(protector != null, "decode() called after close()"); |
| protector.unprotect(in, out, ctx.alloc()); |
| } |
| |
| @Override |
| @SuppressWarnings("FutureReturnValueIgnored") // for setSuccess |
| public void write(ChannelHandlerContext ctx, Object message, ChannelPromise promise) { |
| if (protector == null) { |
| promise.setFailure(new IllegalStateException("write() called after close()")); |
| return; |
| } |
| ByteBuf msg = (ByteBuf) message; |
| if (!msg.isReadable()) { |
| // Nothing to encode. |
| promise.setSuccess(); |
| return; |
| } |
| |
| // Just add the message to the pending queue. We'll write it on the next flush. |
| pendingUnprotectedWrites.add(msg, promise); |
| } |
| |
| @Override |
| public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { |
| if (pendingUnprotectedWrites != null && !pendingUnprotectedWrites.isEmpty()) { |
| pendingUnprotectedWrites.removeAndFailAll( |
| new ChannelException("Pending write on removal of TSI handler")); |
| } |
| destroyProtector(); |
| } |
| |
| @Override |
| public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { |
| doClose(ctx); |
| ctx.disconnect(promise); |
| } |
| |
| @Override |
| public void close(ChannelHandlerContext ctx, ChannelPromise promise) { |
| doClose(ctx); |
| ctx.close(promise); |
| } |
| |
| private void doClose(ChannelHandlerContext ctx) { |
| if (closeInitiated) { |
| return; |
| } |
| closeInitiated = true; |
| try { |
| // flush any remaining writes before close |
| if (!pendingUnprotectedWrites.isEmpty()) { |
| flush(ctx); |
| } |
| } catch (GeneralSecurityException e) { |
| logger.log(Level.FINE, "Ignored error on flush before close", e); |
| } finally { |
| pendingUnprotectedWrites = null; |
| destroyProtector(); |
| } |
| } |
| |
| @Override |
| @SuppressWarnings("FutureReturnValueIgnored") // for aggregatePromise.doneAllocatingPromises |
| public void flush(final ChannelHandlerContext ctx) throws GeneralSecurityException { |
| if (pendingUnprotectedWrites.isEmpty()) { |
| // Return early if there's nothing to write. Otherwise protector.protectFlush() below may |
| // not check for "no-data" and go on writing the 0-byte "data" to the socket with the |
| // protection framing. |
| return; |
| } |
| // Flushes can happen after close, but only when there are no pending writes. |
| checkState(protector != null, "flush() called after close()"); |
| final ProtectedPromise aggregatePromise = |
| new ProtectedPromise(ctx.channel(), ctx.executor(), pendingUnprotectedWrites.size()); |
| List<ByteBuf> bufs = new ArrayList<>(pendingUnprotectedWrites.size()); |
| |
| // Drain the unprotected writes. |
| while (!pendingUnprotectedWrites.isEmpty()) { |
| ByteBuf in = (ByteBuf) pendingUnprotectedWrites.current(); |
| bufs.add(in.retain()); |
| // Remove and release the buffer and add its promise to the aggregate. |
| aggregatePromise.addUnprotectedPromise(pendingUnprotectedWrites.remove()); |
| } |
| |
| final class ProtectedFrameWriteFlusher implements Consumer<ByteBuf> { |
| |
| @Override |
| public void accept(ByteBuf byteBuf) { |
| ctx.writeAndFlush(byteBuf, aggregatePromise.newPromise()); |
| } |
| } |
| |
| protector.protectFlush(bufs, new ProtectedFrameWriteFlusher(), ctx.alloc()); |
| // We're done writing, start the flow of promise events. |
| aggregatePromise.doneAllocatingPromises(); |
| } |
| |
| // Only here to fulfill ChannelOutboundHandler |
| @Override |
| public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { |
| ctx.bind(localAddress, promise); |
| } |
| |
| // Only here to fulfill ChannelOutboundHandler |
| @Override |
| public void connect( |
| ChannelHandlerContext ctx, |
| SocketAddress remoteAddress, |
| SocketAddress localAddress, |
| ChannelPromise promise) { |
| ctx.connect(remoteAddress, localAddress, promise); |
| } |
| |
| // Only here to fulfill ChannelOutboundHandler |
| @Override |
| public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { |
| ctx.deregister(promise); |
| } |
| |
| // Only here to fulfill ChannelOutboundHandler |
| @Override |
| public void read(ChannelHandlerContext ctx) { |
| ctx.read(); |
| } |
| |
| private void destroyProtector() { |
| if (protector != null) { |
| try { |
| protector.destroy(); |
| } finally { |
| protector = null; |
| } |
| } |
| } |
| } |