tree 5148c43c38bc39b2fe7c8ac2b25fb1cdabcb17b7
parent 46f9e16afecbab0a57f6d3a0bb489787cd8cf979
author patel-zeel <patel_zeel@iitgn.ac.in> 1645249293 -0800
committer PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> 1645252388 +0000

Adding details to kl.py (#72845)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/72765.

- [x] Improved `NotImplementedError` verbosity.
- [x] Automate the docstring generation process

## Improved `NotImplementedError` verbosity
### Code
```python
import torch

dist = torch.distributions

torch_normal = dist.Normal(loc=0.0, scale=1.0)
torch_mixture = dist.MixtureSameFamily(
    dist.Categorical(torch.ones(5,)
    ),
    dist.Normal(torch.randn(5,), torch.rand(5,)),
)

dist.kl_divergence(torch_normal, torch_mixture)
```
#### Output before this PR
```python
NotImplementedError:
```
#### Output after this PR
```python
NotImplementedError: No KL(p || q) is implemented for p type Normal and q type MixtureSameFamily
```

## Automate the docstring generation process
### Docstring before this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.

    .. math::

        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx

    Args:
        p (Distribution): A :class:`~torch.distributions.Distribution` object.
        q (Distribution): A :class:`~torch.distributions.Distribution` object.

    Returns:
        Tensor: A batch of KL divergences of shape `batch_shape`.

    Raises:
        NotImplementedError: If the distribution types have not been registered via
            :meth:`register_kl`.
```
### Docstring after this PR
```python
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.

    .. math::

        KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx

    Args:
        p (Distribution): A :class:`~torch.distributions.Distribution` object.
        q (Distribution): A :class:`~torch.distributions.Distribution` object.

    Returns:
        Tensor: A batch of KL divergences of shape `batch_shape`.

    Raises:
        NotImplementedError: If the distribution types have not been registered via
            :meth:`register_kl`.
    KL divergence is currently implemented for the following distribution pairs:
        * :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Bernoulli`
        * :class:`~torch.distributions.Bernoulli` and :class:`~torch.distributions.Poisson`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Beta` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Binomial` and :class:`~torch.distributions.Binomial`
        * :class:`~torch.distributions.Categorical` and :class:`~torch.distributions.Categorical`
        * :class:`~torch.distributions.Cauchy` and :class:`~torch.distributions.Cauchy`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.ContinuousBernoulli` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Dirichlet` and :class:`~torch.distributions.Dirichlet`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Exponential` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.ExponentialFamily` and :class:`~torch.distributions.ExponentialFamily`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Gamma` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Geometric` and :class:`~torch.distributions.Geometric`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Gumbel` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.HalfNormal` and :class:`~torch.distributions.HalfNormal`
        * :class:`~torch.distributions.Independent` and :class:`~torch.distributions.Independent`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Laplace`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Laplace` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
        * :class:`~torch.distributions.LowRankMultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
        * :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.LowRankMultivariateNormal`
        * :class:`~torch.distributions.MultivariateNormal` and :class:`~torch.distributions.MultivariateNormal`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Laplace`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Normal` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.OneHotCategorical` and :class:`~torch.distributions.OneHotCategorical`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Pareto` and :class:`~torch.distributions.Uniform`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Bernoulli`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Binomial`
        * :class:`~torch.distributions.Poisson` and :class:`~torch.distributions.Poisson`
        * :class:`~torch.distributions.TransformedDistribution` and :class:`~torch.distributions.TransformedDistribution`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Beta`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.ContinuousBernoulli`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Exponential`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gamma`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Gumbel`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Normal`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Pareto`
        * :class:`~torch.distributions.Uniform` and :class:`~torch.distributions.Uniform`
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72845

Reviewed By: mikaylagawarecki

Differential Revision: D34344551

Pulled By: soulitzer

fbshipit-source-id: 7a603613a2f56f71138d56399c7c521e2238e8c5
(cherry picked from commit 6b2a51c796cd8a16551d629ca368360eec34faef)
