blob: 8f0a24ba32ea7fb46bcaed210ebdee70b5e45fdf [file]
/*
* Copyright (C) 2026 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
namespace android {
/**
* A fixed-size bitset similar to std::bitset but with a few differences:
* No constructors from or conversions to an unsigned long long
* Allows exposing the underlying storage object for easy serialization & deserialization
* Has an implicit bool that behaves the same as `any()`
* Is constexpr on all standards (we're not quite C++23 everywhere yet)
*/
template <size_t N>
class BitSet {
private:
static constexpr size_t kBitsPerWord = 64;
static constexpr size_t kNumWords = (N == 0) ? 0 : (N - 1) / kBitsPerWord + 1;
using WordT = uint64_t;
std::array<WordT, kNumWords> mBits = {};
constexpr void clean_unused_bits() noexcept {
if constexpr (N > 0 && N % kBitsPerWord != 0) {
constexpr WordT mask = (WordT(1) << (N % kBitsPerWord)) - 1;
mBits[kNumWords - 1] &= mask;
}
}
public:
class reference {
friend class BitSet;
BitSet* mBitSet;
size_t mPos;
constexpr reference(BitSet* bitset, size_t pos) : mBitSet(bitset), mPos(pos) {}
public:
constexpr reference& operator=(bool x) noexcept {
mBitSet->set(mPos, x);
return *this;
}
constexpr reference& operator=(const reference& x) noexcept {
mBitSet->set(mPos, x.mBitSet->test(x.mPos));
return *this;
}
constexpr bool operator~() const noexcept { return !mBitSet->test(mPos); }
constexpr operator bool() const noexcept { return mBitSet->test(mPos); }
constexpr reference& flip() noexcept {
mBitSet->flip(mPos);
return *this;
}
};
explicit constexpr BitSet() = default;
explicit constexpr BitSet(std::initializer_list<size_t> list) {
for (size_t pos : list) {
set(pos);
}
}
constexpr BitSet(const BitSet&) = default;
constexpr BitSet(BitSet&&) = default;
BitSet& operator=(const BitSet&) = default;
BitSet& operator=(BitSet&&) = default;
constexpr const std::byte* data() const noexcept {
return reinterpret_cast<const std::byte*>(mBits.data());
}
constexpr std::byte* data() noexcept { return reinterpret_cast<std::byte*>(mBits.data()); }
constexpr size_t dataSize() const noexcept { return mBits.size() * sizeof(WordT); }
constexpr bool operator==(const BitSet& other) const noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
if (mBits[i] != other.mBits[i]) return false;
}
return true;
}
constexpr bool operator!=(const BitSet& other) const noexcept { return !(*this == other); }
constexpr bool test(size_t pos) const {
if (pos >= N) std::__throw_overflow_error("BitSet::test");
return (mBits[pos / kBitsPerWord] & (WordT(1) << (pos % kBitsPerWord))) != 0;
}
constexpr bool all() const noexcept {
if constexpr (N == 0) return true;
for (size_t i = 0; i < kNumWords - 1; ++i) {
if (mBits[i] != ~WordT(0)) return false;
}
constexpr WordT mask =
(N % kBitsPerWord == 0) ? ~WordT(0) : (WordT(1) << (N % kBitsPerWord)) - 1;
return (mBits[kNumWords - 1] & mask) == mask;
}
constexpr bool any() const noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
if (mBits[i] != 0) return true;
}
return false;
}
constexpr bool none() const noexcept { return !any(); }
constexpr operator bool() const noexcept { return any(); }
// This is a bit of a hack to prevent implicit conversion to integral types, while
// still allowing implicit conversion to bool. The implicit conversion to bool is used
// much too broadly to switch to explicit, and without this there would be a silent implicit
// conversion to uint64_t by converting to bool -> uint64_t, which is definitely wrong
template <typename T,
typename = std::enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, bool>>>
constexpr operator T() const noexcept = delete;
constexpr size_t count() const noexcept {
size_t c = 0;
for (size_t i = 0; i < kNumWords; ++i) {
c += __builtin_popcountll(mBits[i]);
}
return c;
}
constexpr size_t size() const noexcept { return N; }
constexpr bool operator[](size_t pos) const noexcept {
return (mBits[pos / kBitsPerWord] & (WordT(1) << (pos % kBitsPerWord))) != 0;
}
constexpr reference operator[](size_t pos) noexcept { return reference(this, pos); }
constexpr BitSet& set() noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] = ~WordT(0);
}
clean_unused_bits();
return *this;
}
constexpr BitSet& set(size_t pos, bool value = true) {
if (pos >= N) std::__throw_overflow_error("BitSet::set");
if (value) {
mBits[pos / kBitsPerWord] |= (WordT(1) << (pos % kBitsPerWord));
} else {
mBits[pos / kBitsPerWord] &= ~(WordT(1) << (pos % kBitsPerWord));
}
return *this;
}
constexpr BitSet& reset() noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] = 0;
}
return *this;
}
constexpr BitSet& reset(size_t pos) {
if (pos >= N) std::__throw_overflow_error("BitSet::reset");
mBits[pos / kBitsPerWord] &= ~(WordT(1) << (pos % kBitsPerWord));
return *this;
}
constexpr BitSet& flip() noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] = ~mBits[i];
}
clean_unused_bits();
return *this;
}
constexpr BitSet& flip(size_t pos) {
if (pos >= N) std::__throw_overflow_error("BitSet::flip");
mBits[pos / kBitsPerWord] ^= (WordT(1) << (pos % kBitsPerWord));
return *this;
}
constexpr BitSet& operator&=(const BitSet& other) noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] &= other.mBits[i];
}
return *this;
}
constexpr BitSet& operator|=(const BitSet& other) noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] |= other.mBits[i];
}
return *this;
}
constexpr BitSet& operator^=(const BitSet& other) noexcept {
for (size_t i = 0; i < kNumWords; ++i) {
mBits[i] ^= other.mBits[i];
}
return *this;
}
constexpr BitSet operator~() const noexcept {
BitSet result(*this);
result.flip();
return result;
}
std::string to_string(char zero = '0', char one = '1') const {
std::string str(N, zero);
for (size_t i = 0; i < N; ++i) {
if (test(i)) {
str[N - 1 - i] = one;
}
}
return str;
}
constexpr BitSet& operator<<=(size_t shift) noexcept {
if (shift == 0) return *this;
if (shift >= N) {
reset();
return *this;
}
const size_t word_shift = shift / kBitsPerWord;
const size_t bit_shift = shift % kBitsPerWord;
if (bit_shift == 0) {
for (size_t i = kNumWords; i-- > word_shift;) {
mBits[i] = mBits[i - word_shift];
}
} else {
for (size_t i = kNumWords; i-- > word_shift + 1;) {
mBits[i] = (mBits[i - word_shift] << bit_shift) |
(mBits[i - word_shift - 1] >> (kBitsPerWord - bit_shift));
}
mBits[word_shift] = mBits[0] << bit_shift;
}
for (size_t i = 0; i < word_shift; ++i) {
mBits[i] = 0;
}
clean_unused_bits();
return *this;
}
constexpr BitSet& operator>>=(size_t shift) noexcept {
if (shift == 0) return *this;
if (shift >= N) {
reset();
return *this;
}
const size_t word_shift = shift / kBitsPerWord;
const size_t bit_shift = shift % kBitsPerWord;
if (bit_shift == 0) {
for (size_t i = 0; i < kNumWords - word_shift; ++i) {
mBits[i] = mBits[i + word_shift];
}
} else {
for (size_t i = 0; i < kNumWords - word_shift - 1; ++i) {
mBits[i] = (mBits[i + word_shift] >> bit_shift) |
(mBits[i + word_shift + 1] << (kBitsPerWord - bit_shift));
}
mBits[kNumWords - word_shift - 1] = mBits[kNumWords - 1] >> bit_shift;
}
for (size_t i = kNumWords - word_shift; i < kNumWords; ++i) {
mBits[i] = 0;
}
return *this;
}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
constexpr BitSet operator<<(T pos) const noexcept {
BitSet result(*this);
result <<= static_cast<size_t>(pos);
return result;
}
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
constexpr BitSet operator>>(T pos) const noexcept {
BitSet result(*this);
result >>= static_cast<size_t>(pos);
return result;
}
friend constexpr BitSet operator&(const BitSet& lhs, const BitSet& rhs) noexcept {
BitSet result(lhs);
result &= rhs;
return result;
}
friend constexpr BitSet operator|(const BitSet& lhs, const BitSet& rhs) noexcept {
BitSet result(lhs);
result |= rhs;
return result;
}
friend constexpr BitSet operator^(const BitSet& lhs, const BitSet& rhs) noexcept {
BitSet result(lhs);
result ^= rhs;
return result;
}
};
} // namespace android