Add missing Modules to nn.functional (#1801)

* add dropout2d and dropout3d to functional

added some loss functions to functional

added tests

using dropout from backend

added docs

fixes

* edited loss modules to call functional
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index 75a8fa9..221e8be 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -913,6 +913,16 @@
 
 .. autofunction:: alpha_dropout
 
+:hidden:`dropout2d`
+~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: dropout2d
+
+:hidden:`dropout3d`
+~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: dropout3d
+
 Distance functions
 ----------------------------------
 
@@ -930,30 +940,70 @@
 Loss functions
 --------------
 
-:hidden:`nll_loss`
-~~~~~~~~~~~~~~~~~~
+:hidden:`binary_cross_entropy`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-.. autofunction:: nll_loss
+.. autofunction:: binary_cross_entropy
 
 :hidden:`poisson_nll_loss`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 .. autofunction:: poisson_nll_loss
 
-:hidden:`kl_div`
-~~~~~~~~~~~~~~~~
+:hidden:`cosine_embedding_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-.. autofunction:: kl_div
+.. autofunction:: cosine_embedding_loss
 
 :hidden:`cross_entropy`
 ~~~~~~~~~~~~~~~~~~~~~~~
 
 .. autofunction:: cross_entropy
 
-:hidden:`binary_cross_entropy`
+:hidden:`hinge_embedding_loss`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-.. autofunction:: binary_cross_entropy
+.. autofunction:: hinge_embedding_loss
+
+:hidden:`kl_div`
+~~~~~~~~~~~~~~~~
+
+.. autofunction:: kl_div
+
+:hidden:`l1_loss`
+~~~~~~~~~~~~~~~~~
+
+.. autofunction:: l1_loss
+
+:hidden:`mse_loss`
+~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: mse_loss
+
+:hidden:`margin_ranking_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: margin_ranking_loss
+
+:hidden:`multilabel_margin_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: multilabel_margin_loss
+
+:hidden:`multilabel_soft_margin_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: multilabel_soft_margin_loss
+
+:hidden:`multi_margin_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: multi_margin_loss
+
+:hidden:`nll_loss`
+~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: nll_loss
 
 :hidden:`binary_cross_entropy_with_logits`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -965,6 +1015,11 @@
 
 .. autofunction:: smooth_l1_loss
 
+:hidden:`soft_margin_loss`
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autofunction:: soft_margin_loss
+
 :hidden:`triplet_margin_loss`
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
diff --git a/test/common_nn.py b/test/common_nn.py
index 57c4c73..d25a81a 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -216,7 +216,6 @@
     ),
 ]
 
-
 criterion_tests = [
     dict(module_name='L1Loss',
          input_size=(2, 3, 4),
diff --git a/test/test_nn.py b/test/test_nn.py
index 9cf3030..b70a7f0 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3669,6 +3669,7 @@
         test_params['constructor'] = getattr(nn, name)
     test = NewModuleTest(**test_params)
     add_test(test)
+
 for test_params in criterion_tests + new_criterion_tests:
     name = test_params.pop('module_name')
     test_params['constructor'] = getattr(nn, name)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 245112e..1f1c99a 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -447,6 +447,14 @@
     return output.mul_(a).add_(b)
 
 
+def dropout2d(input, p=0.5, training=False, inplace=False):
+    return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)
+
+
+def dropout3d(input, p=0.5, training=False, inplace=False):
+    return _functions.dropout.FeatureDropout.apply(input, p, training, inplace)
+
+
 def threshold(input, threshold, value, inplace=False):
     return _functions.thnn.Threshold.apply(input, threshold, value, inplace)
 
@@ -632,7 +640,7 @@
         return torch.sum(loss)
 
 
-def kl_div(input, target, size_average=True):
+def kl_div(input, target, size_average=True, weight=None):
     r"""The `Kullback-Leibler divergence`_ Loss.
 
     See :class:`~torch.nn.KLDivLoss` for details.
@@ -642,8 +650,10 @@
         target: Variable of the same shape as input
         size_average: if True the output is divided by the number of elements
           in input tensor
+        weight (Tensor, optional): a manual rescaling weight given to each
+                class. If given, has to be a Tensor of size "nclasses"
     """
-    return _functions.thnn.KLDivLoss(size_average)(input, target)
+    return _functions.thnn.KLDivLoss(size_average, weight=weight)(input, target)
 
 
 def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-100):
@@ -730,6 +740,49 @@
     return _functions.thnn.SmoothL1Loss(size_average)(input, target)
 
 
+def l1_loss(input, target, size_average=True):
+    return _functions.thnn.L1Loss(size_average)(input, target)
+
+
+def mse_loss(input, target, size_average=True):
+    return _functions.thnn.MSELoss(size_average)(input, target)
+
+
+def margin_ranking_loss(input1, input2, target, margin=0, size_average=True):
+    return _functions.loss.MarginRankingLoss(margin, size_average)(input1, input2, target)
+
+
+def hinge_embedding_loss(input, target, margin=1.0, size_average=True):
+    return _functions.loss.HingeEmbeddingLoss(margin, size_average)(input, target)
+
+
+def multilabel_margin_loss(input, target, size_average=True):
+    return _functions.thnn.MultiLabelMarginLoss(size_average)(input, target)
+
+
+def soft_margin_loss(input, target, size_average=True):
+    return _functions.thnn.SoftMarginLoss(size_average)(input, target)
+
+
+def multilabel_soft_margin_loss(input, target, weight=None, size_average=True):
+    input = torch.sigmoid(input)
+    return binary_cross_entropy(input, target, weight, size_average)
+
+
+def cosine_embedding_loss(input1, input2, target, margin=0, size_average=True):
+    return _functions.loss.CosineEmbeddingLoss(margin, size_average)(input1, input2, target)
+
+
+def multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=True):
+    if p != 1 and p != 2:
+        raise ValueError('only p == 1 and p == 2 supported')
+    if weight is not None and weight.dim() != 1:
+        raise ValueError('weight must be one-dimensional')
+
+    return _functions.thnn.MultiMarginLoss(size_average, p, margin,
+                                           weight=weight)(input, target)
+
+
 def pixel_shuffle(input, upscale_factor):
     r"""Rearranges elements in a tensor of shape ``[*, C*r^2, H, W]`` to a
     tensor of shape ``[C, H*r, W*r]``.
diff --git a/torch/nn/modules/dropout.py b/torch/nn/modules/dropout.py
index df1601e..264009d 100644
--- a/torch/nn/modules/dropout.py
+++ b/torch/nn/modules/dropout.py
@@ -96,7 +96,7 @@
         self.inplace = inplace
 
     def forward(self, input):
-        return self._backend.Dropout2d.apply(input, self.p, self.training, self.inplace)
+        return F.dropout2d(input, self.p, self.training, self.inplace)
 
     def __repr__(self):
         inplace_str = ', inplace' if self.inplace else ''
@@ -149,7 +149,7 @@
         self.inplace = inplace
 
     def forward(self, input):
-        return self._backend.Dropout3d.apply(input, self.p, self.training, self.inplace)
+        return F.dropout3d(input, self.p, self.training, self.inplace)
 
     def __repr__(self):
         inplace_str = ', inplace' if self.inplace else ''
diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py
index fbb1fdb..a0d10f7 100644
--- a/torch/nn/modules/loss.py
+++ b/torch/nn/modules/loss.py
@@ -17,22 +17,12 @@
         super(_Loss, self).__init__()
         self.size_average = size_average
 
-    def forward(self, input, target):
-        _assert_no_grad(target)
-        backend_fn = getattr(self._backend, type(self).__name__)
-        return backend_fn(self.size_average)(input, target)
-
 
 class _WeightedLoss(_Loss):
     def __init__(self, weight=None, size_average=True):
         super(_WeightedLoss, self).__init__(size_average)
         self.register_buffer('weight', weight)
 
-    def forward(self, input, target):
-        _assert_no_grad(target)
-        backend_fn = getattr(self._backend, type(self).__name__)
-        return backend_fn(self.size_average, weight=self.weight)(input, target)
-
 
 class L1Loss(_Loss):
     r"""Creates a criterion that measures the mean absolute value of the
@@ -66,7 +56,9 @@
         >>> output = loss(input, target)
         >>> output.backward()
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.l1_loss(input, target, size_average=self.size_average)
 
 
 class NLLLoss(_WeightedLoss):
@@ -236,7 +228,9 @@
     .. _Kullback-Leibler divergence:
         https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.kl_div(input, target, size_average=self.size_average, weight=self.weight)
 
 
 class MSELoss(_Loss):
@@ -271,7 +265,9 @@
         >>> output = loss(input, target)
         >>> output.backward()
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.mse_loss(input, target, size_average=self.size_average)
 
 
 class BCELoss(_WeightedLoss):
@@ -293,7 +289,10 @@
     to `False`, the losses are instead summed.
 
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.binary_cross_entropy(input, target, weight=self.weight,
+                                      size_average=self.size_average)
 
 
 class BCEWithLogitsLoss(Module):
@@ -358,8 +357,7 @@
         self.size_average = size_average
 
     def forward(self, input, target):
-        return self._backend.HingeEmbeddingLoss(self.margin,
-                                                self.size_average)(input, target)
+        return F.hinge_embedding_loss(input, target, self.margin, self.size_average)
 
 
 class MultiLabelMarginLoss(_Loss):
@@ -379,7 +377,9 @@
 
     This allows for different samples to have variable amounts of target classes
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.multilabel_margin_loss(input, target, size_average=self.size_average)
 
 
 class SmoothL1Loss(_Loss):
@@ -399,7 +399,9 @@
     The division by `n` can be avoided if one sets the internal variable
     `size_average` to `False`
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.smooth_l1_loss(input, target, size_average=self.size_average)
 
 
 class SoftMarginLoss(_Loss):
@@ -414,7 +416,9 @@
     The normalization by the number of elements in the input can be disabled by
     setting `self.size_average` to `False`.
     """
-    pass
+    def forward(self, input, target):
+        _assert_no_grad(target)
+        return F.soft_margin_loss(input, target, size_average=self.size_average)
 
 
 class CrossEntropyLoss(_WeightedLoss):
@@ -481,8 +485,7 @@
     """
 
     def forward(self, input, target):
-        return F.binary_cross_entropy(torch.sigmoid(input), target,
-                                      self.weight, self.size_average)
+        return F.multilabel_soft_margin_loss(input, target, self.weight, self.size_average)
 
 
 class CosineEmbeddingLoss(Module):
@@ -513,8 +516,7 @@
         self.size_average = size_average
 
     def forward(self, input1, input2, target):
-        return self._backend.CosineEmbeddingLoss(self.margin,
-                                                 self.size_average)(input1, input2, target)
+        return F.cosine_embedding_loss(input1, input2, target, self.margin, self.size_average)
 
 
 class MarginRankingLoss(Module):
@@ -542,8 +544,7 @@
         self.size_average = size_average
 
     def forward(self, input1, input2, target):
-        return self._backend.MarginRankingLoss(self.margin,
-                                               self.size_average)(input1, input2, target)
+        return F.margin_ranking_loss(input1, input2, target, self.margin, self.size_average)
 
 
 class MultiMarginLoss(Module):
@@ -580,8 +581,8 @@
         self.weight = weight
 
     def forward(self, input, target):
-        return self._backend.MultiMarginLoss(self.size_average, self.p,
-                                             self.margin, weight=self.weight)(input, target)
+        return F.multi_margin_loss(input, target, self.p, self.margin,
+                                   self.weight, self.size_average)
 
 
 class TripletMarginLoss(Module):