Support dropout(nested tensor) (#79318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79318
Approved by: https://github.com/jbschlosser
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 0b7c4fe..574ddf1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -240,8 +240,14 @@
- func: _shape_as_tensor(Tensor self) -> Tensor
- func: dropout(Tensor input, float p, bool train) -> Tensor
+ dispatch:
+ CompositeImplicitAutograd: dropout
+ NestedTensorCPU, NestedTensorCUDA: dropout_nested
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+ dispatch:
+ CompositeImplicitAutograd: dropout_
+ NestedTensorCPU, NestedTensorCUDA: dropout_nested_
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 6cca419..c6a2eee 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -647,5 +647,19 @@
return get_nested_size_tensor(self);
}
+Tensor dropout_nested(const Tensor& input, double p, bool train) {
+ auto input_ptr = get_nested_tensor_impl(input);
+ const Tensor & input_buffer = input_ptr->get_buffer(),
+ sizemat = input_ptr->get_nested_size_tensor();
+ Tensor output_buffer = at::dropout(input_buffer, p, train);
+ return wrap_buffer(output_buffer, sizemat.clone());
+}
+
+Tensor& dropout_nested_(Tensor& input, double p, bool train) {
+ Tensor input_buffer = get_buffer(input);
+ at::dropout_(input_buffer, p, train);
+ return input;
+}
+
} // namespace native
} // namespace at
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 1351cec..64dcc3d 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -9,7 +9,7 @@
instantiate_device_type_tests,
skipMeta,
)
-from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests
+from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
from torch import nested_tensor
# Tests are ported from pytorch/nestedtensor.
@@ -482,6 +482,74 @@
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
+ # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
+ @dtypes(torch.float, torch.double)
+ @torch.inference_mode()
+ def test_dropout(self, device, dtype):
+ # edge case: empty nested tensor
+ nt0 = torch.nested_tensor([])
+ y = torch.nn.functional.dropout(nt0, 0.5)
+ self.nt_equal(nt0, y)
+ # normal nested tensor
+ ntensors = 4
+ nt = self.random_nt(device, dtype, ntensors, (4, 4))
+ # edge case: invalid dropout
+ self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
+ self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
+ self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
+ self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
+ # edge case: no dropout
+ dropouter = torch.nn.Dropout(0.0)
+ y0 = dropouter(nt)
+ y1 = torch.nn.functional.dropout(nt, 0.0)
+ self.nt_equal(nt, y0)
+ self.nt_equal(nt, y1)
+ # edge case: all dropout
+ dropouter = torch.nn.Dropout(1.0)
+ y0 = dropouter(nt)
+ y1 = torch.nn.functional.dropout(nt, 1.0)
+ nt0 = nt.clone()
+ for i in range(ntensors):
+ nt0[i].fill_(0.0)
+ self.nt_equal(nt0, y0)
+ self.nt_equal(nt0, y1)
+ # normal case: normal dropout
+ p = 0.2
+ y = torch.nn.functional.dropout(nt, p)
+ expect = nt.clone()
+ for i in range(ntensors):
+ actual_tensor = y[i].view(-1)
+ expect_tensor = expect[i].view(-1)
+ for j in range(actual_tensor.shape[0]):
+ if actual_tensor[j].item() == 0.0:
+ expect_tensor[j] = 0.0
+ else:
+ expect_tensor[j] /= 1.0 - p
+ self.nt_equal(y, expect)
+ with freeze_rng_state():
+ dropouter = torch.nn.Dropout(p)
+ y0 = dropouter(nt)
+ with freeze_rng_state():
+ y1 = torch.nn.functional.dropout(nt, p)
+ self.nt_equal(y0, y1)
+ # inplace
+ # in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
+ # in practice, cuda functional has its own implementation to skip `bernoulli_`
+ # so cuda functional will differ from cuda inplace causing test failure
+ # in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
+ # on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
+ expect = nt.clone()
+ torch.nn.functional.dropout(nt, p, inplace=True)
+ for i in range(ntensors):
+ actual_tensor = nt[i].view(-1)
+ expect_tensor = expect[i].view(-1)
+ for j in range(actual_tensor.shape[0]):
+ if actual_tensor[j].item() == 0.0:
+ expect_tensor[j] = 0.0
+ else:
+ expect_tensor[j] /= 1.0 - p
+ self.nt_equal(nt, expect)
+
class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)