fix torch.set_float32_matmul_precision doc (#119620)
Fixes #119606, clearify the explictly stored number of bits in doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119620
Approved by: https://github.com/eqy, https://github.com/malfet
diff --git a/torch/__init__.py b/torch/__init__.py
index 7bd2049..d65b0f6 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1015,24 +1015,24 @@
Supports three settings:
* "highest", float32 matrix multiplications use the float32 datatype (24 mantissa
- bits) for internal computations.
+ bits with 23 bits explicitly stored) for internal computations.
* "high", float32 matrix multiplications either use the TensorFloat32 datatype (10
- mantissa bits) or treat each float32 number as the sum of two bfloat16 numbers
- (approximately 16 mantissa bits), if the appropriate fast matrix multiplication
+ mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers
+ (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication
algorithms are available. Otherwise float32 matrix multiplications are computed
as if the precision is "highest". See below for more information on the bfloat16
approach.
* "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa
- bits) for internal computations, if a fast matrix multiplication algorithm
+ bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm
using that datatype internally is available. Otherwise float32
matrix multiplications are computed as if the precision is "high".
When using "high" precision, float32 multiplications may use a bfloat16-based algorithm
that is more complicated than simply truncating to some smaller number mantissa bits
- (e.g. 10 for TensorFloat32, 8 for bfloat16). Refer to [Henry2019]_ for a complete
+ (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete
description of this algorithm. To briefly explain here, the first step is to realize
that we can perfectly encode a single float32 number as the sum of three bfloat16
- numbers (because float32 has 24 mantissa bits while bfloat16 has 8, and both have the
+ numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the
same number of exponent bits). This means that the product of two float32 numbers can
be exactly given by the sum of nine products of bfloat16 numbers. We can then trade
accuracy for speed by dropping some of these products. The "high" precision algorithm