blob: a71b6c774262574a587ca653cafc4bf6fb0d71d5 [file] [log] [blame]
/*
* Copyright 2014 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc;
import static com.google.common.truth.Truth.assertThat;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.isA;
import static org.mockito.Matchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import io.grpc.ClientInterceptors.CheckedForwardingClientCall;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.testing.TestMethodDescriptors;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Unit tests for {@link ClientInterceptors}. */
@RunWith(JUnit4.class)
public class ClientInterceptorsTest {
@Mock
private Channel channel;
private BaseClientCall call = new BaseClientCall();
private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();
/**
* Sets up mocks.
*/
@Before public void setUp() {
MockitoAnnotations.initMocks(this);
when(channel.newCall(
Mockito.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
.thenReturn(call);
}
@Test(expected = NullPointerException.class)
public void npeForNullChannel() {
ClientInterceptors.intercept(null, Arrays.<ClientInterceptor>asList());
}
@Test(expected = NullPointerException.class)
public void npeForNullInterceptorList() {
ClientInterceptors.intercept(channel, (List<ClientInterceptor>) null);
}
@Test(expected = NullPointerException.class)
public void npeForNullInterceptor() {
ClientInterceptors.intercept(channel, (ClientInterceptor) null);
}
@Test
public void noop() {
assertSame(channel, ClientInterceptors.intercept(channel, Arrays.<ClientInterceptor>asList()));
}
@Test
public void channelAndInterceptorCalled() {
ClientInterceptor interceptor =
mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
CallOptions callOptions = CallOptions.DEFAULT;
// First call
assertSame(call, intercepted.newCall(method, callOptions));
verify(channel).newCall(same(method), same(callOptions));
verify(interceptor).interceptCall(same(method), same(callOptions), Mockito.<Channel>any());
verifyNoMoreInteractions(channel, interceptor);
// Second call
assertSame(call, intercepted.newCall(method, callOptions));
verify(channel, times(2)).newCall(same(method), same(callOptions));
verify(interceptor, times(2))
.interceptCall(same(method), same(callOptions), Mockito.<Channel>any());
verifyNoMoreInteractions(channel, interceptor);
}
@Test
public void callNextTwice() {
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
// Calling next twice is permitted, although should only rarely be useful.
assertSame(call, next.newCall(method, callOptions));
return next.newCall(method, callOptions);
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT));
verifyNoMoreInteractions(channel);
}
@Test
public void ordered() {
final List<String> order = new ArrayList<>();
channel = new Channel() {
@SuppressWarnings("unchecked")
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
order.add("channel");
return (ClientCall<ReqT, RespT>) call;
}
@Override
public String authority() {
return null;
}
};
ClientInterceptor interceptor1 = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
order.add("i1");
return next.newCall(method, callOptions);
}
};
ClientInterceptor interceptor2 = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
order.add("i2");
return next.newCall(method, callOptions);
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2);
assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
assertEquals(Arrays.asList("i2", "i1", "channel"), order);
}
@Test
public void orderedForward() {
final List<String> order = new ArrayList<>();
channel = new Channel() {
@SuppressWarnings("unchecked")
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
order.add("channel");
return (ClientCall<ReqT, RespT>) call;
}
@Override
public String authority() {
return null;
}
};
ClientInterceptor interceptor1 = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
order.add("i1");
return next.newCall(method, callOptions);
}
};
ClientInterceptor interceptor2 = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
order.add("i2");
return next.newCall(method, callOptions);
}
};
Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2);
assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
assertEquals(Arrays.asList("i1", "i2", "channel"), order);
}
@Test
public void callOptions() {
final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS);
final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS);
assertNotSame(initialCallOptions, newCallOptions);
ClientInterceptor interceptor =
mock(ClientInterceptor.class, delegatesTo(new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
return next.newCall(method, newCallOptions);
}
}));
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
intercepted.newCall(method, initialCallOptions);
verify(interceptor).interceptCall(
same(method), same(initialCallOptions), Mockito.<Channel>any());
verify(channel).newCall(same(method), same(newCallOptions));
}
@Test
public void addOutboundHeaders() {
final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER);
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new SimpleForwardingClientCall<ReqT, RespT>(call) {
@Override
public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
headers.put(credKey, "abcd");
super.start(responseListener, headers);
}
};
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
@SuppressWarnings("unchecked")
ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
// start() on the intercepted call will eventually reach the call created by the real channel
interceptedCall.start(listener, new Metadata());
// The headers passed to the real channel call will contain the information inserted by the
// interceptor.
assertSame(listener, call.listener);
assertEquals("abcd", call.headers.get(credKey));
}
@Test
public void examineInboundHeaders() {
final List<Metadata> examinedHeaders = new ArrayList<>();
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new SimpleForwardingClientCall<ReqT, RespT>(call) {
@Override
public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) {
@Override
public void onHeaders(Metadata headers) {
examinedHeaders.add(headers);
super.onHeaders(headers);
}
}, headers);
}
};
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
@SuppressWarnings("unchecked")
ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
interceptedCall.start(listener, new Metadata());
// Capture the underlying call listener that will receive headers from the transport.
Metadata inboundHeaders = new Metadata();
// Simulate that a headers arrives on the underlying call listener.
call.listener.onHeaders(inboundHeaders);
assertThat(examinedHeaders).contains(inboundHeaders);
}
@Test
public void normalCall() {
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new SimpleForwardingClientCall<ReqT, RespT>(call) { };
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
assertNotSame(call, interceptedCall);
@SuppressWarnings("unchecked")
ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
Metadata headers = new Metadata();
interceptedCall.start(listener, headers);
assertSame(listener, call.listener);
assertSame(headers, call.headers);
interceptedCall.sendMessage(null /*request*/);
assertThat(call.messages).containsExactly((Void) null /*request*/);
interceptedCall.halfClose();
assertTrue(call.halfClosed);
interceptedCall.request(1);
assertThat(call.requests).containsExactly(1);
}
@Test
public void exceptionInStart() {
final Exception error = new Exception("emulated error");
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
return new CheckedForwardingClientCall<ReqT, RespT>(call) {
@Override
protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadata headers)
throws Exception {
throw error;
// delegate().start will not be called
}
};
}
};
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
@SuppressWarnings("unchecked")
ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
assertNotSame(call, interceptedCall);
interceptedCall.start(listener, new Metadata());
interceptedCall.sendMessage(null /*request*/);
interceptedCall.halfClose();
interceptedCall.request(1);
call.done = true;
ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
verify(listener).onClose(captor.capture(), any(Metadata.class));
assertSame(error, captor.getValue().getCause());
// Make sure nothing bad happens after the exception.
ClientCall<?, ?> noop = ((CheckedForwardingClientCall<?, ?>)interceptedCall).delegate();
// Should not throw, even on bad input
noop.cancel("Cancel for test", null);
noop.start(null, null);
noop.request(-1);
noop.halfClose();
noop.sendMessage(null);
assertFalse(noop.isReady());
}
@Test
public void authorityIsDelegated() {
ClientInterceptor interceptor = new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
return next.newCall(method, callOptions);
}
};
when(channel.authority()).thenReturn("auth");
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
assertEquals("auth", intercepted.authority());
}
@Test
public void customOptionAccessible() {
CallOptions.Key<String> customOption = CallOptions.Key.create("custom");
CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
ClientInterceptor interceptor =
mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));
Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
assertSame(call, intercepted.newCall(method, callOptions));
verify(channel).newCall(same(method), same(callOptions));
verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
assertSame("value", passedOptions.getValue().getOption(customOption));
}
private static class NoopInterceptor implements ClientInterceptor {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions, Channel next) {
return next.newCall(method, callOptions);
}
}
private static class BaseClientCall extends ClientCall<String, Integer> {
private boolean started;
private boolean done;
private ClientCall.Listener<Integer> listener;
private Metadata headers;
private List<Integer> requests = new ArrayList<>();
private List<String> messages = new ArrayList<>();
private boolean halfClosed;
private Throwable cancelCause;
private String cancelMessage;
@Override
public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
checkNotDone();
started = true;
this.listener = listener;
this.headers = headers;
}
@Override
public void request(int numMessages) {
checkNotDone();
checkStarted();
requests.add(numMessages);
}
@Override
public void cancel(String message, Throwable cause) {
checkNotDone();
this.cancelMessage = message;
this.cancelCause = cause;
}
@Override
public void halfClose() {
checkNotDone();
checkStarted();
this.halfClosed = true;
}
@Override
public void sendMessage(String message) {
checkNotDone();
checkStarted();
messages.add(message);
}
private void checkNotDone() {
if (done) {
throw new IllegalStateException("no more methods should be called");
}
}
private void checkStarted() {
if (!started) {
throw new IllegalStateException("should have called start");
}
}
}
}