Add the hyperlink of the transfomer doc (#120565)
Fixes #120488
- The shape for forward pass is clearly stated in the main [transformer class](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)
- Boolean mask for _key_padding_mask is also explained in the main transformer class.
Therefore, add the hyperlink to the transformer class explicitly so the user can refer back to the main class. Also, correct several symbols in the transform doc from normal text style to math style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120565
Approved by: https://github.com/mikaylagawarecki
diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py
index 011c791..5f48326 100644
--- a/torch/nn/modules/transformer.py
+++ b/torch/nn/modules/transformer.py
@@ -176,7 +176,7 @@
- tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
- memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
- Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
+ Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
positions. If a BoolTensor is provided, positions with ``True``
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
@@ -191,8 +191,8 @@
the output sequence length of a transformer is same as the input sequence
(i.e. target) length of the decoder.
- where S is the source sequence length, T is the target sequence length, N is the
- batch size, E is the feature number
+ where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
+ batch size, :math:`E` is the feature number
Examples:
>>> # xdoctest: +SKIP
@@ -314,7 +314,7 @@
compatibility.
Shape:
- see the docs in Transformer class.
+ see the docs in :class:`~torch.nn.Transformer`.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
@@ -464,7 +464,7 @@
forward and backward compatibility.
Shape:
- see the docs in Transformer class.
+ see the docs in :class:`~torch.nn.Transformer`.
"""
output = tgt
@@ -622,7 +622,7 @@
compatibility.
Shape:
- see the docs in Transformer class.
+ see the docs in :class:`~torch.nn.Transformer`.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
@@ -858,7 +858,7 @@
forward and backward compatibility.
Shape:
- see the docs in Transformer class.
+ see the docs in :class:`~torch.nn.Transformer`.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf