blob: f9f61e62a9dfc631d2fa19b7947a688dde44c688 [file] [log] [blame]
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio
import okio.ByteString.Companion.encodeUtf8
import org.junit.Assume
import java.io.IOException
import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.io.Serializable
import java.util.Random
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
object TestUtil {
// Necessary to make an internal member visible to Java.
@JvmField val SEGMENT_POOL_MAX_SIZE = SegmentPool.MAX_SIZE
const val SEGMENT_SIZE = Segment.SIZE
const val REPLACEMENT_CODE_POINT: Int = okio.REPLACEMENT_CODE_POINT
@JvmStatic fun segmentPoolByteCount() = SegmentPool.byteCount
@JvmStatic
fun segmentSizes(buffer: Buffer): List<Int> = okio.segmentSizes(buffer)
@JvmStatic
fun assertNoEmptySegments(buffer: Buffer) {
assertTrue(segmentSizes(buffer).all { it != 0 }, "Expected all segments to be non-empty")
}
@JvmStatic
fun assertByteArraysEquals(a: ByteArray, b: ByteArray) {
assertEquals(a.contentToString(), b.contentToString())
}
@JvmStatic
fun assertByteArrayEquals(expectedUtf8: String, b: ByteArray) {
assertEquals(expectedUtf8, b.toString(Charsets.UTF_8))
}
@JvmStatic
fun randomBytes(length: Int): ByteString {
val random = Random(0)
val randomBytes = ByteArray(length)
random.nextBytes(randomBytes)
return ByteString.of(*randomBytes)
}
@JvmStatic
fun randomSource(size: Long): Source {
return object : Source {
internal var random = Random(0)
internal var bytesLeft = size
internal var closed: Boolean = false
@Throws(IOException::class)
override fun read(sink: Buffer, byteCount: Long): Long {
var byteCount = byteCount
if (closed) throw IllegalStateException("closed")
if (bytesLeft == 0L) return -1L
if (byteCount > bytesLeft) byteCount = bytesLeft
// If we can read a full segment we can save a copy.
if (byteCount >= Segment.SIZE) {
val segment = sink.writableSegment(Segment.SIZE)
random.nextBytes(segment.data)
segment.limit += Segment.SIZE
sink.size += Segment.SIZE.toLong()
bytesLeft -= Segment.SIZE.toLong()
return Segment.SIZE.toLong()
} else {
val data = ByteArray(byteCount.toInt())
random.nextBytes(data)
sink.write(data)
bytesLeft -= byteCount
return byteCount
}
}
override fun timeout() = Timeout.NONE
@Throws(IOException::class)
override fun close() {
closed = true
}
}
}
@JvmStatic
fun assertEquivalent(b1: ByteString, b2: ByteString) {
// Equals.
assertTrue(b1 == b2)
assertTrue(b1 == b1)
assertTrue(b2 == b1)
// Hash code.
assertEquals(b1.hashCode().toLong(), b2.hashCode().toLong())
assertEquals(b1.hashCode().toLong(), b1.hashCode().toLong())
assertEquals(b1.toString(), b2.toString())
// Content.
assertEquals(b1.size.toLong(), b2.size.toLong())
val b2Bytes = b2.toByteArray()
for (i in b2Bytes.indices) {
val b = b2Bytes[i]
assertEquals(b.toLong(), b1[i].toLong())
}
assertByteArraysEquals(b1.toByteArray(), b2Bytes)
// Doesn't equal a different byte string.
assertFalse(b1 == null)
assertFalse(b1 == Any())
if (b2Bytes.size > 0) {
val b3Bytes = b2Bytes.clone()
b3Bytes[b3Bytes.size - 1]++
val b3 = ByteString(b3Bytes)
assertFalse(b1 == b3)
assertFalse(b1.hashCode() == b3.hashCode())
} else {
val b3 = "a".encodeUtf8()
assertFalse(b1 == b3)
assertFalse(b1.hashCode() == b3.hashCode())
}
}
@JvmStatic
fun assertEquivalent(b1: Buffer, b2: Buffer) {
// Equals.
assertTrue(b1 == b2)
assertTrue(b1 == b1)
assertTrue(b2 == b1)
// Hash code.
assertEquals(b1.hashCode().toLong(), b2.hashCode().toLong())
assertEquals(b1.hashCode().toLong(), b1.hashCode().toLong())
assertEquals(b1.toString(), b2.toString())
// Content.
assertEquals(b1.size, b2.size)
val buffer = Buffer()
b2.copyTo(buffer, 0, b2.size)
val b2Bytes = b2.readByteArray()
for (i in b2Bytes.indices) {
val b = b2Bytes[i]
assertEquals(b.toLong(), b1[i.toLong()].toLong())
}
// Doesn't equal a different buffer.
assertFalse(b1 == Any())
if (b2Bytes.size > 0) {
val b3Bytes = b2Bytes.clone()
b3Bytes[b3Bytes.size - 1]++
val b3 = Buffer().write(b3Bytes)
assertFalse(b1 == b3)
assertFalse(b1.hashCode() == b3.hashCode())
} else {
val b3 = Buffer().writeUtf8("a")
assertFalse(b1 == b3)
assertFalse(b1.hashCode() == b3.hashCode())
}
}
/** Serializes original to bytes, then deserializes those bytes and returns the result. */
@Suppress("UNCHECKED_CAST")
@Throws(Exception::class)
@JvmStatic
// Assume serialization doesn't change types.
fun <T : Serializable> reserialize(original: T): T {
val buffer = Buffer()
val out = ObjectOutputStream(buffer.outputStream())
out.writeObject(original)
val input = ObjectInputStream(buffer.inputStream())
return input.readObject() as T
}
/**
* Returns a new buffer containing the data in `data` and a segment
* layout determined by `dice`.
*/
@Throws(IOException::class)
@JvmStatic
fun bufferWithRandomSegmentLayout(dice: Random, data: ByteArray): Buffer {
val result = Buffer()
// Writing to result directly will yield packed segments. Instead, write to
// other buffers, then write those buffers to result.
var pos = 0
var byteCount: Int
while (pos < data.size) {
byteCount = Segment.SIZE / 2 + dice.nextInt(Segment.SIZE / 2)
if (byteCount > data.size - pos) byteCount = data.size - pos
val offset = dice.nextInt(Segment.SIZE - byteCount)
val segment = Buffer()
segment.write(ByteArray(offset))
segment.write(data, pos, byteCount)
segment.skip(offset.toLong())
result.write(segment, byteCount.toLong())
pos += byteCount
}
return result
}
/**
* Returns a new buffer containing the contents of `segments`, attempting to isolate each
* string to its own segment in the returned buffer. This clones buffers so that segments are
* shared, preventing compaction from occurring.
*/
@Throws(Exception::class)
@JvmStatic
fun bufferWithSegments(vararg segments: String): Buffer {
val result = Buffer()
for (s in segments) {
val offsetInSegment = if (s.length < Segment.SIZE) (Segment.SIZE - s.length) / 2 else 0
val buffer = Buffer()
buffer.writeUtf8("_".repeat(offsetInSegment))
buffer.writeUtf8(s)
buffer.skip(offsetInSegment.toLong())
result.write(buffer.clone(), buffer.size)
}
return result
}
@JvmStatic
fun makeSegments(source: ByteString): ByteString {
val buffer = Buffer()
for (i in 0 until source.size) {
val segment = buffer.writableSegment(SEGMENT_SIZE)
segment.data[segment.pos] = source[i]
segment.limit++
buffer.size++
}
return buffer.snapshot()
}
/** Remove all segments from the pool and return them as a list. */
@JvmStatic
internal fun takeAllPoolSegments(): List<Segment> {
val result = mutableListOf<Segment>()
while (SegmentPool.byteCount > 0) {
result += SegmentPool.take()
}
return result
}
/** Returns a copy of `buffer` with no segments with `original`. */
@JvmStatic
fun deepCopy(original: Buffer): Buffer {
val result = Buffer()
if (original.size == 0L) return result
result.head = original.head!!.unsharedCopy()
result.head!!.prev = result.head
result.head!!.next = result.head!!.prev
var s = original.head!!.next
while (s !== original.head) {
result.head!!.prev!!.push(s!!.unsharedCopy())
s = s.next
}
result.size = original.size
return result
}
@JvmStatic
fun Int.reverseBytes(): Int {
/* ktlint-disable no-multi-spaces indent */
return (this and -0x1000000 ushr 24) or
(this and 0x00ff0000 ushr 8) or
(this and 0x0000ff00 shl 8) or
(this and 0x000000ff shl 24)
/* ktlint-enable no-multi-spaces indent */
}
@JvmStatic
fun Short.reverseBytes(): Short {
val i = toInt() and 0xffff
/* ktlint-disable no-multi-spaces indent */
val reversed = (i and 0xff00 ushr 8) or
(i and 0x00ff shl 8)
/* ktlint-enable no-multi-spaces indent */
return reversed.toShort()
}
fun assumeNotWindows() = Assume.assumeFalse(System.getProperty("os.name").toLowerCase().contains("win"))
}