blob: 9857eeb440e28068e6735577255aa03914a11cd5 [file] [log] [blame]
package com.xtremelabs.robolectric;
import javassist.CannotCompileException;
import javassist.CtClass;
import javassist.CtField;
import javassist.NotFoundException;
import java.lang.reflect.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class ProxyDelegatingHandler implements ClassHandler {
public static final String SHEEP_FIELD_NAME = "__sheep__";
private static ProxyDelegatingHandler singleton;
private Map<String, String> sheepClassMap = new HashMap<String, String>();
private Map<Class, Field> sheepFieldMap = new HashMap<Class, Field>();
public boolean debug = false;
// sorry! it really only makes sense to have one per ClassLoader anyway though [xw/hu]
public static ProxyDelegatingHandler getInstance() {
if (singleton == null) {
singleton = new ProxyDelegatingHandler();
}
return singleton;
}
private ProxyDelegatingHandler() {
}
@Override
public void instrument(CtClass ctClass) {
try {
CtClass objectClass = ctClass.getClassPool().get(Object.class.getName());
try {
ctClass.getField(SHEEP_FIELD_NAME);
} catch (NotFoundException e) {
CtField field = new CtField(objectClass, SHEEP_FIELD_NAME, ctClass);
field.setModifiers(Modifier.PUBLIC);
ctClass.addField(field);
}
} catch (CannotCompileException e) {
throw new RuntimeException(e);
} catch (NotFoundException e) {
throw new RuntimeException(e);
}
}
@Override
public void beforeTest() {
sheepClassMap.clear();
}
@Override
public void afterTest() {
}
public void addProxyClass(Class<?> realClass, Class<?> handlerClass) {
sheepClassMap.put(realClass.getName(), handlerClass.getName());
if (debug) System.out.println("handle " + realClass + " with " + handlerClass);
}
@Override
public Object methodInvoked(Class clazz, String methodName, Object instance, String[] paramTypes, Object[] params) {
InvocationPlan invocationPlan = new InvocationPlan(clazz, methodName, instance, paramTypes);
if (!invocationPlan.prepare()) return null;
try {
return invocationPlan.getMethod().invoke(invocationPlan.getSheep(), params);
} catch (IllegalArgumentException e) {
throw new RuntimeException(invocationPlan.getSheep().getClass().getName() + " is not assignable from " + invocationPlan.getHandlingClass().getName(), e);
} catch (InvocationTargetException e) {
if (e.getCause() instanceof RuntimeException) {
throw (RuntimeException) e.getCause();
} else {
throw new RuntimeException("Did your fake implementation of a method throw an exception? Refer to the bottom of this stack trace.", e);
}
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
static Class<?> loadClass(String paramType, ClassLoader classLoader) {
Class primitiveClass = Type.findPrimitiveClass(paramType);
if (primitiveClass != null) return primitiveClass;
int arrayLevel = 0;
while (paramType.endsWith("[]")) {
arrayLevel++;
paramType = paramType.substring(0, paramType.length() - 2);
}
Class<?> clazz = Type.findPrimitiveClass(paramType);
if (clazz == null) {
try {
clazz = classLoader.loadClass(paramType);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
while (arrayLevel-- > 0) {
clazz = Array.newInstance(clazz, 0).getClass();
}
return clazz;
}
public Object sheepFor(Object instance) {
Field field = getSheepField(instance);
Object sheep = readField(instance, field);
if (sheep != null) {
return sheep;
}
String sheepClassName = getHandlingClassName(instance.getClass());
if (debug)
System.out.println("creating new " + sheepClassName + " as fake for " + instance.getClass().getName());
try {
Class<?> sheepClass = loadClass(sheepClassName, instance.getClass().getClassLoader());
Constructor<?> constructor = findConstructor(instance, sheepClass);
if (constructor != null) {
sheep = constructor.newInstance(instance);
} else {
sheep = sheepClass.newInstance();
}
field.set(instance, sheep);
return sheep;
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
}
}
private String getHandlingClassName(Class clazz) {
String sheepClassName = null;
while (sheepClassName == null && clazz != null) {
sheepClassName = sheepClassMap.get(clazz.getName());
clazz = clazz.getSuperclass();
}
return sheepClassName;
}
private Constructor<?> findConstructor(Object instance, Class<?> sheepClass) {
Class clazz = instance.getClass();
Constructor constructor;
for (constructor = null; constructor == null && clazz != null; clazz = clazz.getSuperclass()) {
try {
constructor = sheepClass.getConstructor(clazz);
} catch (NoSuchMethodException e) {
// expected
}
}
return constructor;
}
private Object readField(Object instance, Field field) {
try {
return field.get(instance);
} catch (IllegalAccessException e1) {
throw new RuntimeException(e1);
}
}
private Field getSheepField(Object instance) {
Class clazz = instance.getClass();
Field field = sheepFieldMap.get(clazz);
if (field == null) {
try {
field = clazz.getField(SHEEP_FIELD_NAME);
} catch (NoSuchFieldException e) {
throw new RuntimeException(instance.getClass().getName() + " has no fake field", e);
}
sheepFieldMap.put(clazz, field);
}
return field;
}
public Object proxyFor(Object instance) {
if (instance == null) {
throw new RuntimeException("no instance for which to get a proxy");
}
Field field = getSheepField(instance);
return readField(instance, field);
}
private class InvocationPlan {
private Class clazz;
private String methodName;
private Object instance;
private String[] paramTypes;
private Class<?> handlingClass;
private Method method;
private Object sheep;
public InvocationPlan(Class clazz, String methodName, Object instance, String... paramTypes) {
this.clazz = clazz;
this.methodName = methodName;
this.instance = instance;
this.paramTypes = paramTypes;
}
public Class<?> getHandlingClass() {
return handlingClass;
}
public Method getMethod() {
return method;
}
public Object getSheep() {
return sheep;
}
public boolean prepare() {
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
ClassLoader classLoader = clazz.getClassLoader();
for (int i = 0; i < paramTypes.length; i++) {
paramClasses[i] = loadClass(paramTypes[i], classLoader);
}
Class<?> originalClass = loadClass(clazz.getName(), classLoader);
Class<?> declaringClass;
if (methodName.equals("<init>")) {
declaringClass = originalClass;
} else {
Method originalMethod;
try {
originalMethod = originalClass.getDeclaredMethod(methodName, paramClasses);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
declaringClass = originalMethod.getDeclaringClass();
}
String handlingClassName = getHandlingClassName(declaringClass);
if (handlingClassName == null) {
return false;
}
handlingClass = loadClass(handlingClassName, classLoader);
if (methodName.equals("<init>")) {
methodName = "__constructor__";
}
if (instance != null) {
sheep = sheepFor(instance);
method = getMethod(sheep.getClass(), methodName, paramClasses);
} else {
sheep = null;
String sheepClassName = getHandlingClassName(clazz);
Class<?> sheepClass = loadClass(sheepClassName, classLoader);
method = getMethod(sheepClass, methodName, paramClasses);
}
if (method == null) {
if (debug) {
System.out.println("No method found for " + clazz + "." + methodName + "(" + Arrays.asList(paramClasses) + ") on " + handlingClass.getName());
}
return false;
}
if ((instance == null) != Modifier.isStatic(method.getModifiers())) {
throw new RuntimeException("method staticness of " + clazz.getName() + "." + methodName + " and " + handlingClassName + "." + method.getName() + " don't match");
}
method.setAccessible(true);
return true;
}
private Method getMethod(Class<?> clazz, String methodName, Class<?>[] paramClasses) {
try {
return clazz.getMethod(methodName, paramClasses);
} catch (NoSuchMethodException e) {
try {
return clazz.getDeclaredMethod(methodName, paramClasses);
} catch (NoSuchMethodException e1) {
return null;
}
}
}
}
}