blob: f273971dc8bb4e93b4d05a51da589dff5b8b046d [file] [log] [blame]
/*
* Copyright (C) 2020 Square, Inc. and others.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.random.Random
@RunWith(Parameterized::class)
class CipherSinkTest(private val cipherAlgorithm: CipherAlgorithm) {
companion object {
@get:Parameterized.Parameters(name = "{0}")
@get:JvmStatic
val parameters: List<CipherAlgorithm>
get() = CipherAlgorithm.BLOCK_CIPHER_ALGORITHMS
}
@Test
fun encrypt() {
val random = Random(8912860393601532863)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val data = random.nextBytes(32)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.encrypt)
cipherSink.buffer().use { it.write(data) }
val actualEncryptedData = buffer.readByteArray()
val expectedEncryptedData = cipherFactory.encrypt.doFinal(data)
assertArrayEquals(expectedEncryptedData, actualEncryptedData)
}
@Test
fun encryptEmpty() {
val random = Random(3014415396541767201)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val data = ByteArray(0)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.encrypt)
cipherSink.buffer().close()
val actualEncryptedData = buffer.readByteArray()
val expectedEncryptedData = cipherFactory.encrypt.doFinal(data)
assertArrayEquals(expectedEncryptedData, actualEncryptedData)
}
@Test
fun encryptLarge() {
val random = Random(4800508322764694019)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val data = random.nextBytes(Segment.SIZE * 16 + Segment.SIZE / 2)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.encrypt)
cipherSink.buffer().use { it.write(data) }
val actualEncryptedData = buffer.readByteArray()
val expectedEncryptedData = cipherFactory.encrypt.doFinal(data)
assertArrayEquals(expectedEncryptedData, actualEncryptedData)
}
@Test
fun encryptSingleByteWrite() {
val random = Random(4374178522096702290)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val data = random.nextBytes(32)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.encrypt)
cipherSink.buffer().use {
data.forEach {
byte ->
it.writeByte(byte.toInt())
}
}
val actualEncryptedData = buffer.readByteArray()
val expectedEncryptedData = cipherFactory.encrypt.doFinal(data)
assertArrayEquals(expectedEncryptedData, actualEncryptedData)
}
/** Only relevant for algorithms which handle padding. */
@Test
fun encryptPaddingRequired() {
val random = Random(7515202505362968404)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val blockSize = cipherFactory.blockSize
val dataSize = blockSize * 4 + if (cipherAlgorithm.padding) blockSize / 2 else 0
val data = random.nextBytes(dataSize)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.encrypt)
cipherSink.buffer().use { it.write(data) }
val actualEncryptedData = buffer.readByteArray()
val expectedEncryptedData = cipherFactory.encrypt.doFinal(data)
assertArrayEquals(expectedEncryptedData, actualEncryptedData)
}
@Test
fun decrypt() {
val random = Random(488375923060579687)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val expectedData = random.nextBytes(32)
val encryptedData = cipherFactory.encrypt.doFinal(expectedData)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.decrypt)
cipherSink.buffer().use { it.write(encryptedData) }
val actualData = buffer.readByteArray()
assertArrayEquals(expectedData, actualData)
}
@Test
fun decryptEmpty() {
val random = Random(-9063010151894844496)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val expectedData = ByteArray(0)
val encryptedData = cipherFactory.encrypt.doFinal(expectedData)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.decrypt)
cipherSink.buffer().use { it.write(encryptedData) }
val actualData = buffer.readByteArray()
assertArrayEquals(expectedData, actualData)
}
@Test
fun decryptLarge() {
val random = Random(993064087526004362)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val expectedData = random.nextBytes(Segment.SIZE * 16 + Segment.SIZE / 2)
val encryptedData = cipherFactory.encrypt.doFinal(expectedData)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.decrypt)
cipherSink.buffer().use { it.write(encryptedData) }
val actualData = buffer.readByteArray()
assertArrayEquals(expectedData, actualData)
}
@Test
fun decryptSingleByteWrite() {
val random = Random(2621474675920878975)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val expectedData = random.nextBytes(32)
val encryptedData = cipherFactory.encrypt.doFinal(expectedData)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.decrypt)
cipherSink.buffer().use {
encryptedData.forEach { byte ->
it.writeByte(byte.toInt())
}
}
val actualData = buffer.readByteArray()
assertArrayEquals(expectedData, actualData)
}
/** Only relevant for algorithms which handle padding. */
@Test
fun decryptPaddingRequired() {
val random = Random(7689061926945836562)
val cipherFactory = cipherAlgorithm.createCipherFactory(random)
val blockSize = cipherFactory.blockSize
val dataSize = blockSize * 4 + if (cipherAlgorithm.padding) blockSize / 2 else 0
val expectedData = random.nextBytes(dataSize)
val encryptedData = cipherFactory.encrypt.doFinal(expectedData)
val buffer = Buffer()
val cipherSink = buffer.cipherSink(cipherFactory.decrypt)
cipherSink.buffer().use { it.write(encryptedData) }
val actualData = buffer.readByteArray()
assertArrayEquals(expectedData, actualData)
}
}