Enable half for CUDA dense EmbeddingBag backward. (#19293)
Summary:
I audited the relevant kernel and saw it accumulates a good deal into float
so it should be fine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19293
Differential Revision: D14942274
Pulled By: zou3519
fbshipit-source-id: 36996ba0fbb29fbfb12b27bfe9c0ad1eb012ba3c
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index 677d04f..8e0f6e8 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -491,7 +491,7 @@
dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
auto output = at::empty({num_samples}, grad.options());
- AT_DISPATCH_FLOATING_TYPES(
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
_embedding_bag_per_sample_weights_backward_kernel<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
diff --git a/test/test_nn.py b/test/test_nn.py
index c093502..d489e50 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2381,7 +2381,8 @@
self.assertEqual(es_weight_grad, e.weight.grad, needed_prec)
if test_per_sample_weights and trainable_per_sample_weights:
- self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad)
+ self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
+ dtype2prec[dtype])
def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
# check a known test example
@@ -2653,16 +2654,21 @@
expected = self._embedding_bag_reference_impl(
input, reference_weights, offsets, mode, ref_per_sample_weights)
result = es(input, offsets, per_sample_weights)
- self.assertEqual(result, expected)
+ self.assertEqual(result, expected, prec=dtype2prec[dtype])
grad = torch.randn_like(expected)
result.backward(grad)
expected.backward(grad)
- self.assertEqual(es.weight.grad, reference_weights.grad)
+ self.assertEqual(es.weight.grad, reference_weights.grad,
+ dtype2prec[dtype])
if trainable_scale:
- self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad)
+ self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
+ prec=dtype2prec[dtype])
- dtypes = (torch.float, torch.double)
+ if device == 'cuda':
+ dtypes = (torch.float, torch.double, torch.half)
+ else:
+ dtypes = (torch.float, torch.double)
modes = ('sum',)
trainable_scale = (True, False)
for dtype, mode, trainable in itertools.product(dtypes, modes, trainable_scale):
@@ -2677,12 +2683,7 @@
@staticmethod
def _test_EmbeddingBag_per_sample_weights_and_no_offsets(self, device='cpu'):
- dtypes = (torch.float, torch.double)
- modes = ('sum',)
- sparsity = (True, False)
- trainable_scale = (True, False)
- for dtype, mode, sparse, trainable_per_sample_weights in \
- itertools.product(dtypes, modes, sparsity, trainable_scale):
+ def run_tests(dtype, mode, sparse, trainable_per_sample_weights):
kwargs = dict(test_per_sample_weights=True, device=device,
mode=mode, dtype=dtype, sparse=sparse,
trainable_per_sample_weights=trainable_per_sample_weights)
@@ -2699,6 +2700,24 @@
# Large embedding_dim
self._test_EmbeddingBag_vs_Embedding(2, 101, 3, 7, **kwargs)
+ dtypes = (torch.float, torch.double)
+ modes = ('sum',)
+ sparsity = (True, False)
+ trainable_scale = (True, False)
+ for dtype, mode, sparse, trainable_per_sample_weights in \
+ itertools.product(dtypes, modes, sparsity, trainable_scale):
+ run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
+ # Test CUDA Dense on half precision
+ if device == 'cuda':
+ dtypes = (torch.half,)
+ modes = ('sum',)
+ sparsity = (False,)
+ trainable_scale = (True, False)
+ for dtype, mode, sparse, trainable_per_sample_weights in \
+ itertools.product(dtypes, modes, sparsity, trainable_scale):
+ run_tests(dtype, mode, sparse, trainable_per_sample_weights)
+
def test_EmbeddingBag_per_sample_weights_and_no_offsets(self):
self._test_EmbeddingBag_per_sample_weights_and_no_offsets(self)