| /** |
| * Copyright (C) 2021 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include <dlfcn.h> |
| #include <string.h> |
| #include <openssl/ssl.h> |
| #include <openssl/crypto.h> |
| #include <openssl/bn.h> |
| #include <memory> |
| |
| /** NOTE: These values are for the BIGNUM declared in kBN2DecTests and */ |
| /** must be updated if kBN2DecTests is changed. */ |
| #define MALLOC_SIZE_32BITS 11 |
| #define MALLOC_SIZE_64BITS 6 |
| static const int sMallocSkipCount32[] = {1,0}; |
| static const int sMallocSkipCount64[] = {0,0}; |
| static const char *kTest = |
| "123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"; |
| static int sCount = 0; |
| static bool sOverloadMalloc = false; |
| int loopIndex = 0; |
| |
| template<typename T> |
| struct OpenSSLFree { |
| void operator()(T *buf) { |
| OPENSSL_free(buf); |
| } |
| }; |
| |
| using ScopedOpenSSLString = std::unique_ptr<char, OpenSSLFree<char>>; |
| namespace crypto { |
| template<typename T, void (*func)(T*)> |
| struct OpenSSLDeleter { |
| void operator()(T *obj) { |
| func(obj); |
| } |
| }; |
| |
| template<typename Type, void (*Destroyer)(Type*)> |
| struct OpenSSLDestroyer { |
| void operator()(Type* ptr) const { |
| Destroyer(ptr); |
| } |
| }; |
| |
| template<typename T, void (*func)(T*)> |
| using ScopedOpenSSLType = std::unique_ptr<T, OpenSSLDeleter<T, func>>; |
| template<typename PointerType, void (*Destroyer)(PointerType*)> |
| using ScopedOpenSSL = |
| std::unique_ptr<PointerType, OpenSSLDestroyer<PointerType, Destroyer>>; |
| |
| struct OpenSSLFree { |
| void operator()(uint8_t* ptr) const { |
| OPENSSL_free(ptr); |
| } |
| }; |
| |
| using ScopedBIGNUM = ScopedOpenSSL<BIGNUM, BN_free>; |
| using ScopedBN_CTX = ScopedOpenSSLType<BN_CTX, BN_CTX_free>; |
| } // namespace crypto |
| |
| static int DecimalToBIGNUM(crypto::ScopedBIGNUM *out, const char *in) { |
| BIGNUM *raw = nullptr; |
| int ret = BN_dec2bn(&raw, in); |
| out->reset(raw); |
| return ret; |
| } |
| |
| void* (*realMalloc)(size_t) = nullptr; |
| void mtraceInit(void) { |
| realMalloc = (void *(*)(size_t))dlsym(RTLD_NEXT, "malloc"); |
| return; |
| } |
| |
| void *malloc(size_t size) { |
| if (realMalloc == nullptr) { |
| mtraceInit(); |
| } |
| if (!sOverloadMalloc) { |
| return realMalloc(size); |
| } |
| int mallocSize = MALLOC_SIZE_32BITS; |
| int mallocSkipCount = sMallocSkipCount32[loopIndex]; |
| if (sizeof(BN_ULONG) == 8) { |
| mallocSize = MALLOC_SIZE_64BITS; |
| mallocSkipCount = sMallocSkipCount64[loopIndex]; |
| } |
| if (size == (sizeof(BN_ULONG) * mallocSize)) { |
| if (sCount >= mallocSkipCount) { |
| return nullptr; |
| } |
| ++sCount; |
| } |
| return realMalloc(size); |
| } |
| |
| using namespace crypto; |
| int main() { |
| CRYPTO_library_init(); |
| ScopedBN_CTX ctx(BN_CTX_new()); |
| if (!ctx) { |
| return EXIT_FAILURE; |
| } |
| for(loopIndex = 0; loopIndex < 2; ++loopIndex) { |
| ScopedBIGNUM bn; |
| int ret = DecimalToBIGNUM(&bn, kTest); |
| if (!ret) { |
| return EXIT_FAILURE; |
| } |
| sOverloadMalloc = true; |
| ScopedOpenSSLString dec(BN_bn2dec(bn.get())); |
| sOverloadMalloc = false; |
| if (!dec) { |
| return EXIT_FAILURE; |
| } |
| if (strcmp(dec.get(), kTest)) { |
| return EXIT_FAILURE; |
| } |
| } |
| return EXIT_SUCCESS; |
| } |