Documenting `torch.distributions.utils.clamp_probs` (#128136)
Fixes https://github.com/pytorch/pytorch/issues/127889
This PR adds docstring to the `torch.distributions.utils.clamp_probs` function.
Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128136
Approved by: https://github.com/janeyx99, https://github.com/svekars, https://github.com/malfet
diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py
index 7a6d31a..91e4345 100644
--- a/torch/distributions/utils.py
+++ b/torch/distributions/utils.py
@@ -90,6 +90,27 @@
def clamp_probs(probs):
+ """Clamps the probabilities to be in the open interval `(0, 1)`.
+
+ The probabilities would be clamped between `eps` and `1 - eps`,
+ and `eps` would be the smallest representable positive number for the input data type.
+
+ Args:
+ probs (Tensor): A tensor of probabilities.
+
+ Returns:
+ Tensor: The clamped probabilities.
+
+ Examples:
+ >>> probs = torch.tensor([0.0, 0.5, 1.0])
+ >>> clamp_probs(probs)
+ tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
+
+ >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
+ >>> clamp_probs(probs)
+ tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
+
+ """
eps = torch.finfo(probs.dtype).eps
return probs.clamp(min=eps, max=1 - eps)