Add documentation for FeatureAlphaDropout (#36295)
Summary:
These changes add documentation for FeatureAlphaDropout, based on a need raised in an issue by SsnL (Issue https://github.com/pytorch/pytorch/issues/9886).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36295
Differential Revision: D21478591
Pulled By: zou3519
fbshipit-source-id: a73c40bf1c7e3b1f301dc3347cef7b32e9842320
diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst
index 9fe5bfc..a712a86 100644
--- a/docs/source/nn.functional.rst
+++ b/docs/source/nn.functional.rst
@@ -328,6 +328,11 @@
.. autofunction:: alpha_dropout
+:hidden:`feature_alpha_dropout`
+~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: feature_alpha_dropout
+
:hidden:`dropout2d`
~~~~~~~~~~~~~~~~~~~
diff --git a/test/test_torch.py b/test/test_torch.py
index 2f076a3..c94d711 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -188,7 +188,7 @@
'sparse_resize_and_clear_',
)
test_namespace(torch.nn)
- test_namespace(torch.nn.functional, 'assert_int_or_pair', 'feature_alpha_dropout')
+ test_namespace(torch.nn.functional, 'assert_int_or_pair')
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 19e2593..df32d4a 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1051,6 +1051,25 @@
def feature_alpha_dropout(input, p=0.5, training=False, inplace=False):
# type: (Tensor, float, bool, bool) -> Tensor
+ r"""
+ Randomly masks out entire channels (a channel is a feature map,
+ e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
+ is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
+ setting activations to zero, as in regular Dropout, the activations are set
+ to the negative saturation value of the SELU activation function.
+
+ Each element will be masked independently on every forward call with
+ probability :attr:`p` using samples from a Bernoulli distribution.
+ The elements to be masked are randomized on every forward call, and scaled
+ and shifted to maintain zero mean and unit variance.
+
+ See :class:`~torch.nn.FeatureAlphaDropout` for details.
+
+ Args:
+ p: dropout probability of a channel to be zeroed. Default: 0.5
+ training: apply dropout if is ``True``. Default: ``True``
+ inplace: If set to ``True``, will do this operation in-place. Default: ``False``
+ """
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(
diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py
index 1f499b0..df0465d 100644
--- a/torch/nn/modules/dropout.py
+++ b/torch/nn/modules/dropout.py
@@ -181,6 +181,49 @@
class FeatureAlphaDropout(_DropoutNd):
+ r"""Randomly masks out entire channels (a channel is a feature map,
+ e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
+ is a tensor :math:`\text{input}[i, j]`) of the input tensor). Instead of
+ setting activations to zero, as in regular Dropout, the activations are set
+ to the negative saturation value of the SELU activation function. More details
+ can be found in the paper `Self-Normalizing Neural Networks`_ .
+
+ Each element will be masked independently for each sample on every forward
+ call with probability :attr:`p` using samples from a Bernoulli distribution.
+ The elements to be masked are randomized on every forward call, and scaled
+ and shifted to maintain zero mean and unit variance.
+
+ Usually the input comes from :class:`nn.AlphaDropout` modules.
+
+ As described in the paper
+ `Efficient Object Localization Using Convolutional Networks`_ ,
+ if adjacent pixels within feature maps are strongly correlated
+ (as is normally the case in early convolution layers) then i.i.d. dropout
+ will not regularize the activations and will otherwise just result
+ in an effective learning rate decrease.
+
+ In this case, :func:`nn.AlphaDropout` will help promote independence between
+ feature maps and should be used instead.
+
+ Args:
+ p (float, optional): probability of an element to be zeroed. Default: 0.5
+ inplace (bool, optional): If set to ``True``, will do this operation
+ in-place
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples::
+
+ >>> m = nn.FeatureAlphaDropout(p=0.2)
+ >>> input = torch.randn(20, 16, 4, 32, 32)
+ >>> output = m(input)
+
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
+ .. _Efficient Object Localization Using Convolutional Networks:
+ http://arxiv.org/abs/1411.4280
+ """
def forward(self, input):
return F.feature_alpha_dropout(input, self.p, self.training)