Bag of Distributions doc fixes (#10894)
Summary:
- Added `__repr__` for Constraints and Transforms.
- Arguments passed to the constructor are now rendered with :attr:
Closes https://github.com/pytorch/pytorch/issues/10884
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10894
Differential Revision: D9514161
Pulled By: apaszke
fbshipit-source-id: 4abf60335d876449f2b6477eb9655afed9d5b80b
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index 7b773b8..09099e07 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -9,7 +9,7 @@
class Bernoulli(ExponentialFamily):
r"""
- Creates a Bernoulli distribution parameterized by `probs` or `logits`.
+ Creates a Bernoulli distribution parameterized by :attr:`probs` or :attr:`logits` (but not both).
Samples are binary (0 or 1). They take the value `1` with probability `p`
and `0` with probability `1 - p`.
diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py
index fd7192f..f23415d 100644
--- a/torch/distributions/beta.py
+++ b/torch/distributions/beta.py
@@ -9,7 +9,7 @@
class Beta(ExponentialFamily):
r"""
- Beta distribution parameterized by `concentration1` and `concentration0`.
+ Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
Example::
diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py
index 28756d4..da2b106 100644
--- a/torch/distributions/binomial.py
+++ b/torch/distributions/binomial.py
@@ -7,9 +7,9 @@
class Binomial(Distribution):
r"""
- Creates a Binomial distribution parameterized by `total_count` and
- either `probs` or `logits` (but not both). `total_count` must be
- broadcastable with `probs`/`logits`.
+ Creates a Binomial distribution parameterized by :attr:`total_count` and
+ either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be
+ broadcastable with :attr:`probs`/:attr:`logits`.
Example::
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index 444deff..5c42991 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -14,7 +14,7 @@
It is equivalent to the distribution that :func:`torch.multinomial`
samples from.
- Samples are integers from `0 ... K-1` where `K` is probs.size(-1).
+ Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``.
If :attr:`probs` is 1D with length-`K`, each element is the relative
probability of sampling the class at that index.
diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py
index f9dc43d..fcb0c5b 100644
--- a/torch/distributions/chi2.py
+++ b/torch/distributions/chi2.py
@@ -4,8 +4,8 @@
class Chi2(Gamma):
r"""
- Creates a Chi2 distribution parameterized by shape parameter `df`.
- This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)
+ Creates a Chi2 distribution parameterized by shape parameter :attr:`df`.
+ This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)``
Example::
diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py
index 0b6eb53..f214cf1 100644
--- a/torch/distributions/constraints.py
+++ b/torch/distributions/constraints.py
@@ -60,6 +60,9 @@
"""
raise NotImplementedError
+ def __repr__(self):
+ return self.__class__.__name__[1:] + '()'
+
class _Dependent(Constraint):
"""
@@ -111,6 +114,11 @@
def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
class _IntegerLessThan(Constraint):
"""
@@ -122,6 +130,11 @@
def check(self, value):
return (value % 1 == 0) & (value <= self.upper_bound)
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(upper_bound={})'.format(self.upper_bound)
+ return fmt_string
+
class _IntegerGreaterThan(Constraint):
"""
@@ -133,6 +146,11 @@
def check(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
class _Real(Constraint):
"""
@@ -152,6 +170,11 @@
def check(self, value):
return self.lower_bound < value
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
class _GreaterThanEq(Constraint):
"""
@@ -163,6 +186,11 @@
def check(self, value):
return self.lower_bound <= value
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={})'.format(self.lower_bound)
+ return fmt_string
+
class _LessThan(Constraint):
"""
@@ -174,6 +202,11 @@
def check(self, value):
return value < self.upper_bound
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(upper_bound={})'.format(self.upper_bound)
+ return fmt_string
+
class _Interval(Constraint):
"""
@@ -186,6 +219,11 @@
def check(self, value):
return (self.lower_bound <= value) & (value <= self.upper_bound)
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
class _HalfOpenInterval(Constraint):
"""
@@ -198,6 +236,11 @@
def check(self, value):
return (self.lower_bound <= value) & (value < self.upper_bound)
+ def __repr__(self):
+ fmt_string = self.__class__.__name__[1:]
+ fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
+ return fmt_string
+
class _Simplex(Constraint):
"""
@@ -240,7 +283,7 @@
batch_shape = value.unsqueeze(0).shape[:-2]
# TODO: replace with batched linear algebra routine when one becomes available
# note that `symeig()` returns eigenvalues in ascending order
- flattened_value = value.contiguous().view((-1,) + matrix_shape)
+ flattened_value = value.reshape((-1,) + matrix_shape)
return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
for v in flattened_value]).view(batch_shape)
diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py
index fb66a5b..ca014e0 100644
--- a/torch/distributions/dirichlet.py
+++ b/torch/distributions/dirichlet.py
@@ -37,7 +37,7 @@
class Dirichlet(ExponentialFamily):
r"""
- Creates a Dirichlet distribution parameterized by concentration `concentration`.
+ Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.
Example::
diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py
index 64d8f8a..85decc0 100644
--- a/torch/distributions/exponential.py
+++ b/torch/distributions/exponential.py
@@ -8,7 +8,7 @@
class Exponential(ExponentialFamily):
r"""
- Creates a Exponential distribution parameterized by `rate`.
+ Creates a Exponential distribution parameterized by :attr:`rate`.
Example::
diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py
index 2391559..2026212 100644
--- a/torch/distributions/fishersnedecor.py
+++ b/torch/distributions/fishersnedecor.py
@@ -10,7 +10,7 @@
class FisherSnedecor(Distribution):
r"""
- Creates a Fisher-Snedecor distribution parameterized by `df1` and `df2`.
+ Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
Example::
diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py
index 0b25cd3..f1dae3d 100644
--- a/torch/distributions/gamma.py
+++ b/torch/distributions/gamma.py
@@ -13,7 +13,7 @@
class Gamma(ExponentialFamily):
r"""
- Creates a Gamma distribution parameterized by shape `concentration` and `rate`.
+ Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
Example::
diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py
index 7812ab0..c4aa997 100644
--- a/torch/distributions/geometric.py
+++ b/torch/distributions/geometric.py
@@ -9,11 +9,11 @@
class Geometric(Distribution):
r"""
- Creates a Geometric distribution parameterized by `probs`, where `probs` is the probability of success of Bernoulli
- trials. It represents the probability that in k + 1 Bernoulli trials, the first k trials failed, before
- seeing a success.
+ Creates a Geometric distribution parameterized by :attr:`probs`, where :attr:`probs` is the probability of
+ success of Bernoulli trials. It represents the probability that in :math:`k + 1` Bernoulli trials, the
+ first :math:`k` trials failed, before seeing a success.
- Samples are non-negative integers [0, inf).
+ Samples are non-negative integers [0, :math:`\inf`).
Example::
diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py
index 5df1fe9..d3c09fa 100644
--- a/torch/distributions/laplace.py
+++ b/torch/distributions/laplace.py
@@ -7,7 +7,7 @@
class Laplace(Distribution):
r"""
- Creates a Laplace distribution parameterized by `loc` and 'scale'.
+ Creates a Laplace distribution parameterized by :attr:`loc` and :attr:'scale'.
Example::
diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py
index d5a0ba1..9487dc9 100644
--- a/torch/distributions/log_normal.py
+++ b/torch/distributions/log_normal.py
@@ -7,7 +7,7 @@
class LogNormal(TransformedDistribution):
r"""
Creates a log-normal distribution parameterized by
- `loc` and `scale` where::
+ :attr:`loc` and :attr:`scale` where::
X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)
diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py
index 63098a8..39b2f3e 100644
--- a/torch/distributions/logistic_normal.py
+++ b/torch/distributions/logistic_normal.py
@@ -7,7 +7,7 @@
class LogisticNormal(TransformedDistribution):
r"""
- Creates a logistic-normal distribution parameterized by `loc` and `scale`
+ Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
that define the base `Normal` distribution transformed with the
`StickBreakingTransform` such that::
diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py
index 3b06198..591f5f4 100644
--- a/torch/distributions/lowrank_multivariate_normal.py
+++ b/torch/distributions/lowrank_multivariate_normal.py
@@ -57,7 +57,7 @@
class LowRankMultivariateNormal(Distribution):
r"""
Creates a multivariate normal distribution with covariance matrix having a low-rank form
- parameterized by `cov_factor` and `cov_diag`::
+ parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
covariance_matrix = cov_factor @ cov_factor.T + cov_diag
Example:
diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py
index c045557..ccfab11 100644
--- a/torch/distributions/multinomial.py
+++ b/torch/distributions/multinomial.py
@@ -9,11 +9,11 @@
class Multinomial(Distribution):
r"""
- Creates a Multinomial distribution parameterized by `total_count` and
- either `probs` or `logits` (but not both). The innermost dimension of
- `probs` indexes over categories. All other dimensions index over batches.
+ Creates a Multinomial distribution parameterized by :attr:`total_count` and
+ either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
+ :attr:`probs` indexes over categories. All other dimensions index over batches.
- Note that `total_count` need not be specified if only :meth:`log_prob` is
+ Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
called (see example below)
.. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py
index 854ad5b..4f9804e 100644
--- a/torch/distributions/negative_binomial.py
+++ b/torch/distributions/negative_binomial.py
@@ -9,8 +9,8 @@
r"""
Creates a Negative Binomial distribution, i.e. distribution
of the number of independent identical Bernoulli trials
- needed before `total_count` failures are achieved. The probability
- of success of each Bernoulli trial is `probs`.
+ needed before :attr:`total_count` failures are achieved. The probability
+ of success of each Bernoulli trial is :attr:`probs`.
Args:
total_count (float or Tensor): non-negative number of negative Bernoulli
diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 24c24ae..0f1375e 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -10,7 +10,7 @@
class Normal(ExponentialFamily):
r"""
Creates a normal (also called Gaussian) distribution parameterized by
- `loc` and `scale`.
+ :attr:`loc` and :attr:`scale`.
Example::
diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py
index fde403a..6be5407 100644
--- a/torch/distributions/poisson.py
+++ b/torch/distributions/poisson.py
@@ -8,7 +8,7 @@
class Poisson(ExponentialFamily):
r"""
- Creates a Poisson distribution parameterized by `rate`, the rate parameter.
+ Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
Samples are nonnegative integers, with a pmf given by
diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py
index b75341f..5923e34 100644
--- a/torch/distributions/relaxed_bernoulli.py
+++ b/torch/distributions/relaxed_bernoulli.py
@@ -9,8 +9,8 @@
class LogitRelaxedBernoulli(Distribution):
r"""
- Creates a LogitRelaxedBernoulli distribution parameterized by `probs` or `logits`,
- which is the logit of a RelaxedBernoulli distribution.
+ Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` or :attr:`logits`
+ (but not both), which is the logit of a RelaxedBernoulli distribution.
Samples are logits of values in (0, 1). See [1] for more details.
@@ -76,9 +76,9 @@
class RelaxedBernoulli(TransformedDistribution):
r"""
- Creates a RelaxedBernoulli distribution, parametrized by `temperature`, and either
- `probs` or `logits`. This is a relaxed version of the `Bernoulli` distribution, so
- the values are in (0, 1), and has reparametrizable samples.
+ Creates a RelaxedBernoulli distribution, parametrized by :attr:`temperature`, and either
+ :attr:`probs` or :attr:`logits` (but not both). This is a relaxed version of the `Bernoulli`
+ distribution, so the values are in (0, 1), and has reparametrizable samples.
Example::
diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py
index e5b3c71..45b8791 100644
--- a/torch/distributions/relaxed_categorical.py
+++ b/torch/distributions/relaxed_categorical.py
@@ -10,7 +10,7 @@
class ExpRelaxedCategorical(Distribution):
r"""
Creates a ExpRelaxedCategorical parameterized by
- :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
+ :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
Returns the log of a point in the simplex. Based on the interface to
:class:`OneHotCategorical`.
diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py
index e91c7cf..c7738af 100644
--- a/torch/distributions/studentT.py
+++ b/torch/distributions/studentT.py
@@ -10,7 +10,7 @@
class StudentT(Distribution):
r"""
- Creates a Student's t-distribution parameterized by `df`.
+ Creates a Student's t-distribution parameterized by :attr:`df`.
Example::
diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py
index 9b241f6..a90ccee 100644
--- a/torch/distributions/transforms.py
+++ b/torch/distributions/transforms.py
@@ -157,6 +157,9 @@
"""
raise NotImplementedError
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
class _InverseTransform(Transform):
"""
@@ -274,6 +277,12 @@
x = y
return result
+ def __repr__(self):
+ fmt_string = self.__class__.__name__ + '(\n '
+ fmt_string += ',\n '.join([p.__repr__() for p in self.parts])
+ fmt_string += '\n)'
+ return fmt_string
+
identity_transform = ComposeTransform([])
diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py
index 3e64de1..fa39f80 100644
--- a/torch/distributions/uniform.py
+++ b/torch/distributions/uniform.py
@@ -10,7 +10,7 @@
class Uniform(Distribution):
r"""
Generates uniformly distributed random samples from the half-open interval
- `[low, high)`.
+ ``[low, high)``.
Example::