fix torch.set_float32_matmul_precision doc  (#119620)

Fixes #119606, clearify the explictly stored number of bits in doc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119620
Approved by: https://github.com/eqy, https://github.com/malfet
diff --git a/torch/__init__.py b/torch/__init__.py
index 7bd2049..d65b0f6 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1015,24 +1015,24 @@
     Supports three settings:
 
         * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
-          bits) for internal computations.
+          bits with 23 bits explicitly stored) for internal computations.
         * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
-          mantissa bits) or treat each float32 number as the sum of two bfloat16 numbers
-          (approximately 16 mantissa bits), if the appropriate fast matrix multiplication
+          mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
+          (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
           algorithms are available.  Otherwise float32 matrix multiplications are computed
           as if the precision is "highest".  See below for more information on the bfloat16
           approach.
         * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
-          bits) for internal computations, if a fast matrix multiplication algorithm
+          bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
           using that datatype internally is available. Otherwise float32
           matrix multiplications are computed as if the precision is "high".
 
     When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
     that is more complicated than simply truncating to some smaller number mantissa bits
-    (e.g. 10 for TensorFloat32, 8 for bfloat16).  Refer to [Henry2019]_ for a complete
+    (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored).  Refer to [Henry2019]_ for a complete
     description of this algorithm.  To briefly explain here, the first step is to realize
     that we can perfectly encode a single float32 number as the sum of three bfloat16
-    numbers (because float32 has 24 mantissa bits while bfloat16 has 8, and both have the
+    numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
     same number of exponent bits).  This means that the product of two float32 numbers can
     be exactly given by the sum of nine products of bfloat16 numbers.  We can then trade
     accuracy for speed by dropping some of these products.  The "high" precision algorithm