Support dropout(nested tensor) (#79318)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79318
Approved by: https://github.com/jbschlosser
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 0b7c4fe..574ddf1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -240,8 +240,14 @@
 - func: _shape_as_tensor(Tensor self) -> Tensor
 
 - func: dropout(Tensor input, float p, bool train) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: dropout
+    NestedTensorCPU, NestedTensorCUDA: dropout_nested
 
 - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+  dispatch:
+    CompositeImplicitAutograd: dropout_
+    NestedTensorCPU, NestedTensorCUDA: dropout_nested_
 
 - func: feature_dropout(Tensor input, float p, bool train) -> Tensor
 
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 6cca419..c6a2eee 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -647,5 +647,19 @@
   return get_nested_size_tensor(self);
 }
 
+Tensor dropout_nested(const Tensor& input, double p, bool train) {
+  auto input_ptr = get_nested_tensor_impl(input);
+  const Tensor & input_buffer = input_ptr->get_buffer(),
+                 sizemat = input_ptr->get_nested_size_tensor();
+  Tensor output_buffer = at::dropout(input_buffer, p, train);
+  return wrap_buffer(output_buffer, sizemat.clone());
+}
+
+Tensor& dropout_nested_(Tensor& input, double p, bool train) {
+  Tensor input_buffer = get_buffer(input);
+  at::dropout_(input_buffer, p, train);
+  return input;
+}
+
 } // namespace native
 } // namespace at
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 1351cec..64dcc3d 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -9,7 +9,7 @@
     instantiate_device_type_tests,
     skipMeta,
 )
-from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests
+from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
 from torch import nested_tensor
 
 # Tests are ported from pytorch/nestedtensor.
@@ -482,6 +482,74 @@
         with self.assertRaisesRegex(RuntimeError, msg):
             nt1.clone(memory_format=torch.channels_last)
 
+    # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
+    @dtypes(torch.float, torch.double)
+    @torch.inference_mode()
+    def test_dropout(self, device, dtype):
+        # edge case: empty nested tensor
+        nt0 = torch.nested_tensor([])
+        y = torch.nn.functional.dropout(nt0, 0.5)
+        self.nt_equal(nt0, y)
+        # normal nested tensor
+        ntensors = 4
+        nt = self.random_nt(device, dtype, ntensors, (4, 4))
+        # edge case: invalid dropout
+        self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
+        self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
+        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
+        self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
+        # edge case: no dropout
+        dropouter = torch.nn.Dropout(0.0)
+        y0 = dropouter(nt)
+        y1 = torch.nn.functional.dropout(nt, 0.0)
+        self.nt_equal(nt, y0)
+        self.nt_equal(nt, y1)
+        # edge case: all dropout
+        dropouter = torch.nn.Dropout(1.0)
+        y0 = dropouter(nt)
+        y1 = torch.nn.functional.dropout(nt, 1.0)
+        nt0 = nt.clone()
+        for i in range(ntensors):
+            nt0[i].fill_(0.0)
+        self.nt_equal(nt0, y0)
+        self.nt_equal(nt0, y1)
+        # normal case: normal dropout
+        p = 0.2
+        y = torch.nn.functional.dropout(nt, p)
+        expect = nt.clone()
+        for i in range(ntensors):
+            actual_tensor = y[i].view(-1)
+            expect_tensor = expect[i].view(-1)
+            for j in range(actual_tensor.shape[0]):
+                if actual_tensor[j].item() == 0.0:
+                    expect_tensor[j] = 0.0
+                else:
+                    expect_tensor[j] /= 1.0 - p
+        self.nt_equal(y, expect)
+        with freeze_rng_state():
+            dropouter = torch.nn.Dropout(p)
+            y0 = dropouter(nt)
+        with freeze_rng_state():
+            y1 = torch.nn.functional.dropout(nt, p)
+        self.nt_equal(y0, y1)
+        # inplace
+        # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
+        # in practice, cuda functional has its own implementation to skip `bernoulli_`
+        # so cuda functional will differ from cuda inplace causing test failure
+        # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
+        # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
+        expect = nt.clone()
+        torch.nn.functional.dropout(nt, p, inplace=True)
+        for i in range(ntensors):
+            actual_tensor = nt[i].view(-1)
+            expect_tensor = expect[i].view(-1)
+            for j in range(actual_tensor.shape[0]):
+                if actual_tensor[j].item() == 0.0:
+                    expect_tensor[j] = 0.0
+                else:
+                    expect_tensor[j] /= 1.0 - p
+        self.nt_equal(nt, expect)
+
 class TestNestedTensorAutograd(TestCase):
     def nt_equal(self, nt1, nt2):
         self.assertEqual(nt1.dtype, nt2.dtype)