blob: 586cc24fa36b75ff617c016782bb339e171ff35a [file] [log] [blame]
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectInputFilter;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Set;
import java.util.concurrent.atomic.LongAdder;
import javax.net.ssl.SSLEngineResult;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.testng.annotations.DataProvider;
/* @test
* @build SerialFilterTest
* @run testng/othervm SerialFilterTest
*
* @summary Test ObjectInputFilters
*/
@Test
public class SerialFilterTest implements Serializable {
private static final long serialVersionUID = -6999613679881262446L;
/**
* Enable three arg lambda.
* @param <T> The pattern
* @param <U> The test object
* @param <V> Boolean for if the filter should allow or reject
*/
interface TriConsumer< T, U, V> {
void accept(T t, U u, V v);
}
/**
* Misc object to use that should always be accepted.
*/
private static final Object otherObject = Integer.valueOf(0);
/**
* DataProvider for the individual patterns to test.
* Expand the patterns into cases for each of the Std and Compatibility APIs.
* @return an array of arrays of the parameters including factories
*/
@DataProvider(name="Patterns")
static Object[][] patterns() {
Object[][] patterns = new Object[][]{
{"java.util.Hashtable"},
{"java.util.Hash*"},
{"javax.net.ssl.*"},
{"javax.net.**"},
{"*"},
{"maxarray=47"},
{"maxdepth=5"},
{"maxrefs=10"},
{"maxbytes=100"},
{"maxbytes=72"},
{"maxbytes=+1024"},
{"java.base/java.util.Hashtable"},
};
return patterns;
}
@DataProvider(name="InvalidPatterns")
static Object[][] invalidPatterns() {
return new Object [][] {
{".*"},
{".**"},
{"!"},
{"/java.util.Hashtable"},
{"java.base/"},
{"/"},
};
}
@DataProvider(name="Limits")
static Object[][] limits() {
// The numbers are arbitrary > 1
return new Object[][] {
{"maxrefs", 1}, // 0 is tested as n-1
{"maxrefs", 10},
{"maxdepth", 5},
{"maxbytes", 100},
{"maxarray", 16},
{"maxbytes", Long.MAX_VALUE},
};
}
@DataProvider(name="InvalidLimits")
static Object[][] invalidLimits() {
return new Object[][] {
{"maxrefs=-1"},
{"maxdepth=-1"},
{"maxbytes=-1"},
{"maxarray=-1"},
{"xyz=0"},
{"xyz=-1"},
{"maxrefs=0xabc"},
{"maxrefs=abc"},
{"maxrefs="},
{"maxrefs=+"},
{"maxbytes=-1"},
{"maxbytes=9223372036854775808"},
{"maxbytes=-9223372036854775807"},
};
}
/**
* DataProvider of individual objects. Used to check the information
* available to the filter.
* @return Arrays of parameters with objects
*/
@DataProvider(name="Objects")
static Object[][] objects() {
byte[] byteArray = new byte[0];
Object[] objArray = new Object[7];
objArray[objArray.length - 1] = objArray;
Class<?> serClass = null;
String className = "java.util.concurrent.atomic.LongAdder$SerializationProxy";
try {
serClass = Class.forName(className);
} catch (Exception e) {
Assert.fail("missing class: " + className, e);
}
Class<?>[] interfaces = {Runnable.class};
Runnable proxy = (Runnable) Proxy.newProxyInstance(null,
interfaces, (p, m, args) -> p);
Runnable runnable = (Runnable & Serializable) SerialFilterTest::noop;
Object[][] objects = {
{ null, 0, -1, 0, 0, 0,
new HashSet<>()}, // no callback, no values
{ objArray, 3, 7, 8, 2, 55,
new HashSet<>(Arrays.asList(objArray.getClass()))},
{ Object[].class, 1, -1, 1, 1, 40,
new HashSet<>(Arrays.asList(Object[].class))},
{ new SerialFilterTest(), 1, -1, 1, 1, 37,
new HashSet<>(Arrays.asList(SerialFilterTest.class))},
{ new LongAdder(), 2, -1, 1, 1, 93,
new HashSet<>(Arrays.asList(LongAdder.class, serClass))},
{ new byte[14], 2, 14, 1, 1, 27,
new HashSet<>(Arrays.asList(byteArray.getClass()))},
{ runnable, 13, 0, 10, 2, 514,
new HashSet<>(Arrays.asList(java.lang.invoke.SerializedLambda.class,
SerialFilterTest.class,
objArray.getClass()))},
{ deepHashSet(10), 48, -1, 49, 11, 619,
new HashSet<>(Arrays.asList(HashSet.class))},
{ proxy.getClass(), 3, -1, 1, 1, 114,
new HashSet<>(Arrays.asList(Runnable.class,
java.lang.reflect.Proxy.class))},
};
return objects;
}
@DataProvider(name="Arrays")
static Object[][] arrays() {
return new Object[][]{
{new Object[16], 16},
{new boolean[16], 16},
{new byte[16], 16},
{new char[16], 16},
{new int[16], 16},
{new long[16], 16},
{new short[16], 16},
{new float[16], 16},
{new double[16], 16},
};
}
/**
* Test each object and verify the classes identified by the filter,
* the count of calls to the filter, the max array size, max refs, max depth,
* max bytes.
* This test ignores/is not dependent on the global filter settings.
*
* @param object a Serializable object
* @param count the expected count of calls to the filter
* @param maxArray the maximum array size
* @param maxRefs the maximum references
* @param maxDepth the maximum depth
* @param maxBytes the maximum stream size
* @param classes the expected (unique) classes
* @throws IOException
*/
@Test(dataProvider="Objects")
public static void t1(Object object,
long count, long maxArray, long maxRefs, long maxDepth, long maxBytes,
Set<Class<?>> classes) throws IOException {
byte[] bytes = writeObjects(object);
Validator validator = new Validator();
validate(bytes, validator);
System.out.printf("v: %s%n", validator);
Assert.assertEquals(validator.count, count, "callback count wrong");
Assert.assertEquals(validator.classes, classes, "classes mismatch");
Assert.assertEquals(validator.maxArray, maxArray, "maxArray mismatch");
Assert.assertEquals(validator.maxRefs, maxRefs, "maxRefs wrong");
Assert.assertEquals(validator.maxDepth, maxDepth, "depth wrong");
Assert.assertEquals(validator.maxBytes, maxBytes, "maxBytes wrong");
}
/**
* Test each pattern with an appropriate object.
* A filter is created from the pattern and used to serialize and
* deserialize a generated object with both the positive and negative case.
* This test ignores/is not dependent on the global filter settings.
*
* @param pattern a pattern
*/
@Test(dataProvider="Patterns")
static void testPatterns(String pattern) {
evalPattern(pattern, (p, o, neg) -> testPatterns(p, o, neg));
}
/**
* Test that the filter on a OIS can be set only on a fresh OIS,
* before deserializing any objects.
* This test is agnostic the global filter being set or not.
*/
@Test
static void nonResettableFilter() {
Validator validator1 = new Validator();
Validator validator2 = new Validator();
try {
byte[] bytes = writeObjects("text1"); // an object
try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais)) {
// Check the initial filter is the global filter; may be null
ObjectInputFilter global = ObjectInputFilter.Config.getSerialFilter();
ObjectInputFilter initial = ois.getObjectInputFilter();
Assert.assertEquals(global, initial, "initial filter should be the global filter");
// Check if it can be set to null
ois.setObjectInputFilter(null);
ObjectInputFilter filter = ois.getObjectInputFilter();
Assert.assertNull(filter, "set to null should be null");
ois.setObjectInputFilter(validator1);
Object o = ois.readObject();
try {
ois.setObjectInputFilter(validator2);
Assert.fail("Should not be able to set filter twice");
} catch (IllegalStateException ise) {
// success, the exception was expected
}
} catch (EOFException eof) {
Assert.fail("Should not reach end-of-file", eof);
} catch (ClassNotFoundException cnf) {
Assert.fail("Deserializing", cnf);
}
} catch (IOException ex) {
Assert.fail("Unexpected IOException", ex);
}
}
/**
* Test that if an Objects readReadResolve method returns an array
* that the callback to the filter includes the proper array length.
* @throws IOException if an error occurs
*/
@Test(dataProvider="Arrays")
static void testReadResolveToArray(Object array, int length) throws IOException {
ReadResolveToArray object = new ReadResolveToArray(array, length);
byte[] bytes = writeObjects(object);
Object o = validate(bytes, object); // the object is its own filter
Assert.assertEquals(o.getClass(), array.getClass(), "Filter not called with the array");
}
/**
* Test repeated limits use the last value.
* Construct a filter with the limit and the limit repeated -1.
* Invoke the filter with the limit to make sure it is rejected.
* Invoke the filter with the limit -1 to make sure it is accepted.
* @param name the name of the limit to test
* @param value a test value
*/
@Test(dataProvider="Limits")
static void testLimits(String name, long value) {
Class<?> arrayClass = new int[0].getClass();
String pattern = String.format("%s=%d;%s=%d", name, value, name, value - 1);
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(pattern);
Assert.assertEquals(
filter.checkInput(new FilterValues(arrayClass, value, value, value, value)),
ObjectInputFilter.Status.REJECTED,
"last limit value not used: " + filter);
Assert.assertEquals(
filter.checkInput(new FilterValues(arrayClass, value-1, value-1, value-1, value-1)),
ObjectInputFilter.Status.UNDECIDED,
"last limit value not used: " + filter);
}
/**
* Test invalid limits.
* Construct a filter with the limit, it should throw IllegalArgumentException.
* @param pattern a pattern to test
*/
@Test(dataProvider="InvalidLimits", expectedExceptions=java.lang.IllegalArgumentException.class)
static void testInvalidLimits(String pattern) {
try {
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(pattern);
} catch (IllegalArgumentException iae) {
System.out.printf(" success exception: %s%n", iae);
throw iae;
}
}
/**
* Test that returning null from a filter causes deserialization to fail.
*/
@Test(expectedExceptions=InvalidClassException.class)
static void testNullStatus() throws IOException {
byte[] bytes = writeObjects(0); // an Integer
try {
Object o = validate(bytes, new ObjectInputFilter() {
public ObjectInputFilter.Status checkInput(ObjectInputFilter.FilterInfo f) {
return null;
}
});
} catch (InvalidClassException ice) {
System.out.printf(" success exception: %s%n", ice);
throw ice;
}
}
/**
* Verify that malformed patterns throw IAE.
* @param pattern pattern from the data source
*/
@Test(dataProvider="InvalidPatterns", expectedExceptions=IllegalArgumentException.class)
static void testInvalidPatterns(String pattern) {
try {
ObjectInputFilter.Config.createFilter(pattern);
} catch (IllegalArgumentException iae) {
System.out.printf(" success exception: %s%n", iae);
throw iae;
}
}
/**
* Test that Config.create returns null if the argument does not contain any patterns or limits.
*/
@Test()
static void testEmptyPattern() {
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter("");
Assert.assertNull(filter, "empty pattern did not return null");
filter = ObjectInputFilter.Config.createFilter(";;;;");
Assert.assertNull(filter, "pattern with only delimiters did not return null");
}
/**
* Read objects from the serialized stream, validated with the filter.
*
* @param bytes a byte array to read objects from
* @param filter the ObjectInputFilter
* @return the object deserialized if any
* @throws IOException can be thrown
*/
static Object validate(byte[] bytes,
ObjectInputFilter filter) throws IOException {
try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais)) {
ois.setObjectInputFilter(filter);
Object o = ois.readObject();
return o;
} catch (EOFException eof) {
// normal completion
} catch (ClassNotFoundException cnf) {
Assert.fail("Deserializing", cnf);
}
return null;
}
/**
* Write objects and return a byte array with the bytes.
*
* @param objects zero or more objects to serialize
* @return the byte array of the serialized objects
* @throws IOException if an exception occurs
*/
static byte[] writeObjects(Object... objects) throws IOException {
byte[] bytes;
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
for (Object o : objects) {
oos.writeObject(o);
}
bytes = baos.toByteArray();
}
return bytes;
}
/**
* A filter that accumulates information about the checkInput callbacks
* that can be checked after readObject completes.
*/
static class Validator implements ObjectInputFilter {
long count; // Count of calls to checkInput
HashSet<Class<?>> classes = new HashSet<>();
long maxArray = -1;
long maxRefs;
long maxDepth;
long maxBytes;
Validator() {
}
@Override
public ObjectInputFilter.Status checkInput(FilterInfo filter) {
count++;
if (filter.serialClass() != null) {
if (filter.serialClass().getName().contains("$$Lambda$")) {
// TBD: proper identification of serialized Lambdas?
// Fold the serialized Lambda into the SerializedLambda type
classes.add(SerializedLambda.class);
} else if (Proxy.isProxyClass(filter.serialClass())) {
classes.add(Proxy.class);
} else {
classes.add(filter.serialClass());
}
}
this.maxArray = Math.max(this.maxArray, filter.arrayLength());
this.maxRefs = Math.max(this.maxRefs, filter.references());
this.maxDepth = Math.max(this.maxDepth, filter.depth());
this.maxBytes = Math.max(this.maxBytes, filter.streamBytes());
return ObjectInputFilter.Status.UNDECIDED;
}
public String toString(){
return "count: " + count
+ ", classes: " + classes.toString()
+ ", maxArray: " + maxArray
+ ", maxRefs: " + maxRefs
+ ", maxDepth: " + maxDepth
+ ", maxBytes: " + maxBytes;
}
}
/**
* Create a filter from a pattern and API factory, then serialize and
* deserialize an object and check allowed or reject.
*
* @param pattern the pattern
* @param object the test object
* @param allowed the expected result from ObjectInputStream (exception or not)
*/
static void testPatterns(String pattern, Object object, boolean allowed) {
try {
byte[] bytes = SerialFilterTest.writeObjects(object);
ObjectInputFilter filter = ObjectInputFilter.Config.createFilter(pattern);
validate(bytes, filter);
Assert.assertTrue(allowed, "filter should have thrown an exception");
} catch (IllegalArgumentException iae) {
Assert.fail("bad format pattern", iae);
} catch (InvalidClassException ice) {
Assert.assertFalse(allowed, "filter should not have thrown an exception: " + ice);
} catch (IOException ioe) {
Assert.fail("Unexpected IOException", ioe);
}
}
/**
* For a filter pattern, generate and apply a test object to the action.
* @param pattern a pattern
* @param action an action to perform on positive and negative cases
*/
static void evalPattern(String pattern, TriConsumer<String, Object, Boolean> action) {
Object o = genTestObject(pattern, true);
Assert.assertNotNull(o, "success generation failed");
action.accept(pattern, o, true);
// Test the negative pattern
o = genTestObject(pattern, false);
Assert.assertNotNull(o, "fail generation failed");
String negPattern = pattern.contains("=") ? pattern : "!" + pattern;
action.accept(negPattern, o, false);
}
/**
* Generate a test object based on the pattern.
* Handles each of the forms of the pattern, wildcards,
* class name, various limit forms.
* @param pattern a pattern
* @param allowed a boolean indicating to generate the allowed or disallowed case
* @return an object or {@code null} to indicate no suitable object could be generated
*/
static Object genTestObject(String pattern, boolean allowed) {
if (pattern.contains("=")) {
return genTestLimit(pattern, allowed);
} else if (pattern.endsWith("*")) {
return genTestObjectWildcard(pattern, allowed);
} else {
// class
// isolate module name, if any
int poffset = 0;
int soffset = pattern.indexOf('/', poffset);
String module = null;
if (soffset >= 0) {
poffset = soffset + 1;
module = pattern.substring(0, soffset);
}
try {
Class<?> clazz = Class.forName(pattern.substring(poffset));
Constructor<?> cons = clazz.getConstructor();
return cons.newInstance();
} catch (ClassNotFoundException ex) {
Assert.fail("no such class available: " + pattern);
} catch (InvocationTargetException
| NoSuchMethodException
| InstantiationException
| IllegalAccessException ex1) {
Assert.fail("newInstance: " + ex1);
}
}
return null;
}
/**
* Generate an object to be used with the various wildcard pattern forms.
* Explicitly supports only specific package wildcards with specific objects.
* @param pattern a wildcard pattern ending in "*"
* @param allowed a boolean indicating to generate the allowed or disallowed case
* @return an object within or outside the wildcard
*/
static Object genTestObjectWildcard(String pattern, boolean allowed) {
if (pattern.endsWith(".**")) {
// package hierarchy wildcard
if (pattern.startsWith("javax.net.")) {
return SSLEngineResult.Status.BUFFER_OVERFLOW;
}
if (pattern.startsWith("java.")) {
return 4;
}
if (pattern.startsWith("javax.")) {
return SSLEngineResult.Status.BUFFER_UNDERFLOW;
}
return otherObject;
} else if (pattern.endsWith(".*")) {
// package wildcard
if (pattern.startsWith("javax.net.ssl")) {
return SSLEngineResult.Status.BUFFER_UNDERFLOW;
}
} else {
// class wildcard
if (pattern.equals("*")) {
return otherObject; // any object will do
}
if (pattern.startsWith("java.util.Hash")) {
return new Hashtable<String, String>();
}
}
Assert.fail("Object could not be generated for pattern: "
+ pattern
+ ", allowed: " + allowed);
return null;
}
/**
* Generate a limit test object for the pattern.
* For positive cases, the object exactly hits the limit.
* For negative cases, the object is 1 greater than the limit
* @param pattern the pattern, containing "=" and a maxXXX keyword
* @param allowed a boolean indicating to generate the allowed or disallowed case
* @return a sitable object
*/
static Object genTestLimit(String pattern, boolean allowed) {
int ndx = pattern.indexOf('=');
Assert.assertNotEquals(ndx, -1, "missing value in limit");
long value = Long.parseUnsignedLong(pattern.substring(ndx+1));
if (pattern.startsWith("maxdepth=")) {
// Return an object with the requested depth (or 1 greater)
long depth = allowed ? value : value + 1;
Object[] array = new Object[1];
for (int i = 1; i < depth; i++) {
Object[] n = new Object[1];
n[0] = array;
array = n;
}
return array;
} else if (pattern.startsWith("maxbytes=")) {
// Return a byte array that when written to OOS creates
// a stream of exactly the size requested.
return genMaxBytesObject(allowed, value);
} else if (pattern.startsWith("maxrefs=")) {
Object[] array = new Object[allowed ? (int)value - 1 : (int)value];
for (int i = 0; i < array.length; i++) {
array[i] = otherObject;
}
return array;
} else if (pattern.startsWith("maxarray=")) {
return allowed ? new int[(int)value] : new int[(int)value+1];
}
Assert.fail("Object could not be generated for pattern: "
+ pattern
+ ", allowed: " + allowed);
return null;
}
/**
* Generate an an object that will be serialized to some number of bytes.
* Or 1 greater if allowed is false.
* It returns a two element Object array holding a byte array sized
* to achieve the desired total size.
* @param allowed true if the stream should be allowed at that size,
* false if the stream should be larger
* @param maxBytes the number of bytes desired in the stream;
* should not be less than 72 (due to protocol overhead).
* @return a object that will be serialized to the length requested
*/
private static Object genMaxBytesObject(boolean allowed, long maxBytes) {
Object[] holder = new Object[2];
long desiredSize = allowed ? maxBytes : maxBytes + 1;
long actualSize = desiredSize;
long byteSize = desiredSize - 72; // estimate needed array size
do {
byteSize += (desiredSize - actualSize);
byte[] a = new byte[(int)byteSize];
holder[0] = a;
holder[1] = a;
try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(baos)) {
os.writeObject(holder);
os.flush();
actualSize = baos.size();
} catch (IOException ie) {
Assert.fail("exception generating stream", ie);
}
} while (actualSize != desiredSize);
return holder;
}
/**
* Returns a HashSet of a requested depth.
* @param depth the depth
* @return a HashSet of HashSets...
*/
static HashSet<Object> deepHashSet(int depth) {
HashSet<Object> hashSet = new HashSet<>();
HashSet<Object> s1 = hashSet;
HashSet<Object> s2 = new HashSet<>();
for (int i = 0; i < depth; i++ ) {
HashSet<Object> t1 = new HashSet<>();
HashSet<Object> t2 = new HashSet<>();
// make t1 not equal to t2
t1.add("by Jimminy");
s1.add(t1);
s1.add(t2);
s2.add(t1);
s2.add(t2);
s1 = t1;
s2 = t2;
}
return hashSet;
}
/**
* Simple method to use with Serialized Lambda.
*/
private static void noop() {}
/**
* Class that returns an array from readResolve and also implements
* the ObjectInputFilter to check that it has the expected length.
*/
static class ReadResolveToArray implements Serializable, ObjectInputFilter {
private static final long serialVersionUID = 123456789L;
private final Object array;
private final int length;
ReadResolveToArray(Object array, int length) {
this.array = array;
this.length = length;
}
Object readResolve() {
return array;
}
@Override
public ObjectInputFilter.Status checkInput(FilterInfo filter) {
if (ReadResolveToArray.class.isAssignableFrom(filter.serialClass())) {
return ObjectInputFilter.Status.ALLOWED;
}
if (filter.serialClass() != array.getClass() ||
(filter.arrayLength() >= 0 && filter.arrayLength() != length)) {
return ObjectInputFilter.Status.REJECTED;
}
return ObjectInputFilter.Status.UNDECIDED;
}
}
/**
* Hold a snapshot of values to be passed to an ObjectInputFilter.
*/
static class FilterValues implements ObjectInputFilter.FilterInfo {
private final Class<?> clazz;
private final long arrayLength;
private final long depth;
private final long references;
private final long streamBytes;
public FilterValues(Class<?> clazz, long arrayLength, long depth, long references, long streamBytes) {
this.clazz = clazz;
this.arrayLength = arrayLength;
this.depth = depth;
this.references = references;
this.streamBytes = streamBytes;
}
@Override
public Class<?> serialClass() {
return clazz;
}
public long arrayLength() {
return arrayLength;
}
public long depth() {
return depth;
}
public long references() {
return references;
}
public long streamBytes() {
return streamBytes;
}
}
}