blob: fcd92587c22949fadb0f2b54c11c030dbeb9d3a8 [file] [log] [blame]
/**
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <dlfcn.h>
#include <string.h>
#include <openssl/ssl.h>
#include <openssl/crypto.h>
#include <openssl/bn.h>
#include <memory>
/** NOTE: These values are for the BIGNUM declared in kBN2DecTests and */
/** must be updated if kBN2DecTests is changed. */
#define MALLOC_SIZE_32BITS 11
#define MALLOC_SIZE_64BITS 6
static const int sMallocSkipCount32[] = {1,0};
static const int sMallocSkipCount64[] = {0,0};
static const char *kTest =
"123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890";
static int sCount = 0;
static bool sOverloadMalloc = false;
int loopIndex = 0;
template<typename T>
struct OpenSSLFree {
void operator()(T *buf) {
OPENSSL_free(buf);
}
};
using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>;
namespace crypto {
template<typename T, void (*func)(T*)>
struct OpenSSLDeleter {
void operator()(T *obj) {
func(obj);
}
};
template<typename Type, void (*Destroyer)(Type*)>
struct OpenSSLDestroyer {
void operator()(Type* ptr) const {
Destroyer(ptr);
}
};
template<typename T, void (*func)(T*)>
using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>;
template<typename PointerType, void (*Destroyer)(PointerType*)>
using ScopedOpenSSL =
std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>;
struct OpenSSLFree {
void operator()(uint8_t* ptr) const {
OPENSSL_free(ptr);
}
};
using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>;
using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>;
} // namespace crypto
static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) {
BIGNUM *raw = nullptr;
int ret = BN_dec2bn(&raw, in);
out->reset(raw);
return ret;
}
void* (*realMalloc)(size_t) = nullptr;
void mtraceInit(void) {
realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc");
return;
}
void *malloc(size_t size) {
if (realMalloc == nullptr) {
mtraceInit();
}
if (!sOverloadMalloc) {
return realMalloc(size);
}
int mallocSize = MALLOC_SIZE_32BITS;
int mallocSkipCount = sMallocSkipCount32[loopIndex];
if (sizeof(BN_ULONG) == 8) {
mallocSize = MALLOC_SIZE_64BITS;
mallocSkipCount = sMallocSkipCount64[loopIndex];
}
if (size == (sizeof(BN_ULONG) * mallocSize)) {
if (sCount >= mallocSkipCount) {
return nullptr;
}
++sCount;
}
return realMalloc(size);
}
using namespace crypto;
int main() {
CRYPTO_library_init();
ScopedBN_CTX ctx(BN_CTX_new());
if (!ctx) {
return EXIT_FAILURE;
}
for(loopIndex = 0; loopIndex < 2; ++loopIndex) {
ScopedBIGNUM bn;
int ret = DecimalToBIGNUM(&bn, kTest);
if (!ret) {
return EXIT_FAILURE;
}
sOverloadMalloc = true;
ScopedOpenSSLString dec(BN_bn2dec(bn.get()));
sOverloadMalloc = false;
if (!dec) {
return EXIT_FAILURE;
}
if (strcmp(dec.get(), kTest)) {
return EXIT_FAILURE;
}
}
return EXIT_SUCCESS;
}