Make BitSetT::Mask return a bit set

This allows a Mask function to be implemented for BitSetArray, which is
added in this CL.  When using larger bitsets on 32-bit systems, the
current Mask implementation prohibits its use.

This is in preparation for a follow up change that uses Mask on such a
bitset.

Bug: angleproject:7369
Change-Id: If995d96ec1583a546f20bff277f3223e2f2490f5
Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/4031493
Reviewed-by: Charlie Lao <cclao@google.com>
Commit-Queue: Shahbaz Youssefi <syoussefi@chromium.org>
diff --git a/src/common/bitset_utils.h b/src/common/bitset_utils.h
index ee9a3f1..9ebee5d 100644
--- a/src/common/bitset_utils.h
+++ b/src/common/bitset_utils.h
@@ -163,7 +163,12 @@
     constexpr ParamT last() const;
 
     // Produces a mask of ones up to the "x"th bit.
-    constexpr static BitsT Mask(std::size_t x) { return BitMask<BitsT>(static_cast<ParamT>(x)); }
+    constexpr static BitSetT Mask(std::size_t x)
+    {
+        BitSetT result;
+        result.mBits = BitMask<BitsT>(static_cast<ParamT>(x));
+        return result;
+    }
 
   private:
     BitsT mBits;
@@ -177,7 +182,7 @@
 }
 
 template <size_t N, typename BitsT, typename ParamT>
-constexpr BitSetT<N, BitsT, ParamT>::BitSetT(BitsT value) : mBits(value & Mask(N))
+constexpr BitSetT<N, BitsT, ParamT>::BitSetT(BitsT value) : mBits(value & Mask(N).bits())
 {}
 
 template <size_t N, typename BitsT, typename ParamT>
@@ -187,7 +192,7 @@
     {
         mBits |= Bit<BitsT>(element);
     }
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
 }
 
 template <size_t N, typename BitsT, typename ParamT>
@@ -228,21 +233,21 @@
 template <size_t N, typename BitsT, typename ParamT>
 constexpr bool BitSetT<N, BitsT, ParamT>::all() const
 {
-    ASSERT(mBits == (mBits & Mask(N)));
-    return mBits == Mask(N);
+    ASSERT(mBits == (mBits & Mask(N).bits()));
+    return mBits == Mask(N).bits();
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr bool BitSetT<N, BitsT, ParamT>::any() const
 {
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return (mBits != 0);
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr bool BitSetT<N, BitsT, ParamT>::none() const
 {
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return (mBits == 0);
 }
 
@@ -276,7 +281,7 @@
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> BitSetT<N, BitsT, ParamT>::operator~() const
 {
-    return BitSetT<N, BitsT, ParamT>(~mBits & Mask(N));
+    return BitSetT<N, BitsT, ParamT>(~mBits & Mask(N).bits());
 }
 
 template <size_t N, typename BitsT, typename ParamT>
@@ -290,7 +295,7 @@
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::operator|=(BitsT value)
 {
     mBits |= value;
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return *this;
 }
 
@@ -298,20 +303,20 @@
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::operator^=(BitsT value)
 {
     mBits ^= value;
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return *this;
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> BitSetT<N, BitsT, ParamT>::operator<<(std::size_t pos) const
 {
-    return BitSetT<N, BitsT, ParamT>((mBits << pos) & Mask(N));
+    return BitSetT<N, BitsT, ParamT>((mBits << pos) & Mask(N).bits());
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::operator<<=(std::size_t pos)
 {
-    mBits = mBits << pos & Mask(N);
+    mBits = mBits << pos & Mask(N).bits();
     return *this;
 }
 
@@ -324,15 +329,15 @@
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::operator>>=(std::size_t pos)
 {
-    mBits = (mBits >> pos) & Mask(N);
+    mBits = (mBits >> pos) & Mask(N).bits();
     return *this;
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::set()
 {
-    ASSERT(mBits == (mBits & Mask(N)));
-    mBits = Mask(N);
+    ASSERT(mBits == (mBits & Mask(N).bits()));
+    mBits = Mask(N).bits();
     return *this;
 }
 
@@ -348,14 +353,14 @@
     {
         reset(pos);
     }
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return *this;
 }
 
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::reset()
 {
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     mBits = 0;
     return *this;
 }
@@ -364,7 +369,7 @@
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::reset(ParamT pos)
 {
     ASSERT(static_cast<size_t>(pos) < N);
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     mBits &= ~Bit<BitsT>(pos);
     return *this;
 }
@@ -372,8 +377,8 @@
 template <size_t N, typename BitsT, typename ParamT>
 constexpr BitSetT<N, BitsT, ParamT> &BitSetT<N, BitsT, ParamT>::flip()
 {
-    ASSERT(mBits == (mBits & Mask(N)));
-    mBits ^= Mask(N);
+    ASSERT(mBits == (mBits & Mask(N).bits()));
+    mBits ^= Mask(N).bits();
     return *this;
 }
 
@@ -382,7 +387,7 @@
 {
     ASSERT(static_cast<size_t>(pos) < N);
     mBits ^= Bit<BitsT>(pos);
-    ASSERT(mBits == (mBits & Mask(N)));
+    ASSERT(mBits == (mBits & Mask(N).bits()));
     return *this;
 }
 
@@ -685,6 +690,9 @@
 
     constexpr value_type bits(size_t index) const;
 
+    // Produces a mask of ones up to the "x"th bit.
+    constexpr static BitSetArray Mask(std::size_t x);
+
   private:
     static constexpr std::size_t kDefaultBitSetSizeMinusOne = priv::kDefaultBitSetSize - 1;
     static constexpr std::size_t kShiftForDivision =
@@ -692,8 +700,10 @@
     static constexpr std::size_t kArraySize =
         ((N + kDefaultBitSetSizeMinusOne) >> kShiftForDivision);
     constexpr static std::size_t kLastElementCount = (N & kDefaultBitSetSizeMinusOne);
-    constexpr static std::size_t kLastElementMask  = priv::BaseBitSetType::Mask(
-         kLastElementCount == 0 ? priv::kDefaultBitSetSize : kLastElementCount);
+    constexpr static std::size_t kLastElementMask =
+        priv::BaseBitSetType::Mask(kLastElementCount == 0 ? priv::kDefaultBitSetSize
+                                                          : kLastElementCount)
+            .bits();
 
     std::array<BaseBitSet, kArraySize> mBaseBitSetArray;
 };
@@ -1059,6 +1069,25 @@
 {
     return mBaseBitSetArray[index].bits();
 }
+
+template <std::size_t N>
+constexpr BitSetArray<N> BitSetArray<N>::Mask(std::size_t x)
+{
+    BitSetArray result;
+
+    for (size_t arrayIndex = 0; arrayIndex < kArraySize; ++arrayIndex)
+    {
+        const size_t bitOffset = arrayIndex * priv::kDefaultBitSetSize;
+        if (x <= bitOffset)
+        {
+            break;
+        }
+        const size_t bitsInThisIndex        = std::min(x - bitOffset, priv::kDefaultBitSetSize);
+        result.mBaseBitSetArray[arrayIndex] = BaseBitSet::Mask(bitsInThisIndex);
+    }
+
+    return result;
+}
 }  // namespace angle
 
 template <size_t N, typename BitsT, typename ParamT>
diff --git a/src/common/bitset_utils_unittest.cpp b/src/common/bitset_utils_unittest.cpp
index 4f11ee8..8c4d0c1 100644
--- a/src/common/bitset_utils_unittest.cpp
+++ b/src/common/bitset_utils_unittest.cpp
@@ -28,6 +28,7 @@
 using BitSetTypes = ::testing::Types<BitSet<12>, BitSet32<12>, BitSet64<12>>;
 TYPED_TEST_SUITE(BitSetTest, BitSetTypes);
 
+// Basic test of various bitset functionalities
 TYPED_TEST(BitSetTest, Basic)
 {
     EXPECT_EQ(TypeParam::Zero().bits(), 0u);
@@ -120,6 +121,7 @@
     }
 }
 
+// Test bitwise operations
 TYPED_TEST(BitSetTest, BitwiseOperators)
 {
     TypeParam mBits = this->mBits;
@@ -182,6 +184,30 @@
     EXPECT_EQ(mBits.bits() & ~kMask, 0u);
 }
 
+// Test BitSetT::Mask
+TYPED_TEST(BitSetTest, Mask)
+{
+    // Test constexpr usage
+    TypeParam bits = TypeParam::Mask(0);
+    EXPECT_EQ(bits.bits(), 0u);
+
+    bits = TypeParam::Mask(1);
+    EXPECT_EQ(bits.bits(), 1u);
+
+    bits = TypeParam::Mask(2);
+    EXPECT_EQ(bits.bits(), 3u);
+
+    bits = TypeParam::Mask(TypeParam::size());
+    EXPECT_EQ(bits.bits(), (1u << TypeParam::size()) - 1);
+
+    // Complete test
+    for (size_t i = 0; i < TypeParam::size(); ++i)
+    {
+        bits = TypeParam::Mask(i);
+        EXPECT_EQ(bits.bits(), (1u << i) - 1);
+    }
+}
+
 template <typename T>
 class BitSetIteratorTest : public testing::Test
 {
@@ -407,6 +433,7 @@
     ::testing::Types<BitSetArray<65>, BitSetArray<128>, BitSetArray<130>, BitSetArray<511>>;
 TYPED_TEST_SUITE(BitSetArrayTest, BitSetArrayTypes);
 
+// Basic test of various BitSetArray functionalities
 TYPED_TEST(BitSetArrayTest, BasicTest)
 {
     TypeParam &mBits = this->mBitSet;
@@ -610,6 +637,7 @@
     }
 }
 
+// Test iteration over BitSetArray where there are gaps
 TYPED_TEST(BitSetArrayTest, IterationWithGaps)
 {
     TypeParam &mBits = this->mBitSet;
@@ -630,6 +658,66 @@
     mBits.reset();
 }
 
+// Test BitSetArray::Mask
+TYPED_TEST(BitSetArrayTest, Mask)
+{
+    // Test constexpr usage
+    TypeParam bits = TypeParam::Mask(0);
+    for (size_t i = 0; i < bits.size(); ++i)
+    {
+        EXPECT_FALSE(bits[i]) << i;
+    }
+
+    bits = TypeParam::Mask(1);
+    for (size_t i = 0; i < bits.size(); ++i)
+    {
+        if (i < 1)
+        {
+            EXPECT_TRUE(bits[i]) << i;
+        }
+        else
+        {
+            EXPECT_FALSE(bits[i]) << i;
+        }
+    }
+
+    bits = TypeParam::Mask(2);
+    for (size_t i = 0; i < bits.size(); ++i)
+    {
+        if (i < 2)
+        {
+            EXPECT_TRUE(bits[i]) << i;
+        }
+        else
+        {
+            EXPECT_FALSE(bits[i]) << i;
+        }
+    }
+
+    bits = TypeParam::Mask(TypeParam::size());
+    for (size_t i = 0; i < bits.size(); ++i)
+    {
+        EXPECT_TRUE(bits[i]) << i;
+    }
+
+    // Complete test
+    for (size_t i = 0; i < TypeParam::size(); ++i)
+    {
+        bits = TypeParam::Mask(i);
+        for (size_t j = 0; j < bits.size(); ++j)
+        {
+            if (j < i)
+            {
+                EXPECT_TRUE(bits[j]) << j;
+            }
+            else
+            {
+                EXPECT_FALSE(bits[j]) << j;
+            }
+        }
+    }
+}
+
 // Unit test for angle::Bit
 TEST(Bit, Test)
 {