Fix the kl_div docs (#67443)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67443

Fixes https://github.com/pytorch/pytorch/issues/57459

After discussing the linked issue, we resolved that `F.kl_div` computes
the right thing as to be consistent with the rest of the losses in
PyTorch.

To avoid any confusion, these docs add a note discussing how the PyTorch
implementation differs from the mathematical definition and the reasons
for doing so.

These docs also add an example that may further help understanding the
intended use of this loss.

cc brianjo mruberry

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D32136888

Pulled By: jbschlosser

fbshipit-source-id: 1ad0a606948656b44ff7d2a701d995c75875e671
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index b81696f..3172727 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -373,77 +373,86 @@
 
 
 class KLDivLoss(_Loss):
-    r"""The Kullback-Leibler divergence loss measure
+    r"""The Kullback-Leibler divergence loss.
 
-    `Kullback-Leibler divergence`_ is a useful distance measure for continuous
-    distributions and is often useful when performing direct regression over
-    the space of (discretely sampled) continuous output distributions.
-
-    As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
-    *log-probabilities* and is not restricted to a 2D Tensor.
-    The targets are interpreted as *probabilities* by default, but could be considered
-    as *log-probabilities* with :attr:`log_target` set to ``True``.
-
-    This criterion expects a `target` `Tensor` of the same size as the
-    `input` `Tensor`.
-
-    The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
+    For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`,
+    where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the
+    :attr:`target`, we define the **pointwise KL-divergence** as
 
     .. math::
-        l(x,y) = L = \{ l_1,\dots,l_N \}, \quad
-        l_n = y_n \cdot \left( \log y_n - x_n \right)
 
-    where the index :math:`N` spans all dimensions of ``input`` and :math:`L` has the same
-    shape as ``input``. If :attr:`reduction` is not ``'none'`` (default ``'mean'``), then:
+        L(y_{\text{pred}},\ y_{\text{true}})
+            = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}}
+            = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})
 
-    .. math::
-        \ell(x, y) = \begin{cases}
-            \operatorname{mean}(L), & \text{if reduction} = \text{`mean';} \\
-            \operatorname{sum}(L),  & \text{if reduction} = \text{`sum'.}
-        \end{cases}
+    To avoid underflow issues when computing this quantity, this loss expects the argument
+    :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the
+    log-space if :attr:`log_target`\ `= True`.
 
-    In default :attr:`reduction` mode ``'mean'``, the losses are averaged for each minibatch over observations
-    **as well as** over dimensions. ``'batchmean'`` mode gives the correct KL divergence where losses
-    are averaged over batch dimension only. ``'mean'`` mode's behavior will be changed to the same as
-    ``'batchmean'`` in the next major release.
+    To summarise, this function is roughly equivalent to computing
 
-    .. _`kullback-leibler divergence`: https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
+    .. code-block:: python
+
+        if not log_target: # default
+            loss_pointwise = target * (target.log() - input)
+        else:
+            loss_pointwise = target.exp() * (target - input)
+
+    and then reducing this result depending on the argument :attr:`reduction` as
+
+    .. code-block:: python
+
+        if reduction == "mean":  # default
+            loss = loss_pointwise.mean()
+        elif reduction == "batchmean":  # mathematically correct
+            loss = loss_pointwise.sum() / input.size(0)
+        elif reduction == "sum":
+            loss = loss_pointwise.sum()
+        else:  # reduction == "none"
+            loss = loss_pointwise
+
+    .. note::
+        As all the other losses in PyTorch, this function expects the first argument,
+        :attr:`input`, to be the output of the model (e.g. the neural network)
+        and the second, :attr:`target`, to be the observations in the dataset.
+        This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where
+        :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model.
+
+    .. warning::
+        :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use
+        :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition.
+        In a future release, `"mean"` will be changed to be the same as `"batchmean"`.
 
     Args:
         size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
             the losses are averaged over each loss element in the batch. Note that for
             some losses, there are multiple elements per sample. If the field :attr:`size_average`
-            is set to ``False``, the losses are instead summed for each minibatch. Ignored
-            when :attr:`reduce` is ``False``. Default: ``True``
+            is set to `False`, the losses are instead summed for each minibatch. Ignored
+            when :attr:`reduce` is `False`. Default: `True`
         reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
             losses are averaged or summed over observations for each minibatch depending
-            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
-            batch element instead and ignores :attr:`size_average`. Default: ``True``
-        reduction (string, optional): Specifies the reduction to apply to the output:
-            ``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
-            ``'none'``: no reduction will be applied.
-            ``'batchmean'``: the sum of the output will be divided by batchsize.
-            ``'sum'``: the output will be summed.
-            ``'mean'``: the output will be divided by the number of elements in the output.
-            Default: ``'mean'``
-        log_target (bool, optional): Specifies whether `target` is passed in the log space.
-            Default: ``False``
-
-    .. note::
-        :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
-        and in the meantime, specifying either of those two args will override :attr:`reduction`.
-
-    .. note::
-        :attr:`reduction` = ``'mean'`` doesn't return the true kl divergence value, please use
-        :attr:`reduction` = ``'batchmean'`` which aligns with KL math definition.
-        In the next major release, ``'mean'`` will be changed to be the same as ``'batchmean'``.
+            on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per
+            batch element instead and ignores :attr:`size_average`. Default: `True`
+        reduction (string, optional): Specifies the reduction to apply to the output. Default: `"mean"`
+        log_target (bool, optional): Specifies whether `target` is the log space. Default: `False`
 
     Shape:
         - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
         - Target: :math:`(*)`, same shape as the input.
-        - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`,
+        - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`,
           same shape as the input.
 
+    Examples::
+
+        >>> kl_loss = nn.KLDivLoss(reduction="batchmean")
+        >>> # input should be a distribution in the log space
+        >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True))
+        >>> # Sample a batch of distributions. Usually this would come from the dataset
+        >>> target = F.softmax(torch.rand(3, 5))
+        >>> output = kl_loss(input, target)
+
+        >>> log_target = F.log_softmax(torch.rand(3, 5))
+        >>> output = kl_loss(input, log_target, log_target=True)
     """
     __constants__ = ['reduction']