blob: f39a1a24628e9c704b7b932255964f984b366ef9 [file] [log] [blame]
/*
* Copyright (c) 2016 Mockito contributors
* This program is made available under the terms of the MIT License.
*/
package org.mockito.internal.creation.bytebuddy;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.method.MethodDescription;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.dynamic.scaffold.MethodGraph;
import net.bytebuddy.implementation.bind.annotation.Argument;
import net.bytebuddy.implementation.bind.annotation.This;
import net.bytebuddy.implementation.bytecode.assign.Assigner;
import org.mockito.exceptions.base.MockitoException;
import org.mockito.internal.debugging.LocationImpl;
import org.mockito.internal.exceptions.stacktrace.ConditionalStackTraceFilter;
import org.mockito.internal.invocation.RealMethod;
import org.mockito.internal.invocation.SerializableMethod;
import org.mockito.internal.invocation.mockref.MockReference;
import org.mockito.internal.invocation.mockref.MockWeakReference;
import org.mockito.internal.util.concurrent.WeakConcurrentMap;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.ref.SoftReference;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
public class MockMethodAdvice extends MockMethodDispatcher {
final WeakConcurrentMap<Object, MockMethodInterceptor> interceptors;
private final String identifier;
private final SelfCallInfo selfCallInfo = new SelfCallInfo();
private final MethodGraph.Compiler compiler = MethodGraph.Compiler.Default.forJavaHierarchy();
private final WeakConcurrentMap<Class<?>, SoftReference<MethodGraph>> graphs
= new WeakConcurrentMap.WithInlinedExpunction<Class<?>, SoftReference<MethodGraph>>();
public MockMethodAdvice(WeakConcurrentMap<Object, MockMethodInterceptor> interceptors, String identifier) {
this.interceptors = interceptors;
this.identifier = identifier;
}
@SuppressWarnings("unused")
@Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
private static Callable<?> enter(@Identifier String identifier,
@Advice.This Object mock,
@Advice.Origin Method origin,
@Advice.AllArguments Object[] arguments) throws Throwable {
MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, mock);
if (dispatcher == null || !dispatcher.isMocked(mock) || dispatcher.isOverridden(mock, origin)) {
return null;
} else {
return dispatcher.handle(mock, origin, arguments);
}
}
@SuppressWarnings({"unused", "UnusedAssignment"})
@Advice.OnMethodExit
private static void exit(@Advice.Return(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object returned,
@Advice.Enter Callable<?> mocked) throws Throwable {
if (mocked != null) {
returned = mocked.call();
}
}
static Throwable hideRecursiveCall(Throwable throwable, int current, Class<?> targetType) {
try {
StackTraceElement[] stack = throwable.getStackTrace();
int skip = 0;
StackTraceElement next;
do {
next = stack[stack.length - current - ++skip];
} while (!next.getClassName().equals(targetType.getName()));
int top = stack.length - current - skip;
StackTraceElement[] cleared = new StackTraceElement[stack.length - skip];
System.arraycopy(stack, 0, cleared, 0, top);
System.arraycopy(stack, top + skip, cleared, top, current);
throwable.setStackTrace(cleared);
return throwable;
} catch (RuntimeException ignored) {
// This should not happen unless someone instrumented or manipulated exception stack traces.
return throwable;
}
}
@Override
public Callable<?> handle(Object instance, Method origin, Object[] arguments) throws Throwable {
MockMethodInterceptor interceptor = interceptors.get(instance);
if (interceptor == null) {
return null;
}
RealMethod realMethod;
if (instance instanceof Serializable) {
realMethod = new SerializableRealMethodCall(identifier, origin, instance, arguments);
} else {
realMethod = new RealMethodCall(selfCallInfo, origin, instance, arguments);
}
Throwable t = new Throwable();
t.setStackTrace(skipInlineMethodElement(t.getStackTrace()));
return new ReturnValueWrapper(interceptor.doIntercept(instance,
origin,
arguments,
realMethod,
new LocationImpl(t)));
}
@Override
public boolean isMock(Object instance) {
// We need to exclude 'interceptors.target' explicitly to avoid a recursive check on whether
// the map is a mock object what requires reading from the map.
return instance != interceptors.target && interceptors.containsKey(instance);
}
@Override
public boolean isMocked(Object instance) {
return selfCallInfo.checkSuperCall(instance) && isMock(instance);
}
@Override
public boolean isOverridden(Object instance, Method origin) {
SoftReference<MethodGraph> reference = graphs.get(instance.getClass());
MethodGraph methodGraph = reference == null ? null : reference.get();
if (methodGraph == null) {
methodGraph = compiler.compile(new TypeDescription.ForLoadedType(instance.getClass()));
graphs.put(instance.getClass(), new SoftReference<MethodGraph>(methodGraph));
}
MethodGraph.Node node = methodGraph.locate(new MethodDescription.ForLoadedMethod(origin).asSignatureToken());
return !node.getSort().isResolved() || !node.getRepresentative().asDefined().getDeclaringType().represents(origin.getDeclaringClass());
}
private static class RealMethodCall implements RealMethod {
private final SelfCallInfo selfCallInfo;
private final Method origin;
private final MockWeakReference<Object> instanceRef;
private final Object[] arguments;
private RealMethodCall(SelfCallInfo selfCallInfo, Method origin, Object instance, Object[] arguments) {
this.selfCallInfo = selfCallInfo;
this.origin = origin;
this.instanceRef = new MockWeakReference<Object>(instance);
this.arguments = arguments;
}
@Override
public boolean isInvokable() {
return true;
}
@Override
public Object invoke() throws Throwable {
if (!Modifier.isPublic(origin.getDeclaringClass().getModifiers() & origin.getModifiers())) {
origin.setAccessible(true);
}
selfCallInfo.set(instanceRef.get());
return tryInvoke(origin, instanceRef.get(), arguments);
}
}
private static class SerializableRealMethodCall implements RealMethod {
private final String identifier;
private final SerializableMethod origin;
private final MockReference<Object> instanceRef;
private final Object[] arguments;
private SerializableRealMethodCall(String identifier, Method origin, Object instance, Object[] arguments) {
this.origin = new SerializableMethod(origin);
this.identifier = identifier;
this.instanceRef = new MockWeakReference<Object>(instance);
this.arguments = arguments;
}
@Override
public boolean isInvokable() {
return true;
}
@Override
public Object invoke() throws Throwable {
Method method = origin.getJavaMethod();
if (!Modifier.isPublic(method.getDeclaringClass().getModifiers() & method.getModifiers())) {
method.setAccessible(true);
}
MockMethodDispatcher mockMethodDispatcher = MockMethodDispatcher.get(identifier, instanceRef.get());
if (!(mockMethodDispatcher instanceof MockMethodAdvice)) {
throw new MockitoException("Unexpected dispatcher for advice-based super call");
}
Object previous = ((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.replace(instanceRef.get());
try {
return tryInvoke(method, instanceRef.get(), arguments);
} finally {
((MockMethodAdvice) mockMethodDispatcher).selfCallInfo.set(previous);
}
}
}
private static Object tryInvoke(Method origin, Object instance, Object[] arguments) throws Throwable {
try {
return origin.invoke(instance, arguments);
} catch (InvocationTargetException exception) {
Throwable cause = exception.getCause();
new ConditionalStackTraceFilter().filter(hideRecursiveCall(cause, new Throwable().getStackTrace().length, origin.getDeclaringClass()));
throw cause;
}
}
// With inline mocking, mocks for concrete classes are not subclassed, so elements of the stubbing methods are not filtered out.
// Therefore, if the method is inlined, skip the element.
private static StackTraceElement[] skipInlineMethodElement(StackTraceElement[] elements) {
List<StackTraceElement> list = new ArrayList<StackTraceElement>(elements.length);
for (int i = 0; i < elements.length; i++) {
StackTraceElement element = elements[i];
list.add(element);
if (element.getClassName().equals(MockMethodAdvice.class.getName()) && element.getMethodName().equals("handle")) {
// If the current element is MockMethodAdvice#handle(), the next is assumed to be an inlined method.
i++;
}
}
return list.toArray(new StackTraceElement[list.size()]);
}
private static class ReturnValueWrapper implements Callable<Object> {
private final Object returned;
private ReturnValueWrapper(Object returned) {
this.returned = returned;
}
@Override
public Object call() {
return returned;
}
}
private static class SelfCallInfo extends ThreadLocal<Object> {
Object replace(Object value) {
Object current = get();
set(value);
return current;
}
boolean checkSuperCall(Object value) {
if (value == get()) {
set(null);
return false;
} else {
return true;
}
}
}
@Retention(RetentionPolicy.RUNTIME)
@interface Identifier {
}
static class ForHashCode {
@SuppressWarnings("unused")
@Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
private static boolean enter(@Identifier String id,
@Advice.This Object self) {
MockMethodDispatcher dispatcher = MockMethodDispatcher.get(id, self);
return dispatcher != null && dispatcher.isMock(self);
}
@SuppressWarnings({"unused", "UnusedAssignment"})
@Advice.OnMethodExit
private static void enter(@Advice.This Object self,
@Advice.Return(readOnly = false) int hashCode,
@Advice.Enter boolean skipped) {
if (skipped) {
hashCode = System.identityHashCode(self);
}
}
}
static class ForEquals {
@SuppressWarnings("unused")
@Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class)
private static boolean enter(@Identifier String identifier,
@Advice.This Object self) {
MockMethodDispatcher dispatcher = MockMethodDispatcher.get(identifier, self);
return dispatcher != null && dispatcher.isMock(self);
}
@SuppressWarnings({"unused", "UnusedAssignment"})
@Advice.OnMethodExit
private static void enter(@Advice.This Object self,
@Advice.Argument(0) Object other,
@Advice.Return(readOnly = false) boolean equals,
@Advice.Enter boolean skipped) {
if (skipped) {
equals = self == other;
}
}
}
public static class ForReadObject {
@SuppressWarnings("unused")
public static void doReadObject(@Identifier String identifier,
@This MockAccess thiz,
@Argument(0) ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
objectInputStream.defaultReadObject();
MockMethodAdvice mockMethodAdvice = (MockMethodAdvice) MockMethodDispatcher.get(identifier, thiz);
if (mockMethodAdvice != null) {
mockMethodAdvice.interceptors.put(thiz, thiz.getMockitoInterceptor());
}
}
}
}